Compare commits
6 Commits
claude-son
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
|||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"learningrate": 0.0003,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"n_updates": 150000,
|
||||||
|
"plot_at": 400,
|
||||||
|
"early_stopping_patience": 40,
|
||||||
|
"print_stats_at": 200,
|
||||||
|
"print_train_stats_at": 50,
|
||||||
|
"validate_at": 200,
|
||||||
|
"accumulation_steps": 1,
|
||||||
|
"commands": {
|
||||||
|
"save_checkpoint": false,
|
||||||
|
"run_test_validation": false,
|
||||||
|
"generate_predictions": false
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -20,28 +20,46 @@ def init_weights(m):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class GatedSkipConnection(nn.Module):
|
||||||
"""Channel attention module (squeeze-and-excitation style)"""
|
"""Gated skip connection for better feature fusion"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, up_channels, skip_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
# Project skip to match up_channels if they differ
|
||||||
|
if skip_channels != up_channels:
|
||||||
|
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||||
|
else:
|
||||||
|
self.skip_proj = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, skip):
|
||||||
|
skip_proj = self.skip_proj(skip)
|
||||||
|
combined = torch.cat([x, skip], dim=1)
|
||||||
|
gate = self.gate(combined)
|
||||||
|
return x * gate + skip_proj * (1 - gate)
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientChannelAttention(nn.Module):
|
||||||
|
"""Efficient channel attention without dimensionality reduction"""
|
||||||
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||||
reduced = max(channels // reduction, 8)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
|
||||||
)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
avg_out = self.fc(self.avg_pool(x))
|
# Global pooling
|
||||||
max_out = self.fc(self.max_pool(x))
|
y = self.avg_pool(x)
|
||||||
return x * self.sigmoid(avg_out + max_out)
|
# 1D convolution on channel dimension
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
y = self.sigmoid(y)
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
class SpatialAttention(nn.Module):
|
class SpatialAttention(nn.Module):
|
||||||
"""Spatial attention module"""
|
"""Efficient spatial attention module"""
|
||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
@@ -55,12 +73,12 @@ class SpatialAttention(nn.Module):
|
|||||||
return x * attn
|
return x * attn
|
||||||
|
|
||||||
|
|
||||||
class CBAM(nn.Module):
|
class EfficientAttention(nn.Module):
|
||||||
"""Convolutional Block Attention Module"""
|
"""Lightweight attention module combining channel and spatial"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channel_attn = ChannelAttention(channels, reduction)
|
self.channel_attn = EfficientChannelAttention(channels)
|
||||||
self.spatial_attn = SpatialAttention()
|
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.channel_attn(x)
|
x = self.channel_attn(x)
|
||||||
@@ -70,176 +88,220 @@ class CBAM(nn.Module):
|
|||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
if separable and in_channels > 1:
|
||||||
|
# Depthwise separable convolution for efficiency
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||||
|
nn.Conv2d(in_channels, out_channels, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||||
self.bn = nn.BatchNorm2d(out_channels)
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(nn.Module):
|
||||||
|
"""Lightweight dense block for better gradient flow"""
|
||||||
|
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||||
|
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||||
|
self.bn = nn.BatchNorm2d(channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
out = layer(torch.cat(features, dim=1))
|
||||||
|
features.append(out)
|
||||||
|
out = self.fusion(torch.cat(features, dim=1))
|
||||||
|
out = self.relu(self.bn(out))
|
||||||
|
return out + x # Residual connection
|
||||||
|
|
||||||
class ResidualConvBlock(nn.Module):
|
class ResidualConvBlock(nn.Module):
|
||||||
"""Residual convolutional block for better gradient flow"""
|
"""Improved residual convolutional block with pre-activation"""
|
||||||
def __init__(self, channels, dropout=0.0):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
|
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn1 = nn.BatchNorm2d(channels)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
|
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = nn.BatchNorm2d(channels)
|
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
out = self.relu(self.bn1(self.conv1(x)))
|
out = self.relu1(self.bn1(x))
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.relu2(self.bn2(out))
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.bn2(self.conv2(out))
|
out = self.conv2(out)
|
||||||
out = out + residual
|
return out + residual
|
||||||
return self.relu(out)
|
|
||||||
|
|
||||||
|
|
||||||
class DilatedResidualBlock(nn.Module):
|
|
||||||
"""Residual block with dilated convolutions for larger receptive field"""
|
|
||||||
def __init__(self, channels, dilation=2, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
|
||||||
self.bn1 = nn.BatchNorm2d(channels)
|
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
|
||||||
self.bn2 = nn.BatchNorm2d(channels)
|
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
out = self.relu(self.bn1(self.conv1(x)))
|
|
||||||
out = self.dropout(out)
|
|
||||||
out = self.bn2(self.conv2(out))
|
|
||||||
out = out + residual
|
|
||||||
return self.relu(out)
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Enhanced downsampling block with dense and residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
skip = self.attention(x)
|
skip = self.attention(x)
|
||||||
return self.pool(skip), skip
|
return self.pool(skip), skip
|
||||||
|
|
||||||
class UpBlock(nn.Module):
|
class UpBlock(nn.Module):
|
||||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
"""Enhanced upsampling block with gated skip connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
# Skip connection has in_channels, upsampled has out_channels
|
||||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||||
|
# After gated skip: out_channels
|
||||||
|
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, skip):
|
def forward(self, x, skip):
|
||||||
x = self.up(x)
|
x = self.up(x)
|
||||||
# Handle dimension mismatch by interpolating x to match skip's size
|
# Handle dimension mismatch
|
||||||
if x.shape[2:] != skip.shape[2:]:
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||||
x = torch.cat([x, skip], dim=1)
|
x = self.gated_skip(x, skip)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class MyModel(nn.Module):
|
class MyModel(nn.Module):
|
||||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initial convolution with larger receptive field
|
# Separate mask processing for better feature extraction
|
||||||
self.init_conv = nn.Sequential(
|
self.mask_conv = nn.Sequential(
|
||||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||||
ConvBlock(base_channels, base_channels),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
ResidualConvBlock(base_channels)
|
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encoder (downsampling path)
|
# Image processing path
|
||||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
self.image_conv = nn.Sequential(
|
||||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
|
||||||
|
|
||||||
# Bottleneck with multi-scale dilated convolutions (ASPP-style)
|
|
||||||
self.bottleneck = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
DilatedResidualBlock(base_channels * 16, dilation=2, dropout=dropout),
|
|
||||||
DilatedResidualBlock(base_channels * 16, dilation=4, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
CBAM(base_channels * 16)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decoder (upsampling path)
|
|
||||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
|
||||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
|
||||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
|
||||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
|
||||||
|
|
||||||
# Final refinement layers
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 2, base_channels),
|
|
||||||
ResidualConvBlock(base_channels),
|
|
||||||
ConvBlock(base_channels, base_channels)
|
ConvBlock(base_channels, base_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output layer with smooth transition
|
# Fusion of mask and image features
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||||
|
nn.BatchNorm2d(base_channels),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encoder with progressive feature extraction
|
||||||
|
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
|
||||||
|
# Enhanced bottleneck with multi-scale features and dense connections
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||||
|
DenseBlock(base_channels * 8, growth_rate=10, num_layers=3, dropout=dropout),
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||||
|
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||||
|
EfficientAttention(base_channels * 8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder with progressive reconstruction
|
||||||
|
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
|
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
|
||||||
|
# Multi-scale feature fusion with dense connections
|
||||||
|
self.multiscale_fusion = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 2, base_channels),
|
||||||
|
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||||
|
ConvBlock(base_channels, base_channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output with residual connection to input
|
||||||
|
self.pre_output = nn.Sequential(
|
||||||
|
ConvBlock(base_channels, base_channels),
|
||||||
|
ConvBlock(base_channels, base_channels // 2)
|
||||||
|
)
|
||||||
|
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||||
nn.LeakyReLU(0.1, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
nn.Conv2d(base_channels // 2, 3, 1),
|
||||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply weight initialization
|
# Apply weight initialization
|
||||||
self.apply(init_weights)
|
self.apply(init_weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Initial convolution
|
# Split input into image and mask
|
||||||
x0 = self.init_conv(x)
|
image = x[:, :3, :, :]
|
||||||
|
mask = x[:, 3:4, :, :]
|
||||||
|
|
||||||
|
# Process mask and image separately
|
||||||
|
mask_features = self.mask_conv(mask)
|
||||||
|
image_features = self.image_conv(image)
|
||||||
|
|
||||||
|
# Fuse features
|
||||||
|
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||||
|
|
||||||
# Encoder
|
# Encoder
|
||||||
x1, skip1 = self.down1(x0)
|
x1, skip1 = self.down1(x0)
|
||||||
x2, skip2 = self.down2(x1)
|
x2, skip2 = self.down2(x1)
|
||||||
x3, skip3 = self.down3(x2)
|
x3, skip3 = self.down3(x2)
|
||||||
x4, skip4 = self.down4(x3)
|
|
||||||
|
|
||||||
# Bottleneck
|
# Bottleneck
|
||||||
x = self.bottleneck(x4)
|
x = self.bottleneck(x3)
|
||||||
|
|
||||||
# Decoder with skip connections
|
# Decoder with skip connections
|
||||||
x = self.up1(x, skip4)
|
x = self.up1(x, skip3)
|
||||||
x = self.up2(x, skip3)
|
x = self.up2(x, skip2)
|
||||||
x = self.up3(x, skip2)
|
x = self.up3(x, skip1)
|
||||||
x = self.up4(x, skip1)
|
|
||||||
|
|
||||||
# Handle dimension mismatch for final concatenation
|
# Handle dimension mismatch for final fusion
|
||||||
if x.shape[2:] != x0.shape[2:]:
|
if x.shape[2:] != x0.shape[2:]:
|
||||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
# Concatenate with initial features for better detail preservation
|
# Multi-scale fusion with initial features
|
||||||
x = torch.cat([x, x0], dim=1)
|
x = torch.cat([x, x0], dim=1)
|
||||||
x = self.final_conv(x)
|
x = self.multiscale_fusion(x)
|
||||||
|
|
||||||
# Output
|
# Pre-output processing
|
||||||
|
x = self.pre_output(x)
|
||||||
|
|
||||||
|
# Concatenate with original masked image for residual learning
|
||||||
|
x = torch.cat([x, image], dim=1)
|
||||||
x = self.output(x)
|
x = self.output(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -32,56 +32,55 @@ def resize(img: Image):
|
|||||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||||
])
|
])
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
|
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
return input_array
|
return input_array
|
||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
def augment_image(img: Image, strength: float = 0.7) -> Image:
|
||||||
"""
|
"""Apply comprehensive data augmentation for better generalization"""
|
||||||
Dataset class for loading images from a folder with data augmentation
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, datafolder: str, augment: bool = True):
|
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
|
||||||
self.augment = augment
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.imagefiles)
|
|
||||||
|
|
||||||
def augment_image(self, image: Image) -> Image:
|
|
||||||
"""Apply random augmentations to image"""
|
|
||||||
# Random horizontal flip
|
# Random horizontal flip
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
# Random vertical flip
|
# Random vertical flip
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
# Random rotation (90, 180, 270 degrees)
|
# Random rotation (90, 180, 270 degrees)
|
||||||
if random.random() > 0.5:
|
if random.random() > 0.5:
|
||||||
angle = random.choice([90, 180, 270])
|
angle = random.choice([90, 180, 270])
|
||||||
image = image.rotate(angle)
|
img = img.rotate(angle)
|
||||||
|
|
||||||
# Random brightness adjustment
|
# Color augmentation - more aggressive for long training
|
||||||
if random.random() > 0.5:
|
rand = random.random()
|
||||||
enhancer = ImageEnhance.Brightness(image)
|
if rand > 0.75:
|
||||||
factor = random.uniform(0.8, 1.2)
|
# Brightness
|
||||||
image = enhancer.enhance(factor)
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||||
|
elif rand > 0.5:
|
||||||
|
# Contrast
|
||||||
|
factor = 1.0 + random.uniform(-0.2, 0.2) * strength
|
||||||
|
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||||
|
elif rand > 0.25:
|
||||||
|
# Saturation
|
||||||
|
factor = 1.0 + random.uniform(-0.15, 0.15) * strength
|
||||||
|
img = ImageEnhance.Color(img).enhance(factor)
|
||||||
|
|
||||||
# Random contrast adjustment
|
return img
|
||||||
if random.random() > 0.5:
|
|
||||||
enhancer = ImageEnhance.Contrast(image)
|
|
||||||
factor = random.uniform(0.8, 1.2)
|
|
||||||
image = enhancer.enhance(factor)
|
|
||||||
|
|
||||||
# Random color adjustment
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
if random.random() > 0.5:
|
"""
|
||||||
enhancer = ImageEnhance.Color(image)
|
Dataset class for loading images from a folder with augmentation support
|
||||||
factor = random.uniform(0.8, 1.2)
|
"""
|
||||||
image = enhancer.enhance(factor)
|
|
||||||
|
|
||||||
return image
|
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.7):
|
||||||
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
|
self.augment = augment
|
||||||
|
self.augment_strength = augment_strength
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imagefiles)
|
||||||
|
|
||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx:int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
@@ -89,18 +88,16 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
image = Image.open(self.imagefiles[index])
|
image = Image.open(self.imagefiles[index])
|
||||||
image = resize(image)
|
image = resize(image)
|
||||||
|
|
||||||
# Apply augmentation if enabled
|
# Apply augmentation
|
||||||
if self.augment:
|
if self.augment:
|
||||||
image = self.augment_image(image)
|
image = augment_image(image, self.augment_strength)
|
||||||
|
|
||||||
image = np.asarray(image)
|
image = np.asarray(image)
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
spacing_x = random.randint(2,6)
|
||||||
# Vary spacing and offset more for additional diversity
|
spacing_y = random.randint(2,6)
|
||||||
spacing_x = random.randint(2,7)
|
offset_x = random.randint(0,8)
|
||||||
spacing_y = random.randint(2,7)
|
offset_y = random.randint(0,8)
|
||||||
offset_x = random.randint(0,10)
|
|
||||||
offset_y = random.randint(0,10)
|
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
|
|||||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 2e-4 # Slightly lower for more stable training
|
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||||
config_dict['weight_decay'] = 5e-5 # Reduced for less aggressive regularization
|
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
config_dict['n_updates'] = 40000 # Extended training
|
||||||
config_dict['batchsize'] = 12 # Larger batch for more stable gradients
|
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 50
|
||||||
config_dict['print_stats_at'] = 100
|
config_dict['print_stats_at'] = 200
|
||||||
config_dict['plot_at'] = 400
|
config_dict['plot_at'] = 500
|
||||||
config_dict['validate_at'] = 200 # Validate frequently but not too often
|
config_dict['validate_at'] = 500 # Regular validation
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 32, # Smaller base for efficiency, depth compensates
|
'base_channels': 64,
|
||||||
'dropout': 0.15 # Slightly more regularization with augmentation
|
'dropout': 0.1 # Proper dropout for regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,49 +10,36 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import Subset
|
from torch.utils.data import Subset
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class EnhancedRMSELoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + Edge-aware component for better reconstruction"""
|
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||||
def __init__(self, mse_weight=0.7, l1_weight=0.8, edge_weight=0.2):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
|
||||||
self.l1_weight = l1_weight
|
|
||||||
self.edge_weight = edge_weight
|
|
||||||
self.mse = nn.MSELoss()
|
|
||||||
self.l1 = nn.L1Loss()
|
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
|
||||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
|
||||||
|
|
||||||
def edge_loss(self, pred, target):
|
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
|
||||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
|
||||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
|
||||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
|
||||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
|
||||||
return edge_loss
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
mse_loss = self.mse(pred, target)
|
# Compute per-pixel squared error
|
||||||
l1_loss = self.l1(pred, target)
|
se = (pred - target) ** 2
|
||||||
edge_loss = self.edge_loss(pred, target)
|
|
||||||
|
|
||||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
# Weight edges more heavily for sharper results
|
||||||
return total_loss
|
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||||
|
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||||
|
|
||||||
|
# Apply weighting
|
||||||
|
weighted_se = se * edge_weight
|
||||||
|
|
||||||
|
# Compute RMSE
|
||||||
|
mse = weighted_se.mean()
|
||||||
|
rmse = torch.sqrt(mse + 1e-8)
|
||||||
|
return rmse
|
||||||
|
|
||||||
|
|
||||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||||
@@ -68,6 +55,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# Enable mixed precision training for memory efficiency
|
||||||
|
use_amp = torch.cuda.is_available()
|
||||||
|
scaler = torch.amp.GradScaler('cuda') if use_amp else None
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.login()
|
wandb.login()
|
||||||
wandb.init(project="image_inpainting", config={
|
wandb.init(project="image_inpainting", config={
|
||||||
@@ -84,24 +75,20 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
plotpath = os.path.join(results_path, "plots")
|
plotpath = os.path.join(results_path, "plots")
|
||||||
os.makedirs(plotpath, exist_ok=True)
|
os.makedirs(plotpath, exist_ok=True)
|
||||||
|
|
||||||
# Create dataset with augmentation for training, without for validation/test
|
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||||
image_dataset_full = datasets.ImageDataset(datafolder=data_path, augment=False)
|
|
||||||
|
|
||||||
n_total = len(image_dataset_full)
|
n_total = len(image_dataset)
|
||||||
n_test = int(n_total * testset_ratio)
|
n_test = int(n_total * testset_ratio)
|
||||||
n_valid = int(n_total * validset_ratio)
|
n_valid = int(n_total * validset_ratio)
|
||||||
n_train = n_total - n_test - n_valid
|
n_train = n_total - n_test - n_valid
|
||||||
indices = np.random.permutation(n_total)
|
indices = np.random.permutation(n_total)
|
||||||
|
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||||
|
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||||
|
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||||
|
|
||||||
# Create augmented dataset for training
|
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||||
image_dataset_train = datasets.ImageDataset(datafolder=data_path, augment=True)
|
|
||||||
dataset_train = Subset(image_dataset_train, indices=indices[0:n_train])
|
|
||||||
dataset_valid = Subset(image_dataset_full, indices=indices[n_train:n_train + n_valid])
|
|
||||||
dataset_test = Subset(image_dataset_full, indices=indices[n_train + n_valid:n_total])
|
|
||||||
|
|
||||||
assert n_total == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
del image_dataset
|
||||||
|
|
||||||
del image_dataset_full, image_dataset_train
|
|
||||||
|
|
||||||
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
||||||
num_workers=0, shuffle=True)
|
num_workers=0, shuffle=True)
|
||||||
@@ -115,19 +102,17 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss with optimized weights
|
# defining the loss - Enhanced RMSE for sharper predictions
|
||||||
combined_loss = CombinedLoss(mse_weight=0.7, l1_weight=0.8, edge_weight=0.2).to(device)
|
rmse_loss = EnhancedRMSELoss().to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
|
||||||
# Learning rate scheduler with better configuration
|
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=1e-7)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||||
# Mixed precision training for faster computation and lower memory usage
|
)
|
||||||
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
|
||||||
use_amp = scaler is not None
|
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||||
@@ -136,13 +121,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
counter = 0
|
counter = 0
|
||||||
best_validation_loss = np.inf
|
best_validation_loss = np.inf
|
||||||
loss_list = []
|
loss_list = []
|
||||||
accumulation_steps = 2 # Gradient accumulation for effective larger batch size
|
|
||||||
|
|
||||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||||
|
|
||||||
print(f"Started training on device {device}")
|
print(f"Started training on device {device}")
|
||||||
print(f"Using mixed precision: {use_amp}")
|
|
||||||
print(f"Gradient accumulation steps: {accumulation_steps}")
|
|
||||||
|
|
||||||
while i < n_updates:
|
while i < n_updates:
|
||||||
|
|
||||||
@@ -153,33 +135,35 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
if (i + 1) % print_train_stats_at == 0:
|
if (i + 1) % print_train_stats_at == 0:
|
||||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||||
|
|
||||||
# Use mixed precision if available
|
optimizer.zero_grad()
|
||||||
if use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output = network(input)
|
|
||||||
loss = combined_loss(output, target)
|
|
||||||
loss = loss / accumulation_steps
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
else:
|
|
||||||
output = network(input)
|
|
||||||
loss = combined_loss(output, target)
|
|
||||||
loss = loss / accumulation_steps
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Gradient accumulation - update weights every accumulation_steps
|
# Mixed precision training for memory efficiency
|
||||||
if (i + 1) % accumulation_steps == 0:
|
|
||||||
if use_amp:
|
if use_amp:
|
||||||
|
with torch.amp.autocast('cuda'):
|
||||||
|
output = network(input)
|
||||||
|
loss = rmse_loss(output, target)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# Gradient clipping for training stability
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
output = network(input)
|
||||||
optimizer.step()
|
loss = rmse_loss(output, target)
|
||||||
optimizer.zero_grad()
|
loss.backward()
|
||||||
scheduler.step(i / n_updates)
|
|
||||||
|
|
||||||
loss_list.append(loss.item() * accumulation_steps)
|
# Gradient clipping for training stability
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
loss_list.append(loss.item())
|
||||||
|
|
||||||
# writing the stats to wandb
|
# writing the stats to wandb
|
||||||
if use_wandb and (i+1) % print_stats_at == 0:
|
if use_wandb and (i+1) % print_stats_at == 0:
|
||||||
@@ -188,9 +172,11 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
# plotting
|
# plotting
|
||||||
if (i + 1) % plot_at == 0:
|
if (i + 1) % plot_at == 0:
|
||||||
print(f"Plotting images, current update {i + 1}")
|
print(f"Plotting images, current update {i + 1}")
|
||||||
# Convert to float32 for matplotlib compatibility (mixed precision may produce float16)
|
# Convert to float32 for matplotlib compatibility
|
||||||
plot(input.float().cpu().numpy(), target.detach().float().cpu().numpy(),
|
plot(input.float().cpu().numpy(),
|
||||||
output.detach().float().cpu().numpy(), plotpath, i)
|
target.detach().float().cpu().numpy(),
|
||||||
|
output.detach().float().cpu().numpy(),
|
||||||
|
plotpath, i)
|
||||||
|
|
||||||
# evaluating model every validate_at sample
|
# evaluating model every validate_at sample
|
||||||
if (i + 1) % validate_at == 0:
|
if (i + 1) % validate_at == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user