Compare commits
6 Commits
gemini-3-p
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
*.pt
|
||||
__pycache__/
|
||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"learningrate": 0.0003,
|
||||
"weight_decay": 1e-05,
|
||||
"n_updates": 150000,
|
||||
"plot_at": 400,
|
||||
"early_stopping_patience": 40,
|
||||
"print_stats_at": 200,
|
||||
"print_train_stats_at": 50,
|
||||
"validate_at": 200,
|
||||
"accumulation_steps": 1,
|
||||
"commands": {
|
||||
"save_checkpoint": false,
|
||||
"run_test_validation": false,
|
||||
"generate_predictions": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -15,35 +15,51 @@ def init_weights(m):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""Channel attention module (squeeze-and-excitation style)"""
|
||||
def __init__(self, channels, reduction=16):
|
||||
class GatedSkipConnection(nn.Module):
|
||||
"""Gated skip connection for better feature fusion"""
|
||||
def __init__(self, up_channels, skip_channels):
|
||||
super().__init__()
|
||||
self.gate = nn.Sequential(
|
||||
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# Project skip to match up_channels if they differ
|
||||
if skip_channels != up_channels:
|
||||
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||
else:
|
||||
self.skip_proj = nn.Identity()
|
||||
|
||||
def forward(self, x, skip):
|
||||
skip_proj = self.skip_proj(skip)
|
||||
combined = torch.cat([x, skip], dim=1)
|
||||
gate = self.gate(combined)
|
||||
return x * gate + skip_proj * (1 - gate)
|
||||
|
||||
|
||||
class EfficientChannelAttention(nn.Module):
|
||||
"""Efficient channel attention without dimensionality reduction"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
reduced = max(channels // reduction, 8)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
||||
)
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
return x * self.sigmoid(avg_out + max_out)
|
||||
# Global pooling
|
||||
y = self.avg_pool(x)
|
||||
# 1D convolution on channel dimension
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
y = self.sigmoid(y)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""Spatial attention module"""
|
||||
"""Efficient spatial attention module"""
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||
@@ -57,12 +73,12 @@ class SpatialAttention(nn.Module):
|
||||
return x * attn
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""Convolutional Block Attention Module"""
|
||||
def __init__(self, channels, reduction=16):
|
||||
class EfficientAttention(nn.Module):
|
||||
"""Lightweight attention module combining channel and spatial"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channel_attn = ChannelAttention(channels, reduction)
|
||||
self.spatial_attn = SpatialAttention()
|
||||
self.channel_attn = EfficientChannelAttention(channels)
|
||||
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.channel_attn(x)
|
||||
@@ -71,157 +87,221 @@ class CBAM(nn.Module):
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> InstanceNorm2d -> GELU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0, dilation=1):
|
||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||
super().__init__()
|
||||
if separable and in_channels > 1:
|
||||
# Depthwise separable convolution for efficiency
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||
nn.Conv2d(in_channels, out_channels, 1)
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||
# InstanceNorm is preferred for style/inpainting tasks
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
self.act = nn.GELU()
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.act(self.bn(self.conv(x))))
|
||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||
|
||||
|
||||
class DenseBlock(nn.Module):
|
||||
"""Lightweight dense block for better gradient flow"""
|
||||
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||
self.bn = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for layer in self.layers:
|
||||
out = layer(torch.cat(features, dim=1))
|
||||
features.append(out)
|
||||
out = self.fusion(torch.cat(features, dim=1))
|
||||
out = self.relu(self.bn(out))
|
||||
return out + x # Residual connection
|
||||
|
||||
class ResidualConvBlock(nn.Module):
|
||||
"""Residual convolutional block for better gradient flow"""
|
||||
def __init__(self, channels, dropout=0.0, dilation=1):
|
||||
"""Improved residual convolutional block with pre-activation"""
|
||||
def __init__(self, channels, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn1 = nn.InstanceNorm2d(channels, affine=True)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn2 = nn.InstanceNorm2d(channels, affine=True)
|
||||
self.act = nn.GELU()
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.act(self.bn1(self.conv1(x)))
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.conv1(out)
|
||||
out = self.relu2(self.bn2(out))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + residual
|
||||
return self.act(out)
|
||||
out = self.conv2(out)
|
||||
return out + residual
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Enhanced downsampling block with dense and residual connections"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
if use_dense:
|
||||
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||
else:
|
||||
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.residual(x)
|
||||
x = self.dense(x)
|
||||
skip = self.attention(x)
|
||||
return self.pool(skip), skip
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Enhanced upsampling block with gated skip connections"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||
super().__init__()
|
||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
||||
# Skip connection has in_channels, upsampled has out_channels
|
||||
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||
# After gated skip: out_channels
|
||||
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
if use_dense:
|
||||
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||
else:
|
||||
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.up(x)
|
||||
# Handle dimension mismatch by interpolating x to match skip's size
|
||||
# Handle dimension mismatch
|
||||
if x.shape[2:] != skip.shape[2:]:
|
||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.gated_skip(x, skip)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.residual(x)
|
||||
x = self.dense(x)
|
||||
x = self.attention(x)
|
||||
return x
|
||||
|
||||
class MyModel(nn.Module):
|
||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
||||
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
# Initial convolution with larger receptive field
|
||||
self.init_conv = nn.Sequential(
|
||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
||||
ConvBlock(base_channels, base_channels),
|
||||
ResidualConvBlock(base_channels)
|
||||
# Separate mask processing for better feature extraction
|
||||
self.mask_conv = nn.Sequential(
|
||||
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
# Encoder (downsampling path)
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
|
||||
# Bottleneck with multiple residual blocks
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=2),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=4),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=8),
|
||||
CBAM(base_channels * 16)
|
||||
)
|
||||
|
||||
# Decoder (upsampling path)
|
||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
||||
|
||||
# Final refinement layers
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvBlock(base_channels * 2, base_channels),
|
||||
ResidualConvBlock(base_channels),
|
||||
# Image processing path
|
||||
self.image_conv = nn.Sequential(
|
||||
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
# Output layer with smooth transition
|
||||
# Fusion of mask and image features
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
# Encoder with progressive feature extraction
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||
|
||||
# Enhanced bottleneck with multi-scale features and dense connections
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||
DenseBlock(base_channels * 8, growth_rate=10, num_layers=3, dropout=dropout),
|
||||
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||
EfficientAttention(base_channels * 8)
|
||||
)
|
||||
|
||||
# Decoder with progressive reconstruction
|
||||
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
|
||||
# Multi-scale feature fusion with dense connections
|
||||
self.multiscale_fusion = nn.Sequential(
|
||||
ConvBlock(base_channels * 2, base_channels),
|
||||
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
# Output with residual connection to input
|
||||
self.pre_output = nn.Sequential(
|
||||
ConvBlock(base_channels, base_channels),
|
||||
ConvBlock(base_channels, base_channels // 2)
|
||||
)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
||||
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels // 2, 3, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Apply weight initialization
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
# Initial convolution
|
||||
x0 = self.init_conv(x)
|
||||
# Split input into image and mask
|
||||
image = x[:, :3, :, :]
|
||||
mask = x[:, 3:4, :, :]
|
||||
|
||||
# Process mask and image separately
|
||||
mask_features = self.mask_conv(mask)
|
||||
image_features = self.image_conv(image)
|
||||
|
||||
# Fuse features
|
||||
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||
|
||||
# Encoder
|
||||
x1, skip1 = self.down1(x0)
|
||||
x2, skip2 = self.down2(x1)
|
||||
x3, skip3 = self.down3(x2)
|
||||
x4, skip4 = self.down4(x3)
|
||||
|
||||
# Bottleneck
|
||||
x = self.bottleneck(x4)
|
||||
x = self.bottleneck(x3)
|
||||
|
||||
# Decoder with skip connections
|
||||
x = self.up1(x, skip4)
|
||||
x = self.up2(x, skip3)
|
||||
x = self.up3(x, skip2)
|
||||
x = self.up4(x, skip1)
|
||||
x = self.up1(x, skip3)
|
||||
x = self.up2(x, skip2)
|
||||
x = self.up3(x, skip1)
|
||||
|
||||
# Handle dimension mismatch for final concatenation
|
||||
# Handle dimension mismatch for final fusion
|
||||
if x.shape[2:] != x0.shape[2:]:
|
||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||
|
||||
# Concatenate with initial features for better detail preservation
|
||||
# Multi-scale fusion with initial features
|
||||
x = torch.cat([x, x0], dim=1)
|
||||
x = self.final_conv(x)
|
||||
x = self.multiscale_fusion(x)
|
||||
|
||||
# Output
|
||||
# Pre-output processing
|
||||
x = self.pre_output(x)
|
||||
|
||||
# Concatenate with original masked image for residual learning
|
||||
x = torch.cat([x, image], dim=1)
|
||||
x = self.output(x)
|
||||
|
||||
return x
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
import random
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
IMAGE_DIMENSION = 100
|
||||
|
||||
@@ -26,34 +26,58 @@ def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tu
|
||||
|
||||
return image_array, known_array
|
||||
|
||||
def resize(img: Image, augment: bool = False):
|
||||
transforms_list = [
|
||||
def resize(img: Image):
|
||||
resize_transforms = transforms.Compose([
|
||||
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||
]
|
||||
|
||||
if augment:
|
||||
transforms_list = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomVerticalFlip(),
|
||||
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
||||
transforms.RandomRotation(10),
|
||||
] + transforms_list
|
||||
|
||||
resize_transforms = transforms.Compose(transforms_list)
|
||||
])
|
||||
return resize_transforms(img)
|
||||
|
||||
def preprocess(input_array: np.ndarray):
|
||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||
return input_array
|
||||
|
||||
def augment_image(img: Image, strength: float = 0.7) -> Image:
|
||||
"""Apply comprehensive data augmentation for better generalization"""
|
||||
# Random horizontal flip
|
||||
if random.random() > 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() > 0.5:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90, 180, 270 degrees)
|
||||
if random.random() > 0.5:
|
||||
angle = random.choice([90, 180, 270])
|
||||
img = img.rotate(angle)
|
||||
|
||||
# Color augmentation - more aggressive for long training
|
||||
rand = random.random()
|
||||
if rand > 0.75:
|
||||
# Brightness
|
||||
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||
elif rand > 0.5:
|
||||
# Contrast
|
||||
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||
elif rand > 0.25:
|
||||
# Saturation
|
||||
factor = 1.0 + random.uniform(-0.15, 0.15) * strength
|
||||
img = ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
return img
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder
|
||||
Dataset class for loading images from a folder with augmentation support
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str, augment: bool = False):
|
||||
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.7):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||
self.augment = augment
|
||||
self.augment_strength = augment_strength
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
@@ -62,7 +86,13 @@ class ImageDataset(torch.utils.data.Dataset):
|
||||
index = int(idx)
|
||||
|
||||
image = Image.open(self.imagefiles[index])
|
||||
image = np.asarray(resize(image, self.augment))
|
||||
image = resize(image)
|
||||
|
||||
# Apply augmentation
|
||||
if self.augment:
|
||||
image = augment_image(image, self.augment_strength)
|
||||
|
||||
image = np.asarray(image)
|
||||
image = preprocess(image)
|
||||
spacing_x = random.randint(2,6)
|
||||
spacing_y = random.randint(2,6)
|
||||
|
||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||
config_dict['n_updates'] = 40000 # Extended training
|
||||
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 100
|
||||
config_dict['plot_at'] = 300
|
||||
config_dict['validate_at'] = 300 # Validate more frequently
|
||||
config_dict['print_train_stats_at'] = 50
|
||||
config_dict['print_stats_at'] = 200
|
||||
config_dict['plot_at'] = 500
|
||||
config_dict['validate_at'] = 500 # Regular validation
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 48, # Good balance between capacity and memory
|
||||
'dropout': 0.1 # Regularization
|
||||
'base_channels': 64,
|
||||
'dropout': 0.1 # Proper dropout for regularization
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -10,49 +10,36 @@ from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Subset
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||
class EnhancedRMSELoss(nn.Module):
|
||||
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.mse = nn.MSELoss()
|
||||
self.l1 = nn.L1Loss()
|
||||
|
||||
# Sobel filters for edge detection
|
||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
||||
|
||||
def edge_loss(self, pred, target):
|
||||
"""Compute edge-aware loss using Sobel filters"""
|
||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
|
||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||
return edge_loss
|
||||
|
||||
def forward(self, pred, target):
|
||||
mse_loss = self.mse(pred, target)
|
||||
l1_loss = self.l1(pred, target)
|
||||
edge_loss = self.edge_loss(pred, target)
|
||||
# Compute per-pixel squared error
|
||||
se = (pred - target) ** 2
|
||||
|
||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||
return total_loss
|
||||
# Weight edges more heavily for sharper results
|
||||
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||
|
||||
# Apply weighting
|
||||
weighted_se = se * edge_weight
|
||||
|
||||
# Compute RMSE
|
||||
mse = weighted_se.mean()
|
||||
rmse = torch.sqrt(mse + 1e-8)
|
||||
return rmse
|
||||
|
||||
|
||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||
@@ -68,6 +55,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
# Enable mixed precision training for memory efficiency
|
||||
use_amp = torch.cuda.is_available()
|
||||
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
||||
|
||||
if use_wandb:
|
||||
wandb.login()
|
||||
wandb.init(project="image_inpainting", config={
|
||||
@@ -84,21 +75,16 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
plotpath = os.path.join(results_path, "plots")
|
||||
os.makedirs(plotpath, exist_ok=True)
|
||||
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||
|
||||
n_total = len(image_dataset)
|
||||
n_test = int(n_total * testset_ratio)
|
||||
n_valid = int(n_total * validset_ratio)
|
||||
n_train = n_total - n_test - n_valid
|
||||
indices = np.random.permutation(n_total)
|
||||
|
||||
# Create datasets with and without augmentation
|
||||
train_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=True)
|
||||
val_test_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
|
||||
dataset_train = Subset(train_dataset_source, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(val_test_dataset_source, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(val_test_dataset_source, indices=indices[n_train + n_valid:n_total])
|
||||
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||
|
||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
@@ -116,15 +102,17 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss for better reconstruction
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||
# defining the loss - Enhanced RMSE for sharper predictions
|
||||
rmse_loss = EnhancedRMSELoss().to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||
|
||||
# Learning rate scheduler for better convergence
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
||||
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||
)
|
||||
|
||||
if use_wandb:
|
||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||
@@ -149,17 +137,31 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Mixed precision training for memory efficiency
|
||||
if use_amp:
|
||||
with torch.amp.autocast('cuda'):
|
||||
output = network(input)
|
||||
loss = rmse_loss(output, target)
|
||||
|
||||
loss = combined_loss(output, target)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Gradient clipping for training stability
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
output = network(input)
|
||||
loss = rmse_loss(output, target)
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping for training stability
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
||||
|
||||
scheduler.step()
|
||||
|
||||
loss_list.append(loss.item())
|
||||
|
||||
@@ -170,7 +172,11 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
# plotting
|
||||
if (i + 1) % plot_at == 0:
|
||||
print(f"Plotting images, current update {i + 1}")
|
||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
||||
# Convert to float32 for matplotlib compatibility
|
||||
plot(input.float().cpu().numpy(),
|
||||
target.detach().float().cpu().numpy(),
|
||||
output.detach().float().cpu().numpy(),
|
||||
plotpath, i)
|
||||
|
||||
# evaluating model every validate_at sample
|
||||
if (i + 1) % validate_at == 0:
|
||||
|
||||
Reference in New Issue
Block a user