Compare commits
6 Commits
claude-son
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
*.pt
|
||||
__pycache__/
|
||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"learningrate": 0.0003,
|
||||
"weight_decay": 1e-05,
|
||||
"n_updates": 150000,
|
||||
"plot_at": 400,
|
||||
"early_stopping_patience": 40,
|
||||
"print_stats_at": 200,
|
||||
"print_train_stats_at": 50,
|
||||
"validate_at": 200,
|
||||
"accumulation_steps": 1,
|
||||
"commands": {
|
||||
"save_checkpoint": false,
|
||||
"run_test_validation": false,
|
||||
"generate_predictions": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -20,28 +20,46 @@ def init_weights(m):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""Channel attention module (squeeze-and-excitation style)"""
|
||||
def __init__(self, channels, reduction=16):
|
||||
class GatedSkipConnection(nn.Module):
|
||||
"""Gated skip connection for better feature fusion"""
|
||||
def __init__(self, up_channels, skip_channels):
|
||||
super().__init__()
|
||||
self.gate = nn.Sequential(
|
||||
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# Project skip to match up_channels if they differ
|
||||
if skip_channels != up_channels:
|
||||
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||
else:
|
||||
self.skip_proj = nn.Identity()
|
||||
|
||||
def forward(self, x, skip):
|
||||
skip_proj = self.skip_proj(skip)
|
||||
combined = torch.cat([x, skip], dim=1)
|
||||
gate = self.gate(combined)
|
||||
return x * gate + skip_proj * (1 - gate)
|
||||
|
||||
|
||||
class EfficientChannelAttention(nn.Module):
|
||||
"""Efficient channel attention without dimensionality reduction"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
reduced = max(channels // reduction, 8)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
||||
)
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
return x * self.sigmoid(avg_out + max_out)
|
||||
# Global pooling
|
||||
y = self.avg_pool(x)
|
||||
# 1D convolution on channel dimension
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
y = self.sigmoid(y)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""Spatial attention module"""
|
||||
"""Efficient spatial attention module"""
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||
@@ -55,12 +73,12 @@ class SpatialAttention(nn.Module):
|
||||
return x * attn
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
"""Convolutional Block Attention Module"""
|
||||
def __init__(self, channels, reduction=16):
|
||||
class EfficientAttention(nn.Module):
|
||||
"""Lightweight attention module combining channel and spatial"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channel_attn = ChannelAttention(channels, reduction)
|
||||
self.spatial_attn = SpatialAttention()
|
||||
self.channel_attn = EfficientChannelAttention(channels)
|
||||
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.channel_attn(x)
|
||||
@@ -70,176 +88,220 @@ class CBAM(nn.Module):
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
if separable and in_channels > 1:
|
||||
# Depthwise separable convolution for efficiency
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||
nn.Conv2d(in_channels, out_channels, 1)
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||
|
||||
|
||||
class DenseBlock(nn.Module):
|
||||
"""Lightweight dense block for better gradient flow"""
|
||||
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||
self.bn = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for layer in self.layers:
|
||||
out = layer(torch.cat(features, dim=1))
|
||||
features.append(out)
|
||||
out = self.fusion(torch.cat(features, dim=1))
|
||||
out = self.relu(self.bn(out))
|
||||
return out + x # Residual connection
|
||||
|
||||
class ResidualConvBlock(nn.Module):
|
||||
"""Residual convolutional block for better gradient flow"""
|
||||
"""Improved residual convolutional block with pre-activation"""
|
||||
def __init__(self, channels, dropout=0.0):
|
||||
super().__init__()
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.conv1(out)
|
||||
out = self.relu2(self.bn2(out))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + residual
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
class DilatedResidualBlock(nn.Module):
|
||||
"""Residual block with dilated convolutions for larger receptive field"""
|
||||
def __init__(self, channels, dilation=2, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + residual
|
||||
return self.relu(out)
|
||||
out = self.conv2(out)
|
||||
return out + residual
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Enhanced downsampling block with dense and residual connections"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
if use_dense:
|
||||
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||
else:
|
||||
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.residual(x)
|
||||
x = self.dense(x)
|
||||
skip = self.attention(x)
|
||||
return self.pool(skip), skip
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Enhanced upsampling block with gated skip connections"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||
super().__init__()
|
||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
||||
# Skip connection has in_channels, upsampled has out_channels
|
||||
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||
# After gated skip: out_channels
|
||||
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
if use_dense:
|
||||
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||
else:
|
||||
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.up(x)
|
||||
# Handle dimension mismatch by interpolating x to match skip's size
|
||||
# Handle dimension mismatch
|
||||
if x.shape[2:] != skip.shape[2:]:
|
||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.gated_skip(x, skip)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.residual(x)
|
||||
x = self.dense(x)
|
||||
x = self.attention(x)
|
||||
return x
|
||||
|
||||
class MyModel(nn.Module):
|
||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
||||
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
# Initial convolution with larger receptive field
|
||||
self.init_conv = nn.Sequential(
|
||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
||||
ConvBlock(base_channels, base_channels),
|
||||
ResidualConvBlock(base_channels)
|
||||
# Separate mask processing for better feature extraction
|
||||
self.mask_conv = nn.Sequential(
|
||||
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
# Encoder (downsampling path)
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
|
||||
# Bottleneck with multi-scale dilated convolutions (ASPP-style)
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
DilatedResidualBlock(base_channels * 16, dilation=2, dropout=dropout),
|
||||
DilatedResidualBlock(base_channels * 16, dilation=4, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
CBAM(base_channels * 16)
|
||||
)
|
||||
|
||||
# Decoder (upsampling path)
|
||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
||||
|
||||
# Final refinement layers
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvBlock(base_channels * 2, base_channels),
|
||||
ResidualConvBlock(base_channels),
|
||||
# Image processing path
|
||||
self.image_conv = nn.Sequential(
|
||||
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
# Output layer with smooth transition
|
||||
# Fusion of mask and image features
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
# Encoder with progressive feature extraction
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||
|
||||
# Enhanced bottleneck with multi-scale features and dense connections
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||
DenseBlock(base_channels * 8, growth_rate=10, num_layers=3, dropout=dropout),
|
||||
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||
EfficientAttention(base_channels * 8)
|
||||
)
|
||||
|
||||
# Decoder with progressive reconstruction
|
||||
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
|
||||
# Multi-scale feature fusion with dense connections
|
||||
self.multiscale_fusion = nn.Sequential(
|
||||
ConvBlock(base_channels * 2, base_channels),
|
||||
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
# Output with residual connection to input
|
||||
self.pre_output = nn.Sequential(
|
||||
ConvBlock(base_channels, base_channels),
|
||||
ConvBlock(base_channels, base_channels // 2)
|
||||
)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.1, inplace=True),
|
||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
||||
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels // 2, 3, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Apply weight initialization
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
# Initial convolution
|
||||
x0 = self.init_conv(x)
|
||||
# Split input into image and mask
|
||||
image = x[:, :3, :, :]
|
||||
mask = x[:, 3:4, :, :]
|
||||
|
||||
# Process mask and image separately
|
||||
mask_features = self.mask_conv(mask)
|
||||
image_features = self.image_conv(image)
|
||||
|
||||
# Fuse features
|
||||
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||
|
||||
# Encoder
|
||||
x1, skip1 = self.down1(x0)
|
||||
x2, skip2 = self.down2(x1)
|
||||
x3, skip3 = self.down3(x2)
|
||||
x4, skip4 = self.down4(x3)
|
||||
|
||||
# Bottleneck
|
||||
x = self.bottleneck(x4)
|
||||
x = self.bottleneck(x3)
|
||||
|
||||
# Decoder with skip connections
|
||||
x = self.up1(x, skip4)
|
||||
x = self.up2(x, skip3)
|
||||
x = self.up3(x, skip2)
|
||||
x = self.up4(x, skip1)
|
||||
x = self.up1(x, skip3)
|
||||
x = self.up2(x, skip2)
|
||||
x = self.up3(x, skip1)
|
||||
|
||||
# Handle dimension mismatch for final concatenation
|
||||
# Handle dimension mismatch for final fusion
|
||||
if x.shape[2:] != x0.shape[2:]:
|
||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||
|
||||
# Concatenate with initial features for better detail preservation
|
||||
# Multi-scale fusion with initial features
|
||||
x = torch.cat([x, x0], dim=1)
|
||||
x = self.final_conv(x)
|
||||
x = self.multiscale_fusion(x)
|
||||
|
||||
# Output
|
||||
# Pre-output processing
|
||||
x = self.pre_output(x)
|
||||
|
||||
# Concatenate with original masked image for residual learning
|
||||
x = torch.cat([x, image], dim=1)
|
||||
x = self.output(x)
|
||||
|
||||
return x
|
||||
@@ -32,75 +32,72 @@ def resize(img: Image):
|
||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||
])
|
||||
return resize_transforms(img)
|
||||
|
||||
def preprocess(input_array: np.ndarray):
|
||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||
return input_array
|
||||
|
||||
def augment_image(img: Image, strength: float = 0.7) -> Image:
|
||||
"""Apply comprehensive data augmentation for better generalization"""
|
||||
# Random horizontal flip
|
||||
if random.random() > 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() > 0.5:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90, 180, 270 degrees)
|
||||
if random.random() > 0.5:
|
||||
angle = random.choice([90, 180, 270])
|
||||
img = img.rotate(angle)
|
||||
|
||||
# Color augmentation - more aggressive for long training
|
||||
rand = random.random()
|
||||
if rand > 0.75:
|
||||
# Brightness
|
||||
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||
elif rand > 0.5:
|
||||
# Contrast
|
||||
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||
elif rand > 0.25:
|
||||
# Saturation
|
||||
factor = 1.0 + random.uniform(-0.15, 0.15) * strength
|
||||
img = ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
return img
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder with data augmentation
|
||||
Dataset class for loading images from a folder with augmentation support
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str, augment: bool = True):
|
||||
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.7):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||
self.augment = augment
|
||||
self.augment_strength = augment_strength
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
|
||||
def augment_image(self, image: Image) -> Image:
|
||||
"""Apply random augmentations to image"""
|
||||
# Random horizontal flip
|
||||
if random.random() > 0.5:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() > 0.5:
|
||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90, 180, 270 degrees)
|
||||
if random.random() > 0.5:
|
||||
angle = random.choice([90, 180, 270])
|
||||
image = image.rotate(angle)
|
||||
|
||||
# Random brightness adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Brightness(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
# Random contrast adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
# Random color adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Color(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
return image
|
||||
|
||||
def __getitem__(self, idx:int):
|
||||
index = int(idx)
|
||||
|
||||
image = Image.open(self.imagefiles[index])
|
||||
image = resize(image)
|
||||
|
||||
# Apply augmentation if enabled
|
||||
# Apply augmentation
|
||||
if self.augment:
|
||||
image = self.augment_image(image)
|
||||
image = augment_image(image, self.augment_strength)
|
||||
|
||||
image = np.asarray(image)
|
||||
image = preprocess(image)
|
||||
|
||||
# Vary spacing and offset more for additional diversity
|
||||
spacing_x = random.randint(2,7)
|
||||
spacing_y = random.randint(2,7)
|
||||
offset_x = random.randint(0,10)
|
||||
offset_y = random.randint(0,10)
|
||||
spacing_x = random.randint(2,6)
|
||||
spacing_y = random.randint(2,6)
|
||||
offset_x = random.randint(0,8)
|
||||
offset_y = random.randint(0,8)
|
||||
spacing = (spacing_x, spacing_y)
|
||||
offset = (offset_x, offset_y)
|
||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||
|
||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 2e-4 # Slightly lower for more stable training
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced for less aggressive regularization
|
||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 12 # Larger batch for more stable gradients
|
||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
||||
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||
config_dict['n_updates'] = 40000 # Extended training
|
||||
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 100
|
||||
config_dict['plot_at'] = 400
|
||||
config_dict['validate_at'] = 200 # Validate frequently but not too often
|
||||
config_dict['print_train_stats_at'] = 50
|
||||
config_dict['print_stats_at'] = 200
|
||||
config_dict['plot_at'] = 500
|
||||
config_dict['validate_at'] = 500 # Regular validation
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 32, # Smaller base for efficiency, depth compensates
|
||||
'dropout': 0.15 # Slightly more regularization with augmentation
|
||||
'base_channels': 64,
|
||||
'dropout': 0.1 # Proper dropout for regularization
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -10,49 +10,36 @@ from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Subset
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + Edge-aware component for better reconstruction"""
|
||||
def __init__(self, mse_weight=0.7, l1_weight=0.8, edge_weight=0.2):
|
||||
class EnhancedRMSELoss(nn.Module):
|
||||
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.mse = nn.MSELoss()
|
||||
self.l1 = nn.L1Loss()
|
||||
|
||||
# Sobel filters for edge detection
|
||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
||||
|
||||
def edge_loss(self, pred, target):
|
||||
"""Compute edge-aware loss using Sobel filters"""
|
||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
|
||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||
return edge_loss
|
||||
|
||||
def forward(self, pred, target):
|
||||
mse_loss = self.mse(pred, target)
|
||||
l1_loss = self.l1(pred, target)
|
||||
edge_loss = self.edge_loss(pred, target)
|
||||
# Compute per-pixel squared error
|
||||
se = (pred - target) ** 2
|
||||
|
||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||
return total_loss
|
||||
# Weight edges more heavily for sharper results
|
||||
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||
|
||||
# Apply weighting
|
||||
weighted_se = se * edge_weight
|
||||
|
||||
# Compute RMSE
|
||||
mse = weighted_se.mean()
|
||||
rmse = torch.sqrt(mse + 1e-8)
|
||||
return rmse
|
||||
|
||||
|
||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||
@@ -68,6 +55,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
# Enable mixed precision training for memory efficiency
|
||||
use_amp = torch.cuda.is_available()
|
||||
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
||||
|
||||
if use_wandb:
|
||||
wandb.login()
|
||||
wandb.init(project="image_inpainting", config={
|
||||
@@ -84,24 +75,20 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
plotpath = os.path.join(results_path, "plots")
|
||||
os.makedirs(plotpath, exist_ok=True)
|
||||
|
||||
# Create dataset with augmentation for training, without for validation/test
|
||||
image_dataset_full = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||
|
||||
n_total = len(image_dataset_full)
|
||||
n_total = len(image_dataset)
|
||||
n_test = int(n_total * testset_ratio)
|
||||
n_valid = int(n_total * validset_ratio)
|
||||
n_train = n_total - n_test - n_valid
|
||||
indices = np.random.permutation(n_total)
|
||||
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||
|
||||
# Create augmented dataset for training
|
||||
image_dataset_train = datasets.ImageDataset(datafolder=data_path, augment=True)
|
||||
dataset_train = Subset(image_dataset_train, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset_full, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset_full, indices=indices[n_train + n_valid:n_total])
|
||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
assert n_total == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
del image_dataset_full, image_dataset_train
|
||||
del image_dataset
|
||||
|
||||
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
||||
num_workers=0, shuffle=True)
|
||||
@@ -115,19 +102,17 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss with optimized weights
|
||||
combined_loss = CombinedLoss(mse_weight=0.7, l1_weight=0.8, edge_weight=0.2).to(device)
|
||||
# defining the loss - Enhanced RMSE for sharper predictions
|
||||
rmse_loss = EnhancedRMSELoss().to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||
|
||||
# Learning rate scheduler with better configuration
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=1e-7)
|
||||
|
||||
# Mixed precision training for faster computation and lower memory usage
|
||||
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
||||
use_amp = scaler is not None
|
||||
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||
)
|
||||
|
||||
if use_wandb:
|
||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||
@@ -136,13 +121,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
counter = 0
|
||||
best_validation_loss = np.inf
|
||||
loss_list = []
|
||||
accumulation_steps = 2 # Gradient accumulation for effective larger batch size
|
||||
|
||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||
|
||||
print(f"Started training on device {device}")
|
||||
print(f"Using mixed precision: {use_amp}")
|
||||
print(f"Gradient accumulation steps: {accumulation_steps}")
|
||||
|
||||
while i < n_updates:
|
||||
|
||||
@@ -153,33 +135,35 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
if (i + 1) % print_train_stats_at == 0:
|
||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||
|
||||
# Use mixed precision if available
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Mixed precision training for memory efficiency
|
||||
if use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast('cuda'):
|
||||
output = network(input)
|
||||
loss = combined_loss(output, target)
|
||||
loss = loss / accumulation_steps
|
||||
loss = rmse_loss(output, target)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Gradient clipping for training stability
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
output = network(input)
|
||||
loss = combined_loss(output, target)
|
||||
loss = loss / accumulation_steps
|
||||
loss = rmse_loss(output, target)
|
||||
loss.backward()
|
||||
|
||||
# Gradient accumulation - update weights every accumulation_steps
|
||||
if (i + 1) % accumulation_steps == 0:
|
||||
if use_amp:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step(i / n_updates)
|
||||
# Gradient clipping for training stability
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
loss_list.append(loss.item() * accumulation_steps)
|
||||
optimizer.step()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
loss_list.append(loss.item())
|
||||
|
||||
# writing the stats to wandb
|
||||
if use_wandb and (i+1) % print_stats_at == 0:
|
||||
@@ -188,9 +172,11 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
# plotting
|
||||
if (i + 1) % plot_at == 0:
|
||||
print(f"Plotting images, current update {i + 1}")
|
||||
# Convert to float32 for matplotlib compatibility (mixed precision may produce float16)
|
||||
plot(input.float().cpu().numpy(), target.detach().float().cpu().numpy(),
|
||||
output.detach().float().cpu().numpy(), plotpath, i)
|
||||
# Convert to float32 for matplotlib compatibility
|
||||
plot(input.float().cpu().numpy(),
|
||||
target.detach().float().cpu().numpy(),
|
||||
output.detach().float().cpu().numpy(),
|
||||
plotpath, i)
|
||||
|
||||
# evaluating model every validate_at sample
|
||||
if (i + 1) % validate_at == 0:
|
||||
|
||||
Reference in New Issue
Block a user