Compare commits
6 Commits
gemini-3-p
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
|||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"learningrate": 0.0003,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"n_updates": 150000,
|
||||||
|
"plot_at": 400,
|
||||||
|
"early_stopping_patience": 40,
|
||||||
|
"print_stats_at": 200,
|
||||||
|
"print_train_stats_at": 50,
|
||||||
|
"validate_at": 200,
|
||||||
|
"accumulation_steps": 1,
|
||||||
|
"commands": {
|
||||||
|
"save_checkpoint": false,
|
||||||
|
"run_test_validation": false,
|
||||||
|
"generate_predictions": false
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -15,35 +15,51 @@ def init_weights(m):
|
|||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
if m.weight is not None:
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.bias, 0)
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class GatedSkipConnection(nn.Module):
|
||||||
"""Channel attention module (squeeze-and-excitation style)"""
|
"""Gated skip connection for better feature fusion"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, up_channels, skip_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
# Project skip to match up_channels if they differ
|
||||||
|
if skip_channels != up_channels:
|
||||||
|
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||||
|
else:
|
||||||
|
self.skip_proj = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, skip):
|
||||||
|
skip_proj = self.skip_proj(skip)
|
||||||
|
combined = torch.cat([x, skip], dim=1)
|
||||||
|
gate = self.gate(combined)
|
||||||
|
return x * gate + skip_proj * (1 - gate)
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientChannelAttention(nn.Module):
|
||||||
|
"""Efficient channel attention without dimensionality reduction"""
|
||||||
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||||
reduced = max(channels // reduction, 8)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
|
||||||
)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
avg_out = self.fc(self.avg_pool(x))
|
# Global pooling
|
||||||
max_out = self.fc(self.max_pool(x))
|
y = self.avg_pool(x)
|
||||||
return x * self.sigmoid(avg_out + max_out)
|
# 1D convolution on channel dimension
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
y = self.sigmoid(y)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SpatialAttention(nn.Module):
|
class SpatialAttention(nn.Module):
|
||||||
"""Spatial attention module"""
|
"""Efficient spatial attention module"""
|
||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
@@ -57,12 +73,12 @@ class SpatialAttention(nn.Module):
|
|||||||
return x * attn
|
return x * attn
|
||||||
|
|
||||||
|
|
||||||
class CBAM(nn.Module):
|
class EfficientAttention(nn.Module):
|
||||||
"""Convolutional Block Attention Module"""
|
"""Lightweight attention module combining channel and spatial"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channel_attn = ChannelAttention(channels, reduction)
|
self.channel_attn = EfficientChannelAttention(channels)
|
||||||
self.spatial_attn = SpatialAttention()
|
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.channel_attn(x)
|
x = self.channel_attn(x)
|
||||||
@@ -71,157 +87,221 @@ class CBAM(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> InstanceNorm2d -> GELU"""
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0, dilation=1):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
if separable and in_channels > 1:
|
||||||
# InstanceNorm is preferred for style/inpainting tasks
|
# Depthwise separable convolution for efficiency
|
||||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
self.conv = nn.Sequential(
|
||||||
self.act = nn.GELU()
|
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||||
|
nn.Conv2d(in_channels, out_channels, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.dropout(self.act(self.bn(self.conv(x))))
|
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(nn.Module):
|
||||||
|
"""Lightweight dense block for better gradient flow"""
|
||||||
|
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||||
|
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||||
|
self.bn = nn.BatchNorm2d(channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
out = layer(torch.cat(features, dim=1))
|
||||||
|
features.append(out)
|
||||||
|
out = self.fusion(torch.cat(features, dim=1))
|
||||||
|
out = self.relu(self.bn(out))
|
||||||
|
return out + x # Residual connection
|
||||||
|
|
||||||
class ResidualConvBlock(nn.Module):
|
class ResidualConvBlock(nn.Module):
|
||||||
"""Residual convolutional block for better gradient flow"""
|
"""Improved residual convolutional block with pre-activation"""
|
||||||
def __init__(self, channels, dropout=0.0, dilation=1):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
self.bn1 = nn.InstanceNorm2d(channels, affine=True)
|
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = nn.InstanceNorm2d(channels, affine=True)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
self.act = nn.GELU()
|
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
out = self.act(self.bn1(self.conv1(x)))
|
out = self.relu1(self.bn1(x))
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.relu2(self.bn2(out))
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.bn2(self.conv2(out))
|
out = self.conv2(out)
|
||||||
out = out + residual
|
return out + residual
|
||||||
return self.act(out)
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Enhanced downsampling block with dense and residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
skip = self.attention(x)
|
skip = self.attention(x)
|
||||||
return self.pool(skip), skip
|
return self.pool(skip), skip
|
||||||
|
|
||||||
class UpBlock(nn.Module):
|
class UpBlock(nn.Module):
|
||||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
"""Enhanced upsampling block with gated skip connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
# Skip connection has in_channels, upsampled has out_channels
|
||||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||||
|
# After gated skip: out_channels
|
||||||
|
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, skip):
|
def forward(self, x, skip):
|
||||||
x = self.up(x)
|
x = self.up(x)
|
||||||
# Handle dimension mismatch by interpolating x to match skip's size
|
# Handle dimension mismatch
|
||||||
if x.shape[2:] != skip.shape[2:]:
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||||
x = torch.cat([x, skip], dim=1)
|
x = self.gated_skip(x, skip)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class MyModel(nn.Module):
|
class MyModel(nn.Module):
|
||||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initial convolution with larger receptive field
|
# Separate mask processing for better feature extraction
|
||||||
self.init_conv = nn.Sequential(
|
self.mask_conv = nn.Sequential(
|
||||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||||
ConvBlock(base_channels, base_channels),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
ResidualConvBlock(base_channels)
|
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encoder (downsampling path)
|
# Image processing path
|
||||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
self.image_conv = nn.Sequential(
|
||||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
|
||||||
|
|
||||||
# Bottleneck with multiple residual blocks
|
|
||||||
self.bottleneck = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=2),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=4),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=8),
|
|
||||||
CBAM(base_channels * 16)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decoder (upsampling path)
|
|
||||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
|
||||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
|
||||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
|
||||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
|
||||||
|
|
||||||
# Final refinement layers
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 2, base_channels),
|
|
||||||
ResidualConvBlock(base_channels),
|
|
||||||
ConvBlock(base_channels, base_channels)
|
ConvBlock(base_channels, base_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output layer with smooth transition
|
# Fusion of mask and image features
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||||
|
nn.BatchNorm2d(base_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encoder with progressive feature extraction
|
||||||
|
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
|
||||||
|
# Enhanced bottleneck with multi-scale features and dense connections
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||||
|
DenseBlock(base_channels * 8, growth_rate=10, num_layers=3, dropout=dropout),
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||||
|
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||||
|
EfficientAttention(base_channels * 8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder with progressive reconstruction
|
||||||
|
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
|
||||||
|
# Multi-scale feature fusion with dense connections
|
||||||
|
self.multiscale_fusion = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 2, base_channels),
|
||||||
|
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||||
|
ConvBlock(base_channels, base_channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output with residual connection to input
|
||||||
|
self.pre_output = nn.Sequential(
|
||||||
|
ConvBlock(base_channels, base_channels),
|
||||||
|
ConvBlock(base_channels, base_channels // 2)
|
||||||
|
)
|
||||||
|
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||||
nn.GELU(),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
nn.Conv2d(base_channels // 2, 3, 1),
|
||||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply weight initialization
|
# Apply weight initialization
|
||||||
self.apply(init_weights)
|
self.apply(init_weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Initial convolution
|
# Split input into image and mask
|
||||||
x0 = self.init_conv(x)
|
image = x[:, :3, :, :]
|
||||||
|
mask = x[:, 3:4, :, :]
|
||||||
|
|
||||||
|
# Process mask and image separately
|
||||||
|
mask_features = self.mask_conv(mask)
|
||||||
|
image_features = self.image_conv(image)
|
||||||
|
|
||||||
|
# Fuse features
|
||||||
|
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||||
|
|
||||||
# Encoder
|
# Encoder
|
||||||
x1, skip1 = self.down1(x0)
|
x1, skip1 = self.down1(x0)
|
||||||
x2, skip2 = self.down2(x1)
|
x2, skip2 = self.down2(x1)
|
||||||
x3, skip3 = self.down3(x2)
|
x3, skip3 = self.down3(x2)
|
||||||
x4, skip4 = self.down4(x3)
|
|
||||||
|
|
||||||
# Bottleneck
|
# Bottleneck
|
||||||
x = self.bottleneck(x4)
|
x = self.bottleneck(x3)
|
||||||
|
|
||||||
# Decoder with skip connections
|
# Decoder with skip connections
|
||||||
x = self.up1(x, skip4)
|
x = self.up1(x, skip3)
|
||||||
x = self.up2(x, skip3)
|
x = self.up2(x, skip2)
|
||||||
x = self.up3(x, skip2)
|
x = self.up3(x, skip1)
|
||||||
x = self.up4(x, skip1)
|
|
||||||
|
|
||||||
# Handle dimension mismatch for final concatenation
|
# Handle dimension mismatch for final fusion
|
||||||
if x.shape[2:] != x0.shape[2:]:
|
if x.shape[2:] != x0.shape[2:]:
|
||||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
# Concatenate with initial features for better detail preservation
|
# Multi-scale fusion with initial features
|
||||||
x = torch.cat([x, x0], dim=1)
|
x = torch.cat([x, x0], dim=1)
|
||||||
x = self.final_conv(x)
|
x = self.multiscale_fusion(x)
|
||||||
|
|
||||||
# Output
|
# Pre-output processing
|
||||||
|
x = self.pre_output(x)
|
||||||
|
|
||||||
|
# Concatenate with original masked image for residual learning
|
||||||
|
x = torch.cat([x, image], dim=1)
|
||||||
x = self.output(x)
|
x = self.output(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
@@ -26,34 +26,58 @@ def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tu
|
|||||||
|
|
||||||
return image_array, known_array
|
return image_array, known_array
|
||||||
|
|
||||||
def resize(img: Image, augment: bool = False):
|
def resize(img: Image):
|
||||||
transforms_list = [
|
resize_transforms = transforms.Compose([
|
||||||
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
||||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||||
]
|
])
|
||||||
|
|
||||||
if augment:
|
|
||||||
transforms_list = [
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomVerticalFlip(),
|
|
||||||
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
|
||||||
transforms.RandomRotation(10),
|
|
||||||
] + transforms_list
|
|
||||||
|
|
||||||
resize_transforms = transforms.Compose(transforms_list)
|
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
|
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
return input_array
|
return input_array
|
||||||
|
|
||||||
|
def augment_image(img: Image, strength: float = 0.7) -> Image:
|
||||||
|
"""Apply comprehensive data augmentation for better generalization"""
|
||||||
|
# Random horizontal flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
# Random vertical flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
# Random rotation (90, 180, 270 degrees)
|
||||||
|
if random.random() > 0.5:
|
||||||
|
angle = random.choice([90, 180, 270])
|
||||||
|
img = img.rotate(angle)
|
||||||
|
|
||||||
|
# Color augmentation - more aggressive for long training
|
||||||
|
rand = random.random()
|
||||||
|
if rand > 0.75:
|
||||||
|
# Brightness
|
||||||
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||||
|
elif rand > 0.5:
|
||||||
|
# Contrast
|
||||||
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||||
|
elif rand > 0.25:
|
||||||
|
# Saturation
|
||||||
|
factor = 1.0 + random.uniform(-0.15, 0.15) * strength
|
||||||
|
img = ImageEnhance.Color(img).enhance(factor)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder
|
Dataset class for loading images from a folder with augmentation support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str, augment: bool = False):
|
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.7):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
self.augment_strength = augment_strength
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
@@ -62,7 +86,13 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index])
|
image = Image.open(self.imagefiles[index])
|
||||||
image = np.asarray(resize(image, self.augment))
|
image = resize(image)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
if self.augment:
|
||||||
|
image = augment_image(image, self.augment_strength)
|
||||||
|
|
||||||
|
image = np.asarray(image)
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
spacing_x = random.randint(2,6)
|
spacing_x = random.randint(2,6)
|
||||||
spacing_y = random.randint(2,6)
|
spacing_y = random.randint(2,6)
|
||||||
|
|||||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
config_dict['n_updates'] = 40000 # Extended training
|
||||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 50
|
||||||
config_dict['print_stats_at'] = 100
|
config_dict['print_stats_at'] = 200
|
||||||
config_dict['plot_at'] = 300
|
config_dict['plot_at'] = 500
|
||||||
config_dict['validate_at'] = 300 # Validate more frequently
|
config_dict['validate_at'] = 500 # Regular validation
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 48, # Good balance between capacity and memory
|
'base_channels': 64,
|
||||||
'dropout': 0.1 # Regularization
|
'dropout': 0.1 # Proper dropout for regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,49 +10,36 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import Subset
|
from torch.utils.data import Subset
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class EnhancedRMSELoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
|
||||||
self.l1_weight = l1_weight
|
|
||||||
self.edge_weight = edge_weight
|
|
||||||
self.mse = nn.MSELoss()
|
|
||||||
self.l1 = nn.L1Loss()
|
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
|
||||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
|
||||||
|
|
||||||
def edge_loss(self, pred, target):
|
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
|
||||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
|
||||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
|
||||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
|
||||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
|
||||||
return edge_loss
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
mse_loss = self.mse(pred, target)
|
# Compute per-pixel squared error
|
||||||
l1_loss = self.l1(pred, target)
|
se = (pred - target) ** 2
|
||||||
edge_loss = self.edge_loss(pred, target)
|
|
||||||
|
|
||||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
# Weight edges more heavily for sharper results
|
||||||
return total_loss
|
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||||
|
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||||
|
|
||||||
|
# Apply weighting
|
||||||
|
weighted_se = se * edge_weight
|
||||||
|
|
||||||
|
# Compute RMSE
|
||||||
|
mse = weighted_se.mean()
|
||||||
|
rmse = torch.sqrt(mse + 1e-8)
|
||||||
|
return rmse
|
||||||
|
|
||||||
|
|
||||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||||
@@ -68,6 +55,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# Enable mixed precision training for memory efficiency
|
||||||
|
use_amp = torch.cuda.is_available()
|
||||||
|
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.login()
|
wandb.login()
|
||||||
wandb.init(project="image_inpainting", config={
|
wandb.init(project="image_inpainting", config={
|
||||||
@@ -84,21 +75,16 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
plotpath = os.path.join(results_path, "plots")
|
plotpath = os.path.join(results_path, "plots")
|
||||||
os.makedirs(plotpath, exist_ok=True)
|
os.makedirs(plotpath, exist_ok=True)
|
||||||
|
|
||||||
image_dataset = datasets.ImageDataset(datafolder=data_path, augment=False)
|
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||||
|
|
||||||
n_total = len(image_dataset)
|
n_total = len(image_dataset)
|
||||||
n_test = int(n_total * testset_ratio)
|
n_test = int(n_total * testset_ratio)
|
||||||
n_valid = int(n_total * validset_ratio)
|
n_valid = int(n_total * validset_ratio)
|
||||||
n_train = n_total - n_test - n_valid
|
n_train = n_total - n_test - n_valid
|
||||||
indices = np.random.permutation(n_total)
|
indices = np.random.permutation(n_total)
|
||||||
|
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||||
# Create datasets with and without augmentation
|
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||||
train_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=True)
|
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||||
val_test_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=False)
|
|
||||||
|
|
||||||
dataset_train = Subset(train_dataset_source, indices=indices[0:n_train])
|
|
||||||
dataset_valid = Subset(val_test_dataset_source, indices=indices[n_train:n_train + n_valid])
|
|
||||||
dataset_test = Subset(val_test_dataset_source, indices=indices[n_train + n_valid:n_total])
|
|
||||||
|
|
||||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||||
|
|
||||||
@@ -116,15 +102,17 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss for better reconstruction
|
# defining the loss - Enhanced RMSE for sharper predictions
|
||||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
rmse_loss = EnhancedRMSELoss().to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
|
||||||
# Learning rate scheduler for better convergence
|
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||||
|
)
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||||
@@ -149,17 +137,31 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
output = network(input)
|
# Mixed precision training for memory efficiency
|
||||||
|
if use_amp:
|
||||||
|
with torch.amp.autocast('cuda'):
|
||||||
|
output = network(input)
|
||||||
|
loss = rmse_loss(output, target)
|
||||||
|
|
||||||
loss = combined_loss(output, target)
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
loss.backward()
|
# Gradient clipping for training stability
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
# Gradient clipping for training stability
|
scaler.step(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
scaler.update()
|
||||||
|
else:
|
||||||
|
output = network(input)
|
||||||
|
loss = rmse_loss(output, target)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
optimizer.step()
|
# Gradient clipping for training stability
|
||||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
loss_list.append(loss.item())
|
loss_list.append(loss.item())
|
||||||
|
|
||||||
@@ -170,7 +172,11 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
# plotting
|
# plotting
|
||||||
if (i + 1) % plot_at == 0:
|
if (i + 1) % plot_at == 0:
|
||||||
print(f"Plotting images, current update {i + 1}")
|
print(f"Plotting images, current update {i + 1}")
|
||||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
# Convert to float32 for matplotlib compatibility
|
||||||
|
plot(input.float().cpu().numpy(),
|
||||||
|
target.detach().float().cpu().numpy(),
|
||||||
|
output.detach().float().cpu().numpy(),
|
||||||
|
plotpath, i)
|
||||||
|
|
||||||
# evaluating model every validate_at sample
|
# evaluating model every validate_at sample
|
||||||
if (i + 1) % validate_at == 0:
|
if (i + 1) % validate_at == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user