Compare commits
9 Commits
gpt-5.2
...
claude-son
| Author | SHA1 | Date | |
|---|---|---|---|
| 716feac20c | |||
| d979c200f9 | |||
| 4af674b79d | |||
| fd81f3ce2e | |||
| e9ee27bb56 | |||
| 1f859a3d71 | |||
| c00089a97d | |||
| 5545a2f0eb | |||
| 9bf3335da6 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -2,3 +2,6 @@ data/*
|
|||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
|
runtime_predictions.npz
|
||||||
|
results/runtime_config.json
|
||||||
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.6824.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.9248.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.2533.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-17.3305.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
@@ -20,28 +21,49 @@ def init_weights(m):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class GatedSkipConnection(nn.Module):
|
||||||
"""Channel attention module (squeeze-and-excitation style)"""
|
"""Gated skip connection for better feature fusion"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, up_channels, skip_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
# Project skip to match up_channels if they differ
|
||||||
|
if skip_channels != up_channels:
|
||||||
|
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
|
||||||
|
else:
|
||||||
|
self.skip_proj = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, skip):
|
||||||
|
skip_proj = self.skip_proj(skip)
|
||||||
|
combined = torch.cat([x, skip], dim=1)
|
||||||
|
gate = self.gate(combined)
|
||||||
|
return x * gate + skip_proj * (1 - gate)
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientChannelAttention(nn.Module):
|
||||||
|
"""Efficient channel attention without dimensionality reduction"""
|
||||||
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
|
||||||
reduced = max(channels // reduction, 8)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, reduced, 1, bias=False),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(reduced, channels, 1, bias=False)
|
|
||||||
)
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
avg_out = self.fc(self.avg_pool(x))
|
# Global pooling
|
||||||
max_out = self.fc(self.max_pool(x))
|
y = self.avg_pool(x)
|
||||||
return x * self.sigmoid(avg_out + max_out)
|
# 1D convolution on channel dimension - add safety checks
|
||||||
|
if y.size(-1) == 1 and y.size(-2) == 1:
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
y = self.sigmoid(y)
|
||||||
|
y = torch.clamp(y, min=0.0, max=1.0) # Ensure valid range
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpatialAttention(nn.Module):
|
class SpatialAttention(nn.Module):
|
||||||
"""Spatial attention module"""
|
"""Efficient spatial attention module"""
|
||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
@@ -55,12 +77,12 @@ class SpatialAttention(nn.Module):
|
|||||||
return x * attn
|
return x * attn
|
||||||
|
|
||||||
|
|
||||||
class CBAM(nn.Module):
|
class EfficientAttention(nn.Module):
|
||||||
"""Convolutional Block Attention Module"""
|
"""Lightweight attention module combining channel and spatial"""
|
||||||
def __init__(self, channels, reduction=16):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channel_attn = ChannelAttention(channels, reduction)
|
self.channel_attn = EfficientChannelAttention(channels)
|
||||||
self.spatial_attn = SpatialAttention()
|
self.spatial_attn = SpatialAttention(kernel_size=5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.channel_attn(x)
|
x = self.channel_attn(x)
|
||||||
@@ -68,157 +90,302 @@ class CBAM(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
"""Self-attention module for long-range dependencies"""
|
||||||
|
def __init__(self, in_channels, reduction=8):
|
||||||
|
super().__init__()
|
||||||
|
self.query = nn.Conv2d(in_channels, in_channels // reduction, 1)
|
||||||
|
self.key = nn.Conv2d(in_channels, in_channels // reduction, 1)
|
||||||
|
self.value = nn.Conv2d(in_channels, in_channels, 1)
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1))
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, C, H, W = x.size()
|
||||||
|
|
||||||
|
# Generate query, key, value
|
||||||
|
query = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
|
||||||
|
key = self.key(x).view(batch_size, -1, H * W)
|
||||||
|
value = self.value(x).view(batch_size, -1, H * W)
|
||||||
|
|
||||||
|
# Attention map with numerical stability
|
||||||
|
attention_logits = torch.bmm(query, key)
|
||||||
|
# Scale for numerical stability
|
||||||
|
attention_logits = attention_logits / math.sqrt(query.size(-1))
|
||||||
|
attention = self.softmax(attention_logits)
|
||||||
|
out = torch.bmm(value, attention.permute(0, 2, 1))
|
||||||
|
out = out.view(batch_size, C, H, W)
|
||||||
|
|
||||||
|
# Residual connection with learnable weight
|
||||||
|
out = self.gamma * out + x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
if separable and in_channels > 1:
|
||||||
self.bn = nn.BatchNorm2d(out_channels)
|
# Depthwise separable convolution for efficiency
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels),
|
||||||
|
nn.Conv2d(in_channels, out_channels, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||||
|
# Add momentum and eps for numerical stability
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5, track_running_stats=True)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||||
|
|
||||||
|
|
||||||
|
class DenseBlock(nn.Module):
|
||||||
|
"""Lightweight dense block for better gradient flow"""
|
||||||
|
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
|
||||||
|
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
|
||||||
|
self.bn = nn.BatchNorm2d(channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
out = layer(torch.cat(features, dim=1))
|
||||||
|
features.append(out)
|
||||||
|
out = self.fusion(torch.cat(features, dim=1))
|
||||||
|
out = self.relu(self.bn(out))
|
||||||
|
return out + x # Residual connection
|
||||||
|
|
||||||
class ResidualConvBlock(nn.Module):
|
class ResidualConvBlock(nn.Module):
|
||||||
"""Residual convolutional block for better gradient flow"""
|
"""Improved residual convolutional block with pre-activation"""
|
||||||
def __init__(self, channels, dropout=0.0):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
|
||||||
self.bn1 = nn.BatchNorm2d(channels)
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.relu1 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = nn.BatchNorm2d(channels)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.relu2 = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
out = self.relu(self.bn1(self.conv1(x)))
|
out = self.relu1(self.bn1(x))
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.relu2(self.bn2(out))
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.bn2(self.conv2(out))
|
out = self.conv2(out)
|
||||||
out = out + residual
|
return out + residual
|
||||||
return self.relu(out)
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Enhanced downsampling block with dense and residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False, use_self_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
|
self.self_attention = SelfAttention(out_channels) if use_self_attention else nn.Identity()
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
skip = self.attention(x)
|
x = self.attention(x)
|
||||||
|
skip = self.self_attention(x)
|
||||||
return self.pool(skip), skip
|
return self.pool(skip), skip
|
||||||
|
|
||||||
class UpBlock(nn.Module):
|
class UpBlock(nn.Module):
|
||||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
"""Enhanced upsampling block with gated skip connections"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False, use_self_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
# Skip connection has in_channels, upsampled has out_channels
|
||||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
self.gated_skip = GatedSkipConnection(out_channels, in_channels)
|
||||||
|
# After gated skip: out_channels
|
||||||
|
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
|
||||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
if use_dense:
|
||||||
self.attention = CBAM(out_channels)
|
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout)
|
||||||
|
else:
|
||||||
|
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
|
||||||
|
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
|
||||||
|
self.self_attention = SelfAttention(out_channels) if use_self_attention else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, skip):
|
def forward(self, x, skip):
|
||||||
x = self.up(x)
|
x = self.up(x)
|
||||||
# Handle dimension mismatch by interpolating x to match skip's size
|
# Handle dimension mismatch
|
||||||
if x.shape[2:] != skip.shape[2:]:
|
if x.shape[2:] != skip.shape[2:]:
|
||||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||||
x = torch.cat([x, skip], dim=1)
|
x = self.gated_skip(x, skip)
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.residual(x)
|
x = self.dense(x)
|
||||||
x = self.attention(x)
|
x = self.attention(x)
|
||||||
|
x = self.self_attention(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class MyModel(nn.Module):
|
class MyModel(nn.Module):
|
||||||
"""Improved U-Net style architecture for image inpainting with attention and residual connections"""
|
"""Enhanced U-Net architecture with dense connections and efficient attention"""
|
||||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initial convolution with larger receptive field
|
# Separate mask processing for better feature extraction
|
||||||
self.init_conv = nn.Sequential(
|
# Separate mask processing for better feature extraction
|
||||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
self.mask_conv = nn.Sequential(
|
||||||
ConvBlock(base_channels, base_channels),
|
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||||
ResidualConvBlock(base_channels)
|
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encoder (downsampling path)
|
# Image processing path
|
||||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
self.image_conv = nn.Sequential(
|
||||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
ConvBlock(3, base_channels, kernel_size=5, padding=2),
|
||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
|
||||||
|
|
||||||
# Bottleneck with multiple residual blocks
|
|
||||||
self.bottleneck = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
|
||||||
CBAM(base_channels * 16)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decoder (upsampling path)
|
|
||||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
|
||||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
|
||||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
|
||||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
|
||||||
|
|
||||||
# Final refinement layers
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
ConvBlock(base_channels * 2, base_channels),
|
|
||||||
ResidualConvBlock(base_channels),
|
|
||||||
ConvBlock(base_channels, base_channels)
|
ConvBlock(base_channels, base_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output layer with smooth transition
|
# Fusion of mask and image features
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||||
|
nn.BatchNorm2d(base_channels, momentum=0.1, eps=1e-5, track_running_stats=True),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encoder with progressive feature extraction
|
||||||
|
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout, use_attention=False, use_dense=False)
|
||||||
|
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True, use_self_attention=True)
|
||||||
|
|
||||||
|
# Enhanced bottleneck with multi-scale features, dense connections, and self-attention
|
||||||
|
self.bottleneck = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
|
||||||
|
DenseBlock(base_channels * 8, growth_rate=12, num_layers=3, dropout=dropout),
|
||||||
|
SelfAttention(base_channels * 8, reduction=4),
|
||||||
|
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
|
||||||
|
ResidualConvBlock(base_channels * 8, dropout=dropout),
|
||||||
|
EfficientAttention(base_channels * 8)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decoder with progressive reconstruction
|
||||||
|
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True, use_self_attention=True)
|
||||||
|
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout, use_attention=False, use_dense=False)
|
||||||
|
|
||||||
|
# Multi-scale feature fusion with dense connections
|
||||||
|
self.multiscale_fusion = nn.Sequential(
|
||||||
|
ConvBlock(base_channels * 2, base_channels),
|
||||||
|
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
|
||||||
|
ConvBlock(base_channels, base_channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output with residual connection to input
|
||||||
|
self.pre_output = nn.Sequential(
|
||||||
|
ConvBlock(base_channels, base_channels),
|
||||||
|
ConvBlock(base_channels, base_channels // 2)
|
||||||
|
)
|
||||||
|
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1),
|
||||||
nn.LeakyReLU(0.1, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
nn.Conv2d(base_channels // 2, 3, 1),
|
||||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply weight initialization
|
# Apply weight initialization
|
||||||
self.apply(init_weights)
|
self.apply(init_weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Initial convolution
|
# Split input into image and mask
|
||||||
x0 = self.init_conv(x)
|
image = x[:, :3, :, :]
|
||||||
|
mask = x[:, 3:4, :, :]
|
||||||
|
|
||||||
|
# Clamp inputs to valid range
|
||||||
|
image = torch.clamp(image, 0.0, 1.0)
|
||||||
|
mask = torch.clamp(mask, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Process mask and image separately
|
||||||
|
mask_features = self.mask_conv(mask)
|
||||||
|
image_features = self.image_conv(image)
|
||||||
|
|
||||||
|
# Safety check after initial processing
|
||||||
|
if not torch.isfinite(mask_features).all():
|
||||||
|
mask_features = torch.nan_to_num(mask_features, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
if not torch.isfinite(image_features).all():
|
||||||
|
image_features = torch.nan_to_num(image_features, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
|
# Fuse features
|
||||||
|
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||||
|
if not torch.isfinite(x0).all():
|
||||||
|
x0 = torch.nan_to_num(x0, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
# Encoder
|
# Encoder
|
||||||
x1, skip1 = self.down1(x0)
|
x1, skip1 = self.down1(x0)
|
||||||
|
if not torch.isfinite(x1).all():
|
||||||
|
x1 = torch.nan_to_num(x1, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
skip1 = torch.nan_to_num(skip1, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
x2, skip2 = self.down2(x1)
|
x2, skip2 = self.down2(x1)
|
||||||
|
if not torch.isfinite(x2).all():
|
||||||
|
x2 = torch.nan_to_num(x2, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
skip2 = torch.nan_to_num(skip2, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
x3, skip3 = self.down3(x2)
|
x3, skip3 = self.down3(x2)
|
||||||
x4, skip4 = self.down4(x3)
|
if not torch.isfinite(x3).all():
|
||||||
|
x3 = torch.nan_to_num(x3, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
skip3 = torch.nan_to_num(skip3, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
# Bottleneck
|
# Bottleneck
|
||||||
x = self.bottleneck(x4)
|
x = self.bottleneck(x3)
|
||||||
|
if not torch.isfinite(x).all():
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
# Decoder with skip connections
|
# Decoder with skip connections
|
||||||
x = self.up1(x, skip4)
|
x = self.up1(x, skip3)
|
||||||
x = self.up2(x, skip3)
|
if not torch.isfinite(x).all():
|
||||||
x = self.up3(x, skip2)
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
x = self.up4(x, skip1)
|
|
||||||
|
|
||||||
# Handle dimension mismatch for final concatenation
|
x = self.up2(x, skip2)
|
||||||
|
if not torch.isfinite(x).all():
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
|
x = self.up3(x, skip1)
|
||||||
|
if not torch.isfinite(x).all():
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
|
# Handle dimension mismatch for final fusion
|
||||||
if x.shape[2:] != x0.shape[2:]:
|
if x.shape[2:] != x0.shape[2:]:
|
||||||
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
# Concatenate with initial features for better detail preservation
|
# Multi-scale fusion with initial features
|
||||||
x = torch.cat([x, x0], dim=1)
|
x = torch.cat([x, x0], dim=1)
|
||||||
x = self.final_conv(x)
|
x = self.multiscale_fusion(x)
|
||||||
|
if not torch.isfinite(x).all():
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
# Output
|
# Pre-output processing
|
||||||
|
x = self.pre_output(x)
|
||||||
|
if not torch.isfinite(x).all():
|
||||||
|
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||||
|
|
||||||
|
# Concatenate with original masked image for residual learning
|
||||||
|
x = torch.cat([x, image], dim=1)
|
||||||
x = self.output(x)
|
x = self.output(x)
|
||||||
|
|
||||||
|
# Final safety clamp
|
||||||
|
x = torch.clamp(x, 0.0, 1.0)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -10,7 +10,8 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image, ImageEnhance, ImageFilter
|
||||||
|
from scipy.ndimage import gaussian_filter, map_coordinates
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
@@ -32,17 +33,124 @@ def resize(img: Image):
|
|||||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||||
])
|
])
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
|
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
return input_array
|
return input_array
|
||||||
|
|
||||||
|
def elastic_transform(image: np.ndarray, alpha: float = 20, sigma: float = 4) -> np.ndarray:
|
||||||
|
"""Apply elastic deformation to image array"""
|
||||||
|
shape = image.shape[:2]
|
||||||
|
dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
|
||||||
|
dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
|
||||||
|
|
||||||
|
x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
|
||||||
|
indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
|
||||||
|
|
||||||
|
# Apply to each channel
|
||||||
|
transformed = np.zeros_like(image)
|
||||||
|
for i in range(image.shape[2]):
|
||||||
|
transformed[:, :, i] = map_coordinates(image[:, :, i], indices, order=1, mode='reflect').reshape(shape)
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def add_noise(img_array: np.ndarray, noise_type: str = 'gaussian', strength: float = 0.02) -> np.ndarray:
|
||||||
|
"""Add various types of noise to image"""
|
||||||
|
if noise_type == 'gaussian':
|
||||||
|
noise = np.random.normal(0, strength, img_array.shape)
|
||||||
|
noisy = img_array + noise
|
||||||
|
elif noise_type == 'salt_pepper':
|
||||||
|
noisy = img_array.copy()
|
||||||
|
# Salt
|
||||||
|
num_salt = int(strength * img_array.size * 0.5)
|
||||||
|
coords = [np.random.randint(0, i, num_salt) for i in img_array.shape]
|
||||||
|
noisy[coords[0], coords[1], :] = 1
|
||||||
|
# Pepper
|
||||||
|
num_pepper = int(strength * img_array.size * 0.5)
|
||||||
|
coords = [np.random.randint(0, i, num_pepper) for i in img_array.shape]
|
||||||
|
noisy[coords[0], coords[1], :] = 0
|
||||||
|
else:
|
||||||
|
noisy = img_array
|
||||||
|
|
||||||
|
return np.clip(noisy, 0, 1)
|
||||||
|
|
||||||
|
def augment_image(img: Image, strength: float = 0.8) -> Image:
|
||||||
|
"""Apply comprehensive data augmentation for better generalization"""
|
||||||
|
# Random horizontal flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
# Random vertical flip
|
||||||
|
if random.random() > 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
# Random rotation (90, 180, 270 degrees, or small angles)
|
||||||
|
if random.random() > 0.5:
|
||||||
|
if random.random() > 0.7:
|
||||||
|
# Large rotation
|
||||||
|
angle = random.choice([90, 180, 270])
|
||||||
|
img = img.rotate(angle)
|
||||||
|
else:
|
||||||
|
# Small rotation for more variation
|
||||||
|
angle = random.uniform(-15, 15)
|
||||||
|
img = img.rotate(angle, fillcolor=(128, 128, 128))
|
||||||
|
|
||||||
|
# More aggressive color augmentation
|
||||||
|
if random.random() > 0.3:
|
||||||
|
# Brightness
|
||||||
|
factor = 1.0 + random.uniform(-0.3, 0.3) * strength
|
||||||
|
img = ImageEnhance.Brightness(img).enhance(factor)
|
||||||
|
|
||||||
|
if random.random() > 0.3:
|
||||||
|
# Contrast
|
||||||
|
factor = 1.0 + random.uniform(-0.3, 0.3) * strength
|
||||||
|
img = ImageEnhance.Contrast(img).enhance(factor)
|
||||||
|
|
||||||
|
if random.random() > 0.3:
|
||||||
|
# Saturation
|
||||||
|
factor = 1.0 + random.uniform(-0.25, 0.25) * strength
|
||||||
|
img = ImageEnhance.Color(img).enhance(factor)
|
||||||
|
|
||||||
|
if random.random() > 0.7:
|
||||||
|
# Sharpness
|
||||||
|
factor = 1.0 + random.uniform(-0.3, 0.5) * strength
|
||||||
|
img = ImageEnhance.Sharpness(img).enhance(factor)
|
||||||
|
|
||||||
|
# Gaussian blur for robustness
|
||||||
|
if random.random() > 0.8:
|
||||||
|
radius = random.uniform(0.5, 1.5) * strength
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=radius))
|
||||||
|
|
||||||
|
# Convert to array for elastic transform and noise
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
# Elastic deformation
|
||||||
|
if random.random() > 0.7:
|
||||||
|
alpha = random.uniform(15, 30) * strength
|
||||||
|
sigma = random.uniform(3, 5)
|
||||||
|
img_array = elastic_transform(img_array, alpha=alpha, sigma=sigma)
|
||||||
|
|
||||||
|
# Add noise
|
||||||
|
if random.random() > 0.6:
|
||||||
|
noise_type = random.choice(['gaussian', 'salt_pepper'])
|
||||||
|
noise_strength = random.uniform(0.01, 0.03) * strength
|
||||||
|
img_array = add_noise(img_array, noise_type=noise_type, strength=noise_strength)
|
||||||
|
|
||||||
|
# Convert back to PIL Image
|
||||||
|
img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(img_array)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder
|
Dataset class for loading images from a folder with augmentation support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str):
|
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.8):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
|
self.augment = augment
|
||||||
|
self.augment_strength = augment_strength
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
@@ -51,7 +159,13 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index])
|
image = Image.open(self.imagefiles[index])
|
||||||
image = np.asarray(resize(image))
|
image = resize(image)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
if self.augment:
|
||||||
|
image = augment_image(image, self.augment_strength)
|
||||||
|
|
||||||
|
image = np.asarray(image)
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
spacing_x = random.randint(2,6)
|
spacing_x = random.randint(2,6)
|
||||||
spacing_y = random.randint(2,6)
|
spacing_y = random.randint(2,6)
|
||||||
|
|||||||
@@ -17,46 +17,74 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
config_dict['seed'] = 42
|
config_dict['seed'] = 42
|
||||||
config_dict['testset_ratio'] = 0.1
|
config_dict['testset_ratio'] = 0.1
|
||||||
config_dict['validset_ratio'] = 0.1
|
config_dict['validset_ratio'] = 0.05
|
||||||
# Get the absolute path based on the script's location
|
# Get the absolute path based on the script's location
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
project_root = os.path.dirname(script_dir)
|
project_root = os.path.dirname(script_dir)
|
||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
|
||||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
|
||||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
config_dict['n_updates'] = 12000 # Extended training for better convergence
|
||||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
|
||||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
config_dict['early_stopping_patience'] = 20 # More patience for complex model
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 10
|
||||||
config_dict['print_stats_at'] = 100
|
config_dict['print_stats_at'] = 200
|
||||||
config_dict['plot_at'] = 300
|
config_dict['plot_at'] = 500
|
||||||
config_dict['validate_at'] = 300 # Validate more frequently
|
config_dict['validate_at'] = 250 # More frequent validation
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 48, # Good balance between capacity and memory
|
'base_channels': 52, # Increased capacity for better feature extraction
|
||||||
'dropout': 0.1 # Regularization
|
'dropout': 0.15 # Slightly higher dropout for regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|
||||||
|
# Prepare paths for runtime predictions
|
||||||
|
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
|
||||||
|
save_path = os.path.join(config_dict['results_path'], "runtime_predictions")
|
||||||
|
plot_path_predictions = os.path.join(config_dict['results_path'], "runtime_predictions", "plots")
|
||||||
|
|
||||||
|
config_dict['testset_path'] = testset_path
|
||||||
|
config_dict['save_path'] = save_path
|
||||||
|
config_dict['plot_path_predictions'] = plot_path_predictions
|
||||||
|
|
||||||
|
print("="*60)
|
||||||
|
print("RUNTIME CONFIGURATION ENABLED")
|
||||||
|
print("="*60)
|
||||||
|
print("During training, you can modify these parameters by editing:")
|
||||||
|
print(f"{os.path.join(config_dict['results_path'], 'runtime_config.json')}")
|
||||||
|
print("\nModifiable parameters:")
|
||||||
|
print(" - n_updates: Maximum training steps")
|
||||||
|
print(" - plot_at: How often to save plots")
|
||||||
|
print(" - early_stopping_patience: Patience for early stopping")
|
||||||
|
print(" - print_stats_at: How often to print detailed stats")
|
||||||
|
print(" - print_train_stats_at: How often to print training loss")
|
||||||
|
print(" - validate_at: How often to run validation")
|
||||||
|
print("\nRuntime commands (set to true to execute):")
|
||||||
|
print(" - save_checkpoint: Save model at current step")
|
||||||
|
print(" - run_test_validation: Run validation on final test set")
|
||||||
|
print(" - generate_predictions: Generate predictions on challenge testset")
|
||||||
|
print("\nChanges will be applied within 5 steps.")
|
||||||
|
print("="*60)
|
||||||
|
print()
|
||||||
|
|
||||||
rmse_value = train(**config_dict)
|
rmse_value = train(**config_dict)
|
||||||
|
|
||||||
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
|
|
||||||
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
||||||
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
|
final_save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
|
||||||
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
final_plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
||||||
os.makedirs(plot_path, exist_ok=True)
|
os.makedirs(final_plot_path, exist_ok=True)
|
||||||
for name in os.listdir(plot_path):
|
for name in os.listdir(final_plot_path):
|
||||||
p = os.path.join(plot_path, name)
|
p = os.path.join(final_plot_path, name)
|
||||||
if os.path.isfile(p) or os.path.islink(p):
|
if os.path.isfile(p) or os.path.islink(p):
|
||||||
os.unlink(p)
|
os.unlink(p)
|
||||||
elif os.path.isdir(p):
|
elif os.path.isdir(p):
|
||||||
shutil.rmtree(p)
|
shutil.rmtree(p)
|
||||||
|
|
||||||
# Comment out, if predictions are required
|
# Comment out, if predictions are required
|
||||||
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)
|
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, final_save_path, final_plot_path, plot_at=20, rmse_value=rmse_value)
|
||||||
|
|||||||
@@ -12,52 +12,182 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
from torchvision import models
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data import Subset
|
from torch.utils.data import Subset
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
def load_runtime_config(config_path, current_params):
|
||||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
"""Load runtime configuration from JSON file and update parameters"""
|
||||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
try:
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
new_config = json.load(f)
|
||||||
|
|
||||||
|
# Update modifiable parameters
|
||||||
|
updated = False
|
||||||
|
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
|
||||||
|
'print_stats_at', 'print_train_stats_at', 'validate_at',
|
||||||
|
'learningrate', 'weight_decay']
|
||||||
|
|
||||||
|
for key in modifiable_keys:
|
||||||
|
if key in new_config and new_config[key] != current_params.get(key):
|
||||||
|
old_val = current_params.get(key)
|
||||||
|
current_params[key] = new_config[key]
|
||||||
|
print(f"\n[CONFIG UPDATE] {key}: {old_val} -> {new_config[key]}")
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
# Check for command flags
|
||||||
|
commands = new_config.get('commands', {})
|
||||||
|
current_params['commands'] = commands
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
print("[CONFIG UPDATE] Runtime configuration updated successfully!\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not load runtime config: {e}")
|
||||||
|
|
||||||
|
return current_params
|
||||||
|
|
||||||
|
|
||||||
|
def clear_command_flag(config_path, command_name):
|
||||||
|
"""Clear a specific command flag after execution"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
if 'commands' in config and command_name in config['commands']:
|
||||||
|
config['commands'][command_name] = False
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not clear command flag: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class RMSELoss(nn.Module):
|
||||||
|
"""RMSE loss for direct optimization of evaluation metric"""
|
||||||
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
|
||||||
self.l1_weight = l1_weight
|
|
||||||
self.edge_weight = edge_weight
|
|
||||||
self.mse = nn.MSELoss()
|
self.mse = nn.MSELoss()
|
||||||
self.l1 = nn.L1Loss()
|
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
|
||||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
|
||||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
|
||||||
|
|
||||||
def edge_loss(self, pred, target):
|
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
|
||||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
|
||||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
|
||||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
|
||||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
|
||||||
return edge_loss
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
mse_loss = self.mse(pred, target)
|
mse = self.mse(pred, target)
|
||||||
l1_loss = self.l1(pred, target)
|
# Larger epsilon for numerical stability
|
||||||
edge_loss = self.edge_loss(pred, target)
|
rmse = torch.sqrt(mse + 1e-6)
|
||||||
|
return rmse
|
||||||
|
|
||||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
|
||||||
return total_loss
|
class PerceptualLoss(nn.Module):
|
||||||
|
"""Perceptual loss using VGG16 features for better texture and detail preservation"""
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
# Load pre-trained VGG16 and use specific layers
|
||||||
|
vgg = models.vgg16(pretrained=True).features.to(device).eval()
|
||||||
|
# Freeze VGG parameters
|
||||||
|
for param in vgg.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Use early and middle layers for perceptual loss
|
||||||
|
self.slice1 = nn.Sequential(*list(vgg.children())[:4]) # relu1_2
|
||||||
|
self.slice2 = nn.Sequential(*list(vgg.children())[4:9]) # relu2_2
|
||||||
|
self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # relu3_3
|
||||||
|
|
||||||
|
# Normalization for VGG (ImageNet stats)
|
||||||
|
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||||
|
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||||
|
|
||||||
|
def normalize(self, x):
|
||||||
|
"""Normalize images for VGG with clamping for stability"""
|
||||||
|
# Clamp input to valid range
|
||||||
|
x = torch.clamp(x, 0.0, 1.0)
|
||||||
|
return (x - self.mean) / (self.std + 1e-8)
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
# Clamp inputs to prevent extreme values
|
||||||
|
pred = torch.clamp(pred, 0.0, 1.0)
|
||||||
|
target = torch.clamp(target, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Normalize inputs
|
||||||
|
pred = self.normalize(pred)
|
||||||
|
target = self.normalize(target)
|
||||||
|
|
||||||
|
# Extract features from multiple layers
|
||||||
|
pred_f1 = self.slice1(pred)
|
||||||
|
pred_f2 = self.slice2(pred_f1)
|
||||||
|
pred_f3 = self.slice3(pred_f2)
|
||||||
|
|
||||||
|
target_f1 = self.slice1(target)
|
||||||
|
target_f2 = self.slice2(target_f1)
|
||||||
|
target_f3 = self.slice3(target_f2)
|
||||||
|
|
||||||
|
# Compute losses at multiple scales
|
||||||
|
loss = F.l1_loss(pred_f1, target_f1) + \
|
||||||
|
F.l1_loss(pred_f2, target_f2) + \
|
||||||
|
F.l1_loss(pred_f3, target_f3)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedLoss(nn.Module):
|
||||||
|
"""Combined loss optimized for RMSE evaluation with optional perceptual component"""
|
||||||
|
def __init__(self, device, use_perceptual=True, perceptual_weight=0.05):
|
||||||
|
super().__init__()
|
||||||
|
self.use_perceptual = use_perceptual
|
||||||
|
if use_perceptual:
|
||||||
|
self.perceptual_loss = PerceptualLoss(device)
|
||||||
|
# Use MSE instead of RMSE for training (more stable gradients)
|
||||||
|
self.mse_loss = nn.MSELoss()
|
||||||
|
self.rmse_loss = RMSELoss() # For logging only
|
||||||
|
|
||||||
|
self.perceptual_weight = perceptual_weight
|
||||||
|
self.mse_weight = 1.0 - perceptual_weight
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
# Clamp predictions to valid range
|
||||||
|
pred = torch.clamp(pred, 0.0, 1.0)
|
||||||
|
target = torch.clamp(target, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Check for NaN in inputs
|
||||||
|
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
|
||||||
|
print("Warning: NaN detected in loss inputs")
|
||||||
|
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||||
|
|
||||||
|
# Primary loss: MSE (equivalent to RMSE but more stable)
|
||||||
|
mse = self.mse_loss(pred, target)
|
||||||
|
rmse = self.rmse_loss(pred, target) # For logging
|
||||||
|
|
||||||
|
if self.use_perceptual:
|
||||||
|
# Optional small perceptual component for texture quality
|
||||||
|
perceptual = self.perceptual_loss(pred, target)
|
||||||
|
# Check perceptual loss validity
|
||||||
|
if not torch.isfinite(perceptual):
|
||||||
|
perceptual = torch.tensor(0.0, device=pred.device)
|
||||||
|
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
|
||||||
|
else:
|
||||||
|
# Pure MSE optimization
|
||||||
|
perceptual = torch.tensor(0.0, device=pred.device)
|
||||||
|
total_loss = mse
|
||||||
|
|
||||||
|
# Validate loss is not NaN or Inf
|
||||||
|
if not torch.isfinite(total_loss):
|
||||||
|
# Return MSE only as fallback
|
||||||
|
total_loss = mse
|
||||||
|
if not torch.isfinite(total_loss):
|
||||||
|
print("Warning: MSE is NaN")
|
||||||
|
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||||
|
|
||||||
|
return total_loss, perceptual, mse, rmse
|
||||||
|
|
||||||
|
|
||||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||||
weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize,
|
weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize,
|
||||||
network_config: dict):
|
network_config: dict, testset_path=None, save_path=None, plot_path_predictions=None):
|
||||||
np.random.seed(seed=seed)
|
np.random.seed(seed=seed)
|
||||||
torch.manual_seed(seed=seed)
|
torch.manual_seed(seed=seed)
|
||||||
|
|
||||||
@@ -68,6 +198,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# Enable mixed precision training for memory efficiency
|
||||||
|
use_amp = torch.cuda.is_available()
|
||||||
|
if use_amp:
|
||||||
|
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
|
||||||
|
else:
|
||||||
|
scaler = None
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.login()
|
wandb.login()
|
||||||
wandb.init(project="image_inpainting", config={
|
wandb.init(project="image_inpainting", config={
|
||||||
@@ -111,15 +248,28 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss for better reconstruction
|
# defining the loss - Optimized for RMSE evaluation
|
||||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality
|
||||||
|
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
|
||||||
|
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999))
|
||||||
|
|
||||||
# Learning rate scheduler for better convergence
|
# Learning rate warmup
|
||||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
warmup_steps = min(1000, n_updates // 10)
|
||||||
|
|
||||||
|
# Cosine annealing with warm restarts for long training
|
||||||
|
scheduler_main = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup scheduler
|
||||||
|
def get_lr_scale(step):
|
||||||
|
if step < warmup_steps:
|
||||||
|
return step / warmup_steps
|
||||||
|
return 1.0
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||||
@@ -131,7 +281,31 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||||
|
|
||||||
|
# Save runtime configuration to JSON file for dynamic updates
|
||||||
|
config_json_path = os.path.join(results_path, "runtime_config.json")
|
||||||
|
runtime_params = {
|
||||||
|
'learningrate': learningrate,
|
||||||
|
'weight_decay': weight_decay,
|
||||||
|
'n_updates': n_updates,
|
||||||
|
'plot_at': plot_at,
|
||||||
|
'early_stopping_patience': early_stopping_patience,
|
||||||
|
'print_stats_at': print_stats_at,
|
||||||
|
'print_train_stats_at': print_train_stats_at,
|
||||||
|
'validate_at': validate_at,
|
||||||
|
'commands': {
|
||||||
|
'save_checkpoint': False,
|
||||||
|
'run_test_validation': False,
|
||||||
|
'generate_predictions': False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_json_path, 'w') as f:
|
||||||
|
json.dump(runtime_params, f, indent=2)
|
||||||
|
|
||||||
print(f"Started training on device {device}")
|
print(f"Started training on device {device}")
|
||||||
|
print(f"Runtime config saved to: {config_json_path}")
|
||||||
|
print(f"You can modify this file during training to change parameters dynamically!")
|
||||||
|
print(f"Set command flags to true to trigger actions (save_checkpoint, run_test_validation, generate_predictions)\n")
|
||||||
|
|
||||||
while i < n_updates:
|
while i < n_updates:
|
||||||
|
|
||||||
@@ -139,33 +313,191 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
input, target = input.to(device), target.to(device)
|
input, target = input.to(device), target.to(device)
|
||||||
|
|
||||||
|
# Check for runtime config updates every 5 steps
|
||||||
|
if i % 5 == 0 and i > 0:
|
||||||
|
runtime_params = load_runtime_config(config_json_path, runtime_params)
|
||||||
|
n_updates = runtime_params['n_updates']
|
||||||
|
plot_at = runtime_params['plot_at']
|
||||||
|
early_stopping_patience = runtime_params['early_stopping_patience']
|
||||||
|
print_stats_at = runtime_params['print_stats_at']
|
||||||
|
print_train_stats_at = runtime_params['print_train_stats_at']
|
||||||
|
validate_at = runtime_params['validate_at']
|
||||||
|
|
||||||
|
# Update optimizer parameters if changed
|
||||||
|
if 'learningrate' in runtime_params:
|
||||||
|
new_lr = runtime_params['learningrate']
|
||||||
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = new_lr
|
||||||
|
|
||||||
|
if 'weight_decay' in runtime_params:
|
||||||
|
new_wd = runtime_params['weight_decay']
|
||||||
|
current_wd = optimizer.param_groups[0]['weight_decay']
|
||||||
|
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['weight_decay'] = new_wd
|
||||||
|
|
||||||
|
# Execute runtime commands
|
||||||
|
commands = runtime_params.get('commands', {})
|
||||||
|
|
||||||
|
# Command: Save checkpoint
|
||||||
|
if commands.get('save_checkpoint', False):
|
||||||
|
checkpoint_path = os.path.join(results_path, f"checkpoint_step_{i}.pt")
|
||||||
|
torch.save(network.state_dict(), checkpoint_path)
|
||||||
|
print(f"\n[COMMAND] Checkpoint saved to: {checkpoint_path}\n")
|
||||||
|
clear_command_flag(config_json_path, 'save_checkpoint')
|
||||||
|
|
||||||
|
# Command: Generate predictions
|
||||||
|
if commands.get('generate_predictions', False) and testset_path is not None:
|
||||||
|
print(f"\n[COMMAND] Generating predictions at step {i}...")
|
||||||
|
try:
|
||||||
|
from utils import create_predictions
|
||||||
|
pred_save_path = save_path or os.path.join(results_path, "runtime_predictions", f"step_{i}")
|
||||||
|
pred_plot_path = plot_path_predictions or os.path.join(results_path, "runtime_predictions", "plots", f"step_{i}")
|
||||||
|
os.makedirs(pred_plot_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Save current state temporarily
|
||||||
|
temp_state_path = os.path.join(results_path, f"temp_state_step_{i}.pt")
|
||||||
|
torch.save(network.state_dict(), temp_state_path)
|
||||||
|
|
||||||
|
# Generate predictions
|
||||||
|
create_predictions(network_config, temp_state_path, testset_path, None,
|
||||||
|
pred_save_path, pred_plot_path, plot_at=20, rmse_value=None)
|
||||||
|
|
||||||
|
print(f"[COMMAND] Predictions saved to: {pred_save_path}")
|
||||||
|
print(f"[COMMAND] Plots saved to: {pred_plot_path}\n")
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if os.path.exists(temp_state_path):
|
||||||
|
os.remove(temp_state_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[COMMAND] Error generating predictions: {e}\n")
|
||||||
|
|
||||||
|
network.train()
|
||||||
|
clear_command_flag(config_json_path, 'generate_predictions')
|
||||||
|
|
||||||
|
# Command: Run test validation
|
||||||
|
if commands.get('run_test_validation', False):
|
||||||
|
print(f"\n[COMMAND] Running test set validation at step {i}...")
|
||||||
|
network.eval()
|
||||||
|
test_loss, test_rmse = evaluate_model(network, dataloader_test, mse_loss, device)
|
||||||
|
print(f"[COMMAND] Test Loss: {test_loss:.6f}, Test RMSE: {test_rmse:.6f}\n")
|
||||||
|
network.train()
|
||||||
|
clear_command_flag(config_json_path, 'run_test_validation')
|
||||||
|
|
||||||
if (i + 1) % print_train_stats_at == 0:
|
if (i + 1) % print_train_stats_at == 0:
|
||||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
output = network(input)
|
# Mixed precision training for memory efficiency
|
||||||
|
if use_amp:
|
||||||
|
with torch.amp.autocast('cuda'):
|
||||||
|
output = network(input)
|
||||||
|
total_loss, perceptual, mse, rmse = combined_loss(output, target)
|
||||||
|
|
||||||
loss = combined_loss(output, target)
|
# Check for NaN before backward
|
||||||
|
if not torch.isfinite(total_loss):
|
||||||
|
continue
|
||||||
|
|
||||||
loss.backward()
|
scaler.scale(total_loss).backward()
|
||||||
|
|
||||||
# Gradient clipping for training stability
|
# Unscale and check gradients
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
scaler.unscale_(optimizer)
|
||||||
|
|
||||||
optimizer.step()
|
# Check for NaN in gradients
|
||||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
has_nan = False
|
||||||
|
for name, param in network.named_parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
if not torch.isfinite(param.grad).all():
|
||||||
|
print(f"NaN gradient detected in {name}")
|
||||||
|
has_nan = True
|
||||||
|
break
|
||||||
|
|
||||||
loss_list.append(loss.item())
|
if has_nan:
|
||||||
|
print(f"Skipping step {i+1}: NaN gradients detected")
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.update()
|
||||||
|
# Reset scaler if NaN persists
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# More aggressive gradient clipping for stability
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
# Skip update if gradient norm is too large
|
||||||
|
if grad_norm > 100.0:
|
||||||
|
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
|
||||||
|
optimizer.zero_grad()
|
||||||
|
scaler.update()
|
||||||
|
continue
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
output = network(input)
|
||||||
|
total_loss, perceptual, mse, rmse = combined_loss(output, target)
|
||||||
|
|
||||||
|
# Check for NaN before backward
|
||||||
|
if not torch.isfinite(total_loss):
|
||||||
|
print(f"Skipping step {i+1}: NaN or Inf loss detected")
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_loss.backward()
|
||||||
|
|
||||||
|
# Check for NaN in gradients
|
||||||
|
has_nan = False
|
||||||
|
for name, param in network.named_parameters():
|
||||||
|
if param.grad is not None and not torch.isfinite(param.grad).all():
|
||||||
|
print(f"NaN gradient detected in {name}")
|
||||||
|
has_nan = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_nan:
|
||||||
|
print(f"Skipping step {i+1}: NaN gradients detected")
|
||||||
|
optimizer.zero_grad()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# More aggressive gradient clipping
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
if grad_norm > 100.0:
|
||||||
|
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
|
||||||
|
optimizer.zero_grad()
|
||||||
|
continue
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Apply learning rate scheduling with warmup
|
||||||
|
lr_scale = get_lr_scale(i)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = learningrate * lr_scale
|
||||||
|
|
||||||
|
if i >= warmup_steps:
|
||||||
|
scheduler_main.step()
|
||||||
|
|
||||||
|
loss_list.append(total_loss.item())
|
||||||
|
|
||||||
# writing the stats to wandb
|
# writing the stats to wandb
|
||||||
if use_wandb and (i+1) % print_stats_at == 0:
|
if use_wandb and (i+1) % print_stats_at == 0:
|
||||||
wandb.log({"training/loss_per_batch": loss.item()}, step=i)
|
wandb.log({
|
||||||
|
"training/loss_total": total_loss.item(),
|
||||||
|
"training/loss_mse": mse.item(),
|
||||||
|
"training/loss_rmse": rmse.item(),
|
||||||
|
"training/loss_perceptual": perceptual.item() if isinstance(perceptual, torch.Tensor) else perceptual,
|
||||||
|
"training/learning_rate": optimizer.param_groups[0]['lr']
|
||||||
|
}, step=i)
|
||||||
|
|
||||||
# plotting
|
# plotting
|
||||||
if (i + 1) % plot_at == 0:
|
if (i + 1) % plot_at == 0:
|
||||||
print(f"Plotting images, current update {i + 1}")
|
print(f"Plotting images, current update {i + 1}")
|
||||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
# Convert to float32 for matplotlib compatibility
|
||||||
|
plot(input.float().cpu().numpy(),
|
||||||
|
target.detach().float().cpu().numpy(),
|
||||||
|
output.detach().float().cpu().numpy(),
|
||||||
|
plotpath, i)
|
||||||
|
|
||||||
# evaluating model every validate_at sample
|
# evaluating model every validate_at sample
|
||||||
if (i + 1) % validate_at == 0:
|
if (i + 1) % validate_at == 0:
|
||||||
|
|||||||
@@ -18,12 +18,14 @@ def plot(inputs, targets, predictions, path, update):
|
|||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
||||||
|
|
||||||
for i in range(5):
|
# Only plot up to min(5, batch_size) images
|
||||||
|
num_images = min(5, inputs.shape[0])
|
||||||
|
|
||||||
|
for i in range(num_images):
|
||||||
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
||||||
ax.clear()
|
ax.clear()
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
img = data[i:i + 1:, 0:3, :, :]
|
img = data[i, 0:3, :, :]
|
||||||
img = np.squeeze(img)
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
img = np.transpose(img, (1, 2, 0))
|
||||||
img = np.clip(img, 0, 1)
|
img = np.clip(img, 0, 1)
|
||||||
ax.imshow(img)
|
ax.imshow(img)
|
||||||
@@ -54,24 +56,58 @@ def testset_plot(input_array, output_array, path, index):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
||||||
"""Returnse MSE and RMSE of the model on the provided dataloader"""
|
"""Returns MSE and RMSE of the model on the provided dataloader"""
|
||||||
|
# Save training mode and switch to eval
|
||||||
|
was_training = network.training
|
||||||
network.eval()
|
network.eval()
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data in dataloader:
|
for data in dataloader:
|
||||||
input_array, target = data
|
input_array, target = data
|
||||||
input_array = input_array.to(device)
|
input_array = input_array.to(device)
|
||||||
target = target.to(device)
|
target = target.to(device)
|
||||||
|
|
||||||
|
# Check input validity
|
||||||
|
if not torch.isfinite(input_array).all() or not torch.isfinite(target).all():
|
||||||
|
print(f"Warning: NaN detected in evaluation inputs")
|
||||||
|
continue
|
||||||
|
|
||||||
outputs = network(input_array)
|
outputs = network(input_array)
|
||||||
|
|
||||||
loss += loss_fn(outputs, target).item()
|
# Clamp outputs to valid range
|
||||||
|
outputs = torch.clamp(outputs, 0.0, 1.0)
|
||||||
|
|
||||||
loss = loss / len(dataloader)
|
# Check for NaN in outputs
|
||||||
|
if not torch.isfinite(outputs).all():
|
||||||
|
print(f"Warning: NaN detected in model outputs during evaluation")
|
||||||
|
continue
|
||||||
|
|
||||||
network.train()
|
batch_loss = loss_fn(outputs, target).item()
|
||||||
|
|
||||||
return loss, 255.0 * np.sqrt(loss)
|
# Check for NaN in loss
|
||||||
|
if not np.isfinite(batch_loss):
|
||||||
|
print(f"Warning: NaN detected in loss during evaluation")
|
||||||
|
continue
|
||||||
|
|
||||||
|
loss += batch_loss
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
if num_batches == 0:
|
||||||
|
print("Error: No valid batches in evaluation")
|
||||||
|
if was_training:
|
||||||
|
network.train()
|
||||||
|
return float('nan'), float('nan')
|
||||||
|
|
||||||
|
loss = loss / num_batches
|
||||||
|
rmse = 255.0 * np.sqrt(loss)
|
||||||
|
|
||||||
|
# Restore training mode
|
||||||
|
if was_training:
|
||||||
|
network.train()
|
||||||
|
|
||||||
|
return loss, rmse
|
||||||
|
|
||||||
|
|
||||||
def read_compressed_file(file_path: str):
|
def read_compressed_file(file_path: str):
|
||||||
@@ -122,6 +158,13 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
|||||||
|
|
||||||
predictions = np.stack(predictions, axis=0)
|
predictions = np.stack(predictions, axis=0)
|
||||||
|
|
||||||
|
# Handle NaN and inf values before conversion
|
||||||
|
nan_mask = ~np.isfinite(predictions)
|
||||||
|
if nan_mask.any():
|
||||||
|
nan_count = nan_mask.sum()
|
||||||
|
print(f"Warning: Found {nan_count} NaN/Inf values in predictions. Replacing with 0.")
|
||||||
|
predictions = np.nan_to_num(predictions, nan=0.0, posinf=1.0, neginf=0.0)
|
||||||
|
|
||||||
predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8)
|
predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
Reference in New Issue
Block a user