Compare commits
6 Commits
gpt-5.2
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/*
|
data/*
|
||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"learningrate": 0.0003,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"n_updates": 150000,
|
||||||
|
"plot_at": 400,
|
||||||
|
"early_stopping_patience": 40,
|
||||||
|
"print_stats_at": 200,
|
||||||
|
"print_train_stats_at": 50,
|
||||||
|
"validate_at": 200,
|
||||||
|
"accumulation_steps": 1,
|
||||||
|
"commands": {
|
||||||
|
"save_checkpoint": false,
|
||||||
|
"run_test_validation": false,
|
||||||
|
"generate_predictions": false
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -18,44 +18,48 @@ def init_weights(m):
|
|||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.GroupNorm):
|
|
||||||
if m.weight is not None:
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_norm(num_channels: int) -> nn.Module:
|
class GatedSkipConnection(nn.Module):
|
||||||
"""Batch-size independent normalization (works well for batch_size=1 eval)."""
|
"""Gated skip connection for better feature fusion"""
|
||||||
# Choose a group count that divides num_channels.
|
def __init__(self, up_channels, skip_channels):
|
||||||
num_groups = min(32, num_channels)
|
super().__init__()
|
||||||
while num_groups > 1 and (num_channels % num_groups) != 0:
|
self.gate = nn.Sequential(
|
||||||
num_groups //= 2
|
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||||
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
# Project skip to match up_channels if they differ
|
||||||
|
if skip_channels != up_channels:
|
||||||
|
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||||
|
else:
|
||||||
|
self.skip_proj = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, skip):
|
||||||
|
skip_proj = self.skip_proj(skip)
|
||||||
|
combined = torch.cat([x, skip], dim=1)
|
||||||
|
gate = self.gate(combined)
|
||||||
|
return x * gate + skip_proj * (1 - gate)
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class EfficientChannelAttention(nn.Module):
|
||||||
"""Channel attention module (squeeze-and-excitation style)"""
|
"""Efficient channel attention without dimensionality reduction"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||||
reduced = max(channels // reduction, 8)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
|
||||||
)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
avg_out = self.fc(self.avg_pool(x))
|
# Global pooling
|
||||||
max_out = self.fc(self.max_pool(x))
|
y = self.avg_pool(x)
|
||||||
return x * self.sigmoid(avg_out + max_out)
|
# 1D convolution on channel dimension
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
y = self.sigmoid(y)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SpatialAttention(nn.Module):
|
class SpatialAttention(nn.Module):
|
||||||
"""Spatial attention module"""
|
"""Efficient spatial attention module"""
|
||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
@@ -69,12 +73,12 @@ class SpatialAttention(nn.Module):
|
|||||||
return x * attn
|
return x * attn
|
||||||
|
|
||||||
|
|
||||||
class CBAM(nn.Module):
|
class EfficientAttention(nn.Module):
|
||||||
"""Convolutional Block Attention Module"""
|
"""Lightweight attention module combining channel and spatial"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channel_attn = ChannelAttention(channels, reduction)
|
self.channel_attn = EfficientChannelAttention(channels)
|
||||||
self.spatial_attn = SpatialAttention()
|
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.channel_attn(x)
|
x = self.channel_attn(x)
|
||||||
@@ -84,175 +88,220 @@ class CBAM(nn.Module):
|
|||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
if separable and in_channels > 1:
|
||||||
self.bn = _make_norm(out_channels)
|
# Depthwise separable convolution for efficiency
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||||
|
nn.Conv2d(in_channels, out_channels, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(nn.Module):
|
||||||
|
"""Lightweight dense block for better gradient flow"""
|
||||||
|
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||||
|
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||||
|
self.bn = nn.BatchNorm2d(channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
out = layer(torch.cat(features, dim=1))
|
||||||
|
features.append(out)
|
||||||
|
out = self.fusion(torch.cat(features, dim=1))
|
||||||
|
out = self.relu(self.bn(out))
|
||||||
|
return out + x # Residual connection
|
||||||
|
|
||||||
class ResidualConvBlock(nn.Module):
|
class ResidualConvBlock(nn.Module):
|
||||||
"""Residual convolutional block for better gradient flow"""
|
"""Improved residual convolutional block with pre-activation"""
|
||||||
def __init__(self, channels, dropout=0.0):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
|
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn1 = _make_norm(channels)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
|
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = _make_norm(channels)
|
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
out = self.relu(self.bn1(self.conv1(x)))
|
out = self.relu1(self.bn1(x))
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.relu2(self.bn2(out))
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.bn2(self.conv2(out))
|
out = self.conv2(out)
|
||||||
out = out + residual
|
return out + residual
|
||||||
return self.relu(out)
|
|
||||||
|
|
||||||
|
|
||||||
class GatedConvBlock(nn.Module):
|
|
||||||
"""Gated convolution block (helps the network condition on the mask channel)."""
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.feature = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
|
||||||
self.gate = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
|
||||||
self.norm = _make_norm(out_channels)
|
|
||||||
self.act = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
feat = self.feature(x)
|
|
||||||
gate = torch.sigmoid(self.gate(x))
|
|
||||||
out = feat * gate
|
|
||||||
out = self.norm(out)
|
|
||||||
out = self.act(out)
|
|
||||||
out = self.dropout(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Enhanced downsampling block with dense and residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
skip = self.attention(x)
|
skip = self.attention(x)
|
||||||
return self.pool(skip), skip
|
return self.pool(skip), skip
|
||||||
|
|
||||||
class UpBlock(nn.Module):
|
class UpBlock(nn.Module):
|
||||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
"""Enhanced upsampling block with gated skip connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
# Skip connection has in_channels, upsampled has out_channels
|
||||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||||
|
# After gated skip: out_channels
|
||||||
|
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, skip):
|
def forward(self, x, skip):
|
||||||
x = self.up(x)
|
x = self.up(x)
|
||||||
# Handle dimension mismatch by interpolating x to match skip's size
|
# Handle dimension mismatch
|
||||||
if x.shape[2:] != skip.shape[2:]:
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||||
x = torch.cat([x, skip], dim=1)
|
x = self.gated_skip(x, skip)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class MyModel(nn.Module):
|
class MyModel(nn.Module):
|
||||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initial convolution with larger receptive field
|
# Separate mask processing for better feature extraction
|
||||||
self.init_conv = nn.Sequential(
|
self.mask_conv = nn.Sequential(
|
||||||
GatedConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3, dropout=dropout),
|
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||||
ConvBlock(base_channels, base_channels),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
ResidualConvBlock(base_channels)
|
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encoder (downsampling path)
|
# Image processing path
|
||||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
self.image_conv = nn.Sequential(
|
||||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
|
||||||
|
|
||||||
# Bottleneck with multiple residual blocks
|
|
||||||
self.bottleneck = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
CBAM(base_channels * 16)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decoder (upsampling path)
|
|
||||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
|
||||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
|
||||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
|
||||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
|
||||||
|
|
||||||
# Final refinement layers
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 2, base_channels),
|
|
||||||
ResidualConvBlock(base_channels),
|
|
||||||
ConvBlock(base_channels, base_channels)
|
ConvBlock(base_channels, base_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output layer with smooth transition
|
# Fusion of mask and image features
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||||
|
nn.BatchNorm2d(base_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encoder with progressive feature extraction
|
||||||
|
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
|
||||||
|
# Enhanced bottleneck with multi-scale features and dense connections
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||||
|
DenseBlock(base_channels * 8, growth_rate=10, num_layers=3, dropout=dropout),
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||||
|
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||||
|
EfficientAttention(base_channels * 8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder with progressive reconstruction
|
||||||
|
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
|
||||||
|
# Multi-scale feature fusion with dense connections
|
||||||
|
self.multiscale_fusion = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 2, base_channels),
|
||||||
|
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||||
|
ConvBlock(base_channels, base_channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output with residual connection to input
|
||||||
|
self.pre_output = nn.Sequential(
|
||||||
|
ConvBlock(base_channels, base_channels),
|
||||||
|
ConvBlock(base_channels, base_channels // 2)
|
||||||
|
)
|
||||||
|
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||||
nn.LeakyReLU(0.1, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
nn.Conv2d(base_channels // 2, 3, 1),
|
||||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply weight initialization
|
# Apply weight initialization
|
||||||
self.apply(init_weights)
|
self.apply(init_weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Initial convolution
|
# Split input into image and mask
|
||||||
x0 = self.init_conv(x)
|
image = x[:, :3, :, :]
|
||||||
|
mask = x[:, 3:4, :, :]
|
||||||
|
|
||||||
|
# Process mask and image separately
|
||||||
|
mask_features = self.mask_conv(mask)
|
||||||
|
image_features = self.image_conv(image)
|
||||||
|
|
||||||
|
# Fuse features
|
||||||
|
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||||
|
|
||||||
# Encoder
|
# Encoder
|
||||||
x1, skip1 = self.down1(x0)
|
x1, skip1 = self.down1(x0)
|
||||||
x2, skip2 = self.down2(x1)
|
x2, skip2 = self.down2(x1)
|
||||||
x3, skip3 = self.down3(x2)
|
x3, skip3 = self.down3(x2)
|
||||||
x4, skip4 = self.down4(x3)
|
|
||||||
|
|
||||||
# Bottleneck
|
# Bottleneck
|
||||||
x = self.bottleneck(x4)
|
x = self.bottleneck(x3)
|
||||||
|
|
||||||
# Decoder with skip connections
|
# Decoder with skip connections
|
||||||
x = self.up1(x, skip4)
|
x = self.up1(x, skip3)
|
||||||
x = self.up2(x, skip3)
|
x = self.up2(x, skip2)
|
||||||
x = self.up3(x, skip2)
|
x = self.up3(x, skip1)
|
||||||
x = self.up4(x, skip1)
|
|
||||||
|
|
||||||
# Handle dimension mismatch for final concatenation
|
# Handle dimension mismatch for final fusion
|
||||||
if x.shape[2:] != x0.shape[2:]:
|
if x.shape[2:] != x0.shape[2:]:
|
||||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
# Concatenate with initial features for better detail preservation
|
# Multi-scale fusion with initial features
|
||||||
x = torch.cat([x, x0], dim=1)
|
x = torch.cat([x, x0], dim=1)
|
||||||
x = self.final_conv(x)
|
x = self.multiscale_fusion(x)
|
||||||
|
|
||||||
# Output
|
# Pre-output processing
|
||||||
|
x = self.pre_output(x)
|
||||||
|
|
||||||
|
# Concatenate with original masked image for residual learning
|
||||||
|
x = torch.cat([x, image], dim=1)
|
||||||
x = self.output(x)
|
x = self.output(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
@@ -33,35 +33,51 @@ def resize(img: Image):
|
|||||||
])
|
])
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
|
|
||||||
|
|
||||||
def augment_geometric(img: Image.Image) -> Image.Image:
|
|
||||||
"""Lightweight, label-preserving augmentation (safe for train/val/test splits)."""
|
|
||||||
# Horizontal flip
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
|
|
||||||
# Vertical flip (less frequent)
|
|
||||||
if random.random() < 0.2:
|
|
||||||
img = img.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
|
||||||
# 90-degree rotations (no interpolation artifacts)
|
|
||||||
r = random.random()
|
|
||||||
if r < 0.25:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
elif r < 0.5:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif r < 0.75:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
return img
|
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
return input_array
|
return input_array
|
||||||
|
|
||||||
|
def augment_image(img: Image, strength: float = 0.7) -> Image:
|
||||||
|
"""Apply comprehensive data augmentation for better generalization"""
|
||||||
|
# Random horizontal flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
# Random vertical flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
# Random rotation (90, 180, 270 degrees)
|
||||||
|
if random.random() > 0.5:
|
||||||
|
angle = random.choice([90, 180, 270])
|
||||||
|
img = img.rotate(angle)
|
||||||
|
|
||||||
|
# Color augmentation - more aggressive for long training
|
||||||
|
rand = random.random()
|
||||||
|
if rand > 0.75:
|
||||||
|
# Brightness
|
||||||
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||||
|
elif rand > 0.5:
|
||||||
|
# Contrast
|
||||||
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||||
|
elif rand > 0.25:
|
||||||
|
# Saturation
|
||||||
|
factor = 1.0 + random.uniform(-0.15, 0.15) * strength
|
||||||
|
img = ImageEnhance.Color(img).enhance(factor)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder
|
Dataset class for loading images from a folder with augmentation support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str):
|
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.7):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
|
self.augment = augment
|
||||||
|
self.augment_strength = augment_strength
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
@@ -69,17 +85,19 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx:int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index]).convert("RGB")
|
image = Image.open(self.imagefiles[index])
|
||||||
image = augment_geometric(image)
|
image = resize(image)
|
||||||
image = np.asarray(resize(image))
|
|
||||||
|
# Apply augmentation
|
||||||
|
if self.augment:
|
||||||
|
image = augment_image(image, self.augment_strength)
|
||||||
|
|
||||||
|
image = np.asarray(image)
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
spacing_x = random.randint(2,6)
|
||||||
# Sample a grid-mask similar in density to the challenge testset (~8% known pixels).
|
spacing_y = random.randint(2,6)
|
||||||
# IMPORTANT: offset ranges must be tied to spacing to avoid accidental distribution shift.
|
offset_x = random.randint(0,8)
|
||||||
spacing_x = random.randint(4, 6)
|
offset_y = random.randint(0,8)
|
||||||
spacing_y = random.randint(2, 4)
|
|
||||||
offset_x = random.randint(0, spacing_x - 1)
|
|
||||||
offset_y = random.randint(0, spacing_y - 1)
|
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
|
|||||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
config_dict['n_updates'] = 40000 # Extended training
|
||||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 50
|
||||||
config_dict['print_stats_at'] = 100
|
config_dict['print_stats_at'] = 200
|
||||||
config_dict['plot_at'] = 300
|
config_dict['plot_at'] = 500
|
||||||
config_dict['validate_at'] = 300 # Validate more frequently
|
config_dict['validate_at'] = 500 # Regular validation
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 48, # Good balance between capacity and memory
|
'base_channels': 64,
|
||||||
'dropout': 0.1 # Regularization
|
'dropout': 0.1 # Proper dropout for regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,49 +10,36 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import Subset
|
from torch.utils.data import Subset
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class EnhancedRMSELoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
|
||||||
self.l1_weight = l1_weight
|
|
||||||
self.edge_weight = edge_weight
|
|
||||||
self.mse = nn.MSELoss()
|
|
||||||
self.l1 = nn.L1Loss()
|
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
|
||||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
|
||||||
|
|
||||||
def edge_loss(self, pred, target):
|
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
|
||||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
|
||||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
|
||||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
|
||||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
|
||||||
return edge_loss
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
mse_loss = self.mse(pred, target)
|
# Compute per-pixel squared error
|
||||||
l1_loss = self.l1(pred, target)
|
se = (pred - target) ** 2
|
||||||
edge_loss = self.edge_loss(pred, target)
|
|
||||||
|
|
||||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
# Weight edges more heavily for sharper results
|
||||||
return total_loss
|
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||||
|
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||||
|
|
||||||
|
# Apply weighting
|
||||||
|
weighted_se = se * edge_weight
|
||||||
|
|
||||||
|
# Compute RMSE
|
||||||
|
mse = weighted_se.mean()
|
||||||
|
rmse = torch.sqrt(mse + 1e-8)
|
||||||
|
return rmse
|
||||||
|
|
||||||
|
|
||||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||||
@@ -67,6 +54,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# Enable mixed precision training for memory efficiency
|
||||||
|
use_amp = torch.cuda.is_available()
|
||||||
|
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.login()
|
wandb.login()
|
||||||
@@ -111,15 +102,17 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss for better reconstruction
|
# defining the loss - Enhanced RMSE for sharper predictions
|
||||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
rmse_loss = EnhancedRMSELoss().to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
|
||||||
# Learning rate scheduler for better convergence
|
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||||
|
)
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||||
@@ -144,17 +137,31 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
output = network(input)
|
# Mixed precision training for memory efficiency
|
||||||
|
if use_amp:
|
||||||
loss = combined_loss(output, target)
|
with torch.amp.autocast('cuda'):
|
||||||
|
output = network(input)
|
||||||
loss.backward()
|
loss = rmse_loss(output, target)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# Gradient clipping for training stability
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
output = network(input)
|
||||||
|
loss = rmse_loss(output, target)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping for training stability
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
# Gradient clipping for training stability
|
scheduler.step()
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
|
||||||
|
|
||||||
loss_list.append(loss.item())
|
loss_list.append(loss.item())
|
||||||
|
|
||||||
@@ -165,7 +172,11 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
# plotting
|
# plotting
|
||||||
if (i + 1) % plot_at == 0:
|
if (i + 1) % plot_at == 0:
|
||||||
print(f"Plotting images, current update {i + 1}")
|
print(f"Plotting images, current update {i + 1}")
|
||||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
# Convert to float32 for matplotlib compatibility
|
||||||
|
plot(input.float().cpu().numpy(),
|
||||||
|
target.detach().float().cpu().numpy(),
|
||||||
|
output.detach().float().cpu().numpy(),
|
||||||
|
plotpath, i)
|
||||||
|
|
||||||
# evaluating model every validate_at sample
|
# evaluating model every validate_at sample
|
||||||
if (i + 1) % validate_at == 0:
|
if (i + 1) % validate_at == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user