1 Commits

Author SHA1 Message Date
0d23053e31 added result, 21.9691 2026-01-24 19:21:24 +01:00
19 changed files with 212 additions and 880 deletions

View File

@@ -1,7 +1,4 @@
data/* data/*
*.zip *.zip
*.jpg *.jpg
*.pt *.pt
__pycache__/
runtime_predictions.npz
results/runtime_config.json

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -7,7 +7,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
def init_weights(m): def init_weights(m):
@@ -16,54 +15,35 @@ def init_weights(m):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
nn.init.constant_(m.weight, 1) if m.weight is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class GatedSkipConnection(nn.Module): class ChannelAttention(nn.Module):
"""Gated skip connection for better feature fusion""" """Channel attention module (squeeze-and-excitation style)"""
def __init__(self, up_channels, skip_channels): def __init__(self, channels, reduction=16):
super().__init__()
self.gate = nn.Sequential(
nn.Conv2d(up_channels + skip_channels, up_channels, 1),
nn.Sigmoid()
)
# Project skip to match up_channels if they differ
if skip_channels != up_channels:
self.skip_proj = nn.Conv2d(skip_channels, up_channels, 1)
else:
self.skip_proj = nn.Identity()
def forward(self, x, skip):
skip_proj = self.skip_proj(skip)
combined = torch.cat([x, skip], dim=1)
gate = self.gate(combined)
return x * gate + skip_proj * (1 - gate)
class EfficientChannelAttention(nn.Module):
"""Efficient channel attention without dimensionality reduction"""
def __init__(self, channels):
super().__init__() super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False) self.max_pool = nn.AdaptiveMaxPool2d(1)
reduced = max(channels // reduction, 8)
self.fc = nn.Sequential(
nn.Conv2d(channels, reduced, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(reduced, channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# Global pooling avg_out = self.fc(self.avg_pool(x))
y = self.avg_pool(x) max_out = self.fc(self.max_pool(x))
# 1D convolution on channel dimension - add safety checks return x * self.sigmoid(avg_out + max_out)
if y.size(-1) == 1 and y.size(-2) == 1:
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
y = torch.clamp(y, min=0.0, max=1.0) # Ensure valid range
return x * y.expand_as(x)
return x
class SpatialAttention(nn.Module): class SpatialAttention(nn.Module):
"""Efficient spatial attention module""" """Spatial attention module"""
def __init__(self, kernel_size=7): def __init__(self, kernel_size=7):
super().__init__() super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
@@ -77,12 +57,12 @@ class SpatialAttention(nn.Module):
return x * attn return x * attn
class EfficientAttention(nn.Module): class CBAM(nn.Module):
"""Lightweight attention module combining channel and spatial""" """Convolutional Block Attention Module"""
def __init__(self, channels): def __init__(self, channels, reduction=16):
super().__init__() super().__init__()
self.channel_attn = EfficientChannelAttention(channels) self.channel_attn = ChannelAttention(channels, reduction)
self.spatial_attn = SpatialAttention(kernel_size=5) self.spatial_attn = SpatialAttention()
def forward(self, x): def forward(self, x):
x = self.channel_attn(x) x = self.channel_attn(x)
@@ -90,302 +70,158 @@ class EfficientAttention(nn.Module):
return x return x
class SelfAttention(nn.Module):
"""Self-attention module for long-range dependencies"""
def __init__(self, in_channels, reduction=8):
super().__init__()
self.query = nn.Conv2d(in_channels, in_channels // reduction, 1)
self.key = nn.Conv2d(in_channels, in_channels // reduction, 1)
self.value = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, C, H, W = x.size()
# Generate query, key, value
query = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
key = self.key(x).view(batch_size, -1, H * W)
value = self.value(x).view(batch_size, -1, H * W)
# Attention map with numerical stability
attention_logits = torch.bmm(query, key)
# Scale for numerical stability
attention_logits = attention_logits / math.sqrt(query.size(-1))
attention = self.softmax(attention_logits)
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, H, W)
# Residual connection with learnable weight
out = self.gamma * out + x
return out
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU""" """Convolutional block with Conv2d -> InstanceNorm2d -> GELU"""
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0, separable=False): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0, dilation=1):
super().__init__() super().__init__()
if separable and in_channels > 1: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
# Depthwise separable convolution for efficiency # InstanceNorm is preferred for style/inpainting tasks
self.conv = nn.Sequential( self.bn = nn.InstanceNorm2d(out_channels, affine=True)
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, groups=in_channels), self.act = nn.GELU()
nn.Conv2d(in_channels, out_channels, 1)
)
else:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
# Add momentum and eps for numerical stability
self.bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5, track_running_stats=True)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x): def forward(self, x):
return self.dropout(self.relu(self.bn(self.conv(x)))) return self.dropout(self.act(self.bn(self.conv(x))))
class DenseBlock(nn.Module):
"""Lightweight dense block for better gradient flow"""
def __init__(self, channels, growth_rate=8, num_layers=2, dropout=0.0):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(ConvBlock(channels + i * growth_rate, growth_rate, dropout=dropout))
self.fusion = nn.Conv2d(channels + num_layers * growth_rate, channels, 1)
self.bn = nn.BatchNorm2d(channels)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
features = [x]
for layer in self.layers:
out = layer(torch.cat(features, dim=1))
features.append(out)
out = self.fusion(torch.cat(features, dim=1))
out = self.relu(self.bn(out))
return out + x # Residual connection
class ResidualConvBlock(nn.Module): class ResidualConvBlock(nn.Module):
"""Improved residual convolutional block with pre-activation""" """Residual convolutional block for better gradient flow"""
def __init__(self, channels, dropout=0.0): def __init__(self, channels, dropout=0.0, dilation=1):
super().__init__() super().__init__()
self.bn1 = nn.BatchNorm2d(channels) self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
self.relu1 = nn.LeakyReLU(0.2, inplace=True) self.bn1 = nn.InstanceNorm2d(channels, affine=True)
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
self.bn2 = nn.BatchNorm2d(channels) self.bn2 = nn.InstanceNorm2d(channels, affine=True)
self.relu2 = nn.LeakyReLU(0.2, inplace=True) self.act = nn.GELU()
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x): def forward(self, x):
residual = x residual = x
out = self.relu1(self.bn1(x)) out = self.act(self.bn1(self.conv1(x)))
out = self.conv1(out)
out = self.relu2(self.bn2(out))
out = self.dropout(out) out = self.dropout(out)
out = self.conv2(out) out = self.bn2(self.conv2(out))
return out + residual out = out + residual
return self.act(out)
class DownBlock(nn.Module): class DownBlock(nn.Module):
"""Enhanced downsampling block with dense and residual connections""" """Downsampling block with conv blocks, residual connection, attention, and max pooling"""
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False, use_self_attention=False): def __init__(self, in_channels, out_channels, dropout=0.1):
super().__init__() super().__init__()
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout, separable=True) self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout) self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
if use_dense: self.residual = ResidualConvBlock(out_channels, dropout=dropout)
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout) self.attention = CBAM(out_channels)
else:
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
self.self_attention = SelfAttention(out_channels) if use_self_attention else nn.Identity()
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.dense(x) x = self.residual(x)
x = self.attention(x) skip = self.attention(x)
skip = self.self_attention(x)
return self.pool(skip), skip return self.pool(skip), skip
class UpBlock(nn.Module): class UpBlock(nn.Module):
"""Enhanced upsampling block with gated skip connections""" """Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True, use_dense=False, use_self_attention=False): def __init__(self, in_channels, out_channels, dropout=0.1):
super().__init__() super().__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
# Skip connection has in_channels, upsampled has out_channels # After concat: out_channels (from upconv) + in_channels (from skip)
self.gated_skip = GatedSkipConnection(out_channels, in_channels) self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
# After gated skip: out_channels
self.conv1 = ConvBlock(out_channels, out_channels, dropout=dropout, separable=True)
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout) self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
if use_dense: self.residual = ResidualConvBlock(out_channels, dropout=dropout)
self.dense = DenseBlock(out_channels, growth_rate=8, num_layers=2, dropout=dropout) self.attention = CBAM(out_channels)
else:
self.dense = ResidualConvBlock(out_channels, dropout=dropout)
self.attention = EfficientAttention(out_channels) if use_attention else nn.Identity()
self.self_attention = SelfAttention(out_channels) if use_self_attention else nn.Identity()
def forward(self, x, skip): def forward(self, x, skip):
x = self.up(x) x = self.up(x)
# Handle dimension mismatch # Handle dimension mismatch by interpolating x to match skip's size
if x.shape[2:] != skip.shape[2:]: if x.shape[2:] != skip.shape[2:]:
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = self.gated_skip(x, skip) x = torch.cat([x, skip], dim=1)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.dense(x) x = self.residual(x)
x = self.attention(x) x = self.attention(x)
x = self.self_attention(x)
return x return x
class MyModel(nn.Module): class MyModel(nn.Module):
"""Enhanced U-Net architecture with dense connections and efficient attention""" """Improved U-Net style architecture for image inpainting with attention and residual connections"""
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1): def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
super().__init__() super().__init__()
# Separate mask processing for better feature extraction # Initial convolution with larger receptive field
# Separate mask processing for better feature extraction self.init_conv = nn.Sequential(
self.mask_conv = nn.Sequential( ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
nn.Conv2d(1, base_channels // 4, 3, padding=1),
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
nn.LeakyReLU(0.2, inplace=True)
)
# Image processing path
self.image_conv = nn.Sequential(
ConvBlock(3, base_channels, kernel_size=5, padding=2),
ConvBlock(base_channels, base_channels)
)
# Fusion of mask and image features
self.fusion = nn.Sequential(
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
nn.BatchNorm2d(base_channels, momentum=0.1, eps=1e-5, track_running_stats=True),
nn.LeakyReLU(0.2, inplace=True)
)
# Encoder with progressive feature extraction
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout, use_attention=False, use_dense=False)
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True, use_self_attention=True)
# Enhanced bottleneck with multi-scale features, dense connections, and self-attention
self.bottleneck = nn.Sequential(
ConvBlock(base_channels * 8, base_channels * 8, dropout=dropout),
DenseBlock(base_channels * 8, growth_rate=12, num_layers=3, dropout=dropout),
SelfAttention(base_channels * 8, reduction=4),
ConvBlock(base_channels * 8, base_channels * 8, dilation=2, padding=2, dropout=dropout),
ResidualConvBlock(base_channels * 8, dropout=dropout),
EfficientAttention(base_channels * 8)
)
# Decoder with progressive reconstruction
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True, use_self_attention=True)
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout, use_attention=True, use_dense=True)
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout, use_attention=False, use_dense=False)
# Multi-scale feature fusion with dense connections
self.multiscale_fusion = nn.Sequential(
ConvBlock(base_channels * 2, base_channels),
DenseBlock(base_channels, growth_rate=8, num_layers=2, dropout=dropout//2),
ConvBlock(base_channels, base_channels)
)
# Output with residual connection to input
self.pre_output = nn.Sequential(
ConvBlock(base_channels, base_channels), ConvBlock(base_channels, base_channels),
ConvBlock(base_channels, base_channels // 2) ResidualConvBlock(base_channels)
) )
# Encoder (downsampling path)
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
# Bottleneck with multiple residual blocks
self.bottleneck = nn.Sequential(
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=2),
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=4),
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=8),
CBAM(base_channels * 16)
)
# Decoder (upsampling path)
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
# Final refinement layers
self.final_conv = nn.Sequential(
ConvBlock(base_channels * 2, base_channels),
ResidualConvBlock(base_channels),
ConvBlock(base_channels, base_channels)
)
# Output layer with smooth transition
self.output = nn.Sequential( self.output = nn.Sequential(
nn.Conv2d(base_channels // 2 + 3, base_channels // 2, 3, padding=1), nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True), nn.GELU(),
nn.Conv2d(base_channels // 2, 3, 1), nn.Conv2d(base_channels // 2, 3, kernel_size=1),
nn.Sigmoid() nn.Sigmoid() # Ensure output is in [0, 1] range
) )
# Apply weight initialization # Apply weight initialization
self.apply(init_weights) self.apply(init_weights)
def forward(self, x): def forward(self, x):
# Split input into image and mask # Initial convolution
image = x[:, :3, :, :] x0 = self.init_conv(x)
mask = x[:, 3:4, :, :]
# Clamp inputs to valid range
image = torch.clamp(image, 0.0, 1.0)
mask = torch.clamp(mask, 0.0, 1.0)
# Process mask and image separately
mask_features = self.mask_conv(mask)
image_features = self.image_conv(image)
# Safety check after initial processing
if not torch.isfinite(mask_features).all():
mask_features = torch.nan_to_num(mask_features, nan=0.0, posinf=1.0, neginf=-1.0)
if not torch.isfinite(image_features).all():
image_features = torch.nan_to_num(image_features, nan=0.0, posinf=1.0, neginf=-1.0)
# Fuse features
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
if not torch.isfinite(x0).all():
x0 = torch.nan_to_num(x0, nan=0.0, posinf=1.0, neginf=-1.0)
# Encoder # Encoder
x1, skip1 = self.down1(x0) x1, skip1 = self.down1(x0)
if not torch.isfinite(x1).all():
x1 = torch.nan_to_num(x1, nan=0.0, posinf=1.0, neginf=-1.0)
skip1 = torch.nan_to_num(skip1, nan=0.0, posinf=1.0, neginf=-1.0)
x2, skip2 = self.down2(x1) x2, skip2 = self.down2(x1)
if not torch.isfinite(x2).all():
x2 = torch.nan_to_num(x2, nan=0.0, posinf=1.0, neginf=-1.0)
skip2 = torch.nan_to_num(skip2, nan=0.0, posinf=1.0, neginf=-1.0)
x3, skip3 = self.down3(x2) x3, skip3 = self.down3(x2)
if not torch.isfinite(x3).all(): x4, skip4 = self.down4(x3)
x3 = torch.nan_to_num(x3, nan=0.0, posinf=1.0, neginf=-1.0)
skip3 = torch.nan_to_num(skip3, nan=0.0, posinf=1.0, neginf=-1.0)
# Bottleneck # Bottleneck
x = self.bottleneck(x3) x = self.bottleneck(x4)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Decoder with skip connections # Decoder with skip connections
x = self.up1(x, skip3) x = self.up1(x, skip4)
if not torch.isfinite(x).all(): x = self.up2(x, skip3)
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) x = self.up3(x, skip2)
x = self.up4(x, skip1)
x = self.up2(x, skip2) # Handle dimension mismatch for final concatenation
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
x = self.up3(x, skip1)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Handle dimension mismatch for final fusion
if x.shape[2:] != x0.shape[2:]: if x.shape[2:] != x0.shape[2:]:
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False) x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
# Multi-scale fusion with initial features # Concatenate with initial features for better detail preservation
x = torch.cat([x, x0], dim=1) x = torch.cat([x, x0], dim=1)
x = self.multiscale_fusion(x) x = self.final_conv(x)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Pre-output processing # Output
x = self.pre_output(x)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Concatenate with original masked image for residual learning
x = torch.cat([x, image], dim=1)
x = self.output(x) x = self.output(x)
# Final safety clamp
x = torch.clamp(x, 0.0, 1.0)
return x return x

View File

@@ -10,8 +10,7 @@ import numpy as np
import random import random
import glob import glob
import os import os
from PIL import Image, ImageEnhance, ImageFilter from PIL import Image
from scipy.ndimage import gaussian_filter, map_coordinates
IMAGE_DIMENSION = 100 IMAGE_DIMENSION = 100
@@ -27,130 +26,34 @@ def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tu
return image_array, known_array return image_array, known_array
def resize(img: Image): def resize(img: Image, augment: bool = False):
resize_transforms = transforms.Compose([ transforms_list = [
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)), transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION)) transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
]) ]
if augment:
transforms_list = [
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
transforms.RandomRotation(10),
] + transforms_list
resize_transforms = transforms.Compose(transforms_list)
return resize_transforms(img) return resize_transforms(img)
def preprocess(input_array: np.ndarray): def preprocess(input_array: np.ndarray):
input_array = np.asarray(input_array, dtype=np.float32) / 255.0 input_array = np.asarray(input_array, dtype=np.float32) / 255.0
return input_array return input_array
def elastic_transform(image: np.ndarray, alpha: float = 20, sigma: float = 4) -> np.ndarray:
"""Apply elastic deformation to image array"""
shape = image.shape[:2]
dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
# Apply to each channel
transformed = np.zeros_like(image)
for i in range(image.shape[2]):
transformed[:, :, i] = map_coordinates(image[:, :, i], indices, order=1, mode='reflect').reshape(shape)
return transformed
def add_noise(img_array: np.ndarray, noise_type: str = 'gaussian', strength: float = 0.02) -> np.ndarray:
"""Add various types of noise to image"""
if noise_type == 'gaussian':
noise = np.random.normal(0, strength, img_array.shape)
noisy = img_array + noise
elif noise_type == 'salt_pepper':
noisy = img_array.copy()
# Salt
num_salt = int(strength * img_array.size * 0.5)
coords = [np.random.randint(0, i, num_salt) for i in img_array.shape]
noisy[coords[0], coords[1], :] = 1
# Pepper
num_pepper = int(strength * img_array.size * 0.5)
coords = [np.random.randint(0, i, num_pepper) for i in img_array.shape]
noisy[coords[0], coords[1], :] = 0
else:
noisy = img_array
return np.clip(noisy, 0, 1)
def augment_image(img: Image, strength: float = 0.8) -> Image:
"""Apply comprehensive data augmentation for better generalization"""
# Random horizontal flip
if random.random() > 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# Random vertical flip
if random.random() > 0.5:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
# Random rotation (90, 180, 270 degrees, or small angles)
if random.random() > 0.5:
if random.random() > 0.7:
# Large rotation
angle = random.choice([90, 180, 270])
img = img.rotate(angle)
else:
# Small rotation for more variation
angle = random.uniform(-15, 15)
img = img.rotate(angle, fillcolor=(128, 128, 128))
# More aggressive color augmentation
if random.random() > 0.3:
# Brightness
factor = 1.0 + random.uniform(-0.3, 0.3) * strength
img = ImageEnhance.Brightness(img).enhance(factor)
if random.random() > 0.3:
# Contrast
factor = 1.0 + random.uniform(-0.3, 0.3) * strength
img = ImageEnhance.Contrast(img).enhance(factor)
if random.random() > 0.3:
# Saturation
factor = 1.0 + random.uniform(-0.25, 0.25) * strength
img = ImageEnhance.Color(img).enhance(factor)
if random.random() > 0.7:
# Sharpness
factor = 1.0 + random.uniform(-0.3, 0.5) * strength
img = ImageEnhance.Sharpness(img).enhance(factor)
# Gaussian blur for robustness
if random.random() > 0.8:
radius = random.uniform(0.5, 1.5) * strength
img = img.filter(ImageFilter.GaussianBlur(radius=radius))
# Convert to array for elastic transform and noise
img_array = np.array(img).astype(np.float32) / 255.0
# Elastic deformation
if random.random() > 0.7:
alpha = random.uniform(15, 30) * strength
sigma = random.uniform(3, 5)
img_array = elastic_transform(img_array, alpha=alpha, sigma=sigma)
# Add noise
if random.random() > 0.6:
noise_type = random.choice(['gaussian', 'salt_pepper'])
noise_strength = random.uniform(0.01, 0.03) * strength
img_array = add_noise(img_array, noise_type=noise_type, strength=noise_strength)
# Convert back to PIL Image
img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8)
img = Image.fromarray(img_array)
return img
class ImageDataset(torch.utils.data.Dataset): class ImageDataset(torch.utils.data.Dataset):
""" """
Dataset class for loading images from a folder with augmentation support Dataset class for loading images from a folder
""" """
def __init__(self, datafolder: str, augment: bool = True, augment_strength: float = 0.8): def __init__(self, datafolder: str, augment: bool = False):
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True)) self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
self.augment = augment self.augment = augment
self.augment_strength = augment_strength
def __len__(self): def __len__(self):
return len(self.imagefiles) return len(self.imagefiles)
@@ -159,13 +62,7 @@ class ImageDataset(torch.utils.data.Dataset):
index = int(idx) index = int(idx)
image = Image.open(self.imagefiles[index]) image = Image.open(self.imagefiles[index])
image = resize(image) image = np.asarray(resize(image, self.augment))
# Apply augmentation
if self.augment:
image = augment_image(image, self.augment_strength)
image = np.asarray(image)
image = preprocess(image) image = preprocess(image)
spacing_x = random.randint(2,6) spacing_x = random.randint(2,6)
spacing_y = random.randint(2,6) spacing_y = random.randint(2,6)

View File

@@ -17,74 +17,46 @@ if __name__ == '__main__':
config_dict['seed'] = 42 config_dict['seed'] = 42
config_dict['testset_ratio'] = 0.1 config_dict['testset_ratio'] = 0.1
config_dict['validset_ratio'] = 0.05 config_dict['validset_ratio'] = 0.1
# Get the absolute path based on the script's location # Get the absolute path based on the script's location
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir) project_root = os.path.dirname(script_dir)
config_dict['results_path'] = os.path.join(project_root, "results") config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset") config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None config_dict['device'] = None
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
config_dict['n_updates'] = 12000 # Extended training for better convergence config_dict['n_updates'] = 5000 # More updates for better convergence
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
config_dict['early_stopping_patience'] = 20 # More patience for complex model config_dict['early_stopping_patience'] = 10 # More patience for complex model
config_dict['use_wandb'] = False config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 10 config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 200 config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 500 config_dict['plot_at'] = 300
config_dict['validate_at'] = 250 # More frequent validation config_dict['validate_at'] = 300 # Validate more frequently
network_config = { network_config = {
'n_in_channels': 4, 'n_in_channels': 4,
'base_channels': 52, # Increased capacity for better feature extraction 'base_channels': 48, # Good balance between capacity and memory
'dropout': 0.15 # Slightly higher dropout for regularization 'dropout': 0.1 # Regularization
} }
config_dict['network_config'] = network_config config_dict['network_config'] = network_config
# Prepare paths for runtime predictions
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
save_path = os.path.join(config_dict['results_path'], "runtime_predictions")
plot_path_predictions = os.path.join(config_dict['results_path'], "runtime_predictions", "plots")
config_dict['testset_path'] = testset_path
config_dict['save_path'] = save_path
config_dict['plot_path_predictions'] = plot_path_predictions
print("="*60)
print("RUNTIME CONFIGURATION ENABLED")
print("="*60)
print("During training, you can modify these parameters by editing:")
print(f"{os.path.join(config_dict['results_path'], 'runtime_config.json')}")
print("\nModifiable parameters:")
print(" - n_updates: Maximum training steps")
print(" - plot_at: How often to save plots")
print(" - early_stopping_patience: Patience for early stopping")
print(" - print_stats_at: How often to print detailed stats")
print(" - print_train_stats_at: How often to print training loss")
print(" - validate_at: How often to run validation")
print("\nRuntime commands (set to true to execute):")
print(" - save_checkpoint: Save model at current step")
print(" - run_test_validation: Run validation on final test set")
print(" - generate_predictions: Generate predictions on challenge testset")
print("\nChanges will be applied within 5 steps.")
print("="*60)
print()
rmse_value = train(**config_dict) rmse_value = train(**config_dict)
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt") state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
final_save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz") save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
final_plot_path = os.path.join(config_dict['results_path'], "testset", "plots") plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
os.makedirs(final_plot_path, exist_ok=True) os.makedirs(plot_path, exist_ok=True)
for name in os.listdir(final_plot_path): for name in os.listdir(plot_path):
p = os.path.join(final_plot_path, name) p = os.path.join(plot_path, name)
if os.path.isfile(p) or os.path.islink(p): if os.path.isfile(p) or os.path.islink(p):
os.unlink(p) os.unlink(p)
elif os.path.isdir(p): elif os.path.isdir(p):
shutil.rmtree(p) shutil.rmtree(p)
# Comment out, if predictions are required # Comment out, if predictions are required
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, final_save_path, final_plot_path, plot_at=20, rmse_value=rmse_value) create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)

View File

@@ -12,182 +12,52 @@ import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import os import os
import json
from torchvision import models
import torch.nn.functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Subset from torch.utils.data import Subset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import wandb import wandb
def load_runtime_config(config_path, current_params):
"""Load runtime configuration from JSON file and update parameters"""
try:
if os.path.exists(config_path):
with open(config_path, 'r') as f:
new_config = json.load(f)
# Update modifiable parameters
updated = False
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
'print_stats_at', 'print_train_stats_at', 'validate_at',
'learningrate', 'weight_decay']
for key in modifiable_keys:
if key in new_config and new_config[key] != current_params.get(key):
old_val = current_params.get(key)
current_params[key] = new_config[key]
print(f"\n[CONFIG UPDATE] {key}: {old_val} -> {new_config[key]}")
updated = True
# Check for command flags
commands = new_config.get('commands', {})
current_params['commands'] = commands
if updated:
print("[CONFIG UPDATE] Runtime configuration updated successfully!\n")
except Exception as e:
print(f"Warning: Could not load runtime config: {e}")
return current_params
def clear_command_flag(config_path, command_name):
"""Clear a specific command flag after execution"""
try:
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
if 'commands' in config and command_name in config['commands']:
config['commands'][command_name] = False
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
except Exception as e:
print(f"Warning: Could not clear command flag: {e}")
class RMSELoss(nn.Module):
"""RMSE loss for direct optimization of evaluation metric"""
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred, target):
mse = self.mse(pred, target)
# Larger epsilon for numerical stability
rmse = torch.sqrt(mse + 1e-6)
return rmse
class PerceptualLoss(nn.Module):
"""Perceptual loss using VGG16 features for better texture and detail preservation"""
def __init__(self, device):
super().__init__()
# Load pre-trained VGG16 and use specific layers
vgg = models.vgg16(pretrained=True).features.to(device).eval()
# Freeze VGG parameters
for param in vgg.parameters():
param.requires_grad = False
# Use early and middle layers for perceptual loss
self.slice1 = nn.Sequential(*list(vgg.children())[:4]) # relu1_2
self.slice2 = nn.Sequential(*list(vgg.children())[4:9]) # relu2_2
self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # relu3_3
# Normalization for VGG (ImageNet stats)
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def normalize(self, x):
"""Normalize images for VGG with clamping for stability"""
# Clamp input to valid range
x = torch.clamp(x, 0.0, 1.0)
return (x - self.mean) / (self.std + 1e-8)
def forward(self, pred, target):
# Clamp inputs to prevent extreme values
pred = torch.clamp(pred, 0.0, 1.0)
target = torch.clamp(target, 0.0, 1.0)
# Normalize inputs
pred = self.normalize(pred)
target = self.normalize(target)
# Extract features from multiple layers
pred_f1 = self.slice1(pred)
pred_f2 = self.slice2(pred_f1)
pred_f3 = self.slice3(pred_f2)
target_f1 = self.slice1(target)
target_f2 = self.slice2(target_f1)
target_f3 = self.slice3(target_f2)
# Compute losses at multiple scales
loss = F.l1_loss(pred_f1, target_f1) + \
F.l1_loss(pred_f2, target_f2) + \
F.l1_loss(pred_f3, target_f3)
return loss
class CombinedLoss(nn.Module): class CombinedLoss(nn.Module):
"""Combined loss optimized for RMSE evaluation with optional perceptual component""" """Combined loss: MSE + L1 + SSIM-like perceptual component"""
def __init__(self, device, use_perceptual=True, perceptual_weight=0.05): def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
super().__init__() super().__init__()
self.use_perceptual = use_perceptual self.mse_weight = mse_weight
if use_perceptual: self.l1_weight = l1_weight
self.perceptual_loss = PerceptualLoss(device) self.edge_weight = edge_weight
# Use MSE instead of RMSE for training (more stable gradients) self.mse = nn.MSELoss()
self.mse_loss = nn.MSELoss() self.l1 = nn.L1Loss()
self.rmse_loss = RMSELoss() # For logging only
self.perceptual_weight = perceptual_weight # Sobel filters for edge detection
self.mse_weight = 1.0 - perceptual_weight sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
def edge_loss(self, pred, target):
"""Compute edge-aware loss using Sobel filters"""
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
return edge_loss
def forward(self, pred, target): def forward(self, pred, target):
# Clamp predictions to valid range mse_loss = self.mse(pred, target)
pred = torch.clamp(pred, 0.0, 1.0) l1_loss = self.l1(pred, target)
target = torch.clamp(target, 0.0, 1.0) edge_loss = self.edge_loss(pred, target)
# Check for NaN in inputs total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
if not torch.isfinite(pred).all() or not torch.isfinite(target).all(): return total_loss
print("Warning: NaN detected in loss inputs")
return (torch.tensor(float('nan'), device=pred.device),) * 4
# Primary loss: MSE (equivalent to RMSE but more stable)
mse = self.mse_loss(pred, target)
rmse = self.rmse_loss(pred, target) # For logging
if self.use_perceptual:
# Optional small perceptual component for texture quality
perceptual = self.perceptual_loss(pred, target)
# Check perceptual loss validity
if not torch.isfinite(perceptual):
perceptual = torch.tensor(0.0, device=pred.device)
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
else:
# Pure MSE optimization
perceptual = torch.tensor(0.0, device=pred.device)
total_loss = mse
# Validate loss is not NaN or Inf
if not torch.isfinite(total_loss):
# Return MSE only as fallback
total_loss = mse
if not torch.isfinite(total_loss):
print("Warning: MSE is NaN")
return (torch.tensor(float('nan'), device=pred.device),) * 4
return total_loss, perceptual, mse, rmse
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate, def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize, weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize,
network_config: dict, testset_path=None, save_path=None, plot_path_predictions=None): network_config: dict):
np.random.seed(seed=seed) np.random.seed(seed=seed)
torch.manual_seed(seed=seed) torch.manual_seed(seed=seed)
@@ -197,13 +67,6 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)
# Enable mixed precision training for memory efficiency
use_amp = torch.cuda.is_available()
if use_amp:
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
else:
scaler = None
if use_wandb: if use_wandb:
wandb.login() wandb.login()
@@ -221,16 +84,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
plotpath = os.path.join(results_path, "plots") plotpath = os.path.join(results_path, "plots")
os.makedirs(plotpath, exist_ok=True) os.makedirs(plotpath, exist_ok=True)
image_dataset = datasets.ImageDataset(datafolder=data_path) image_dataset = datasets.ImageDataset(datafolder=data_path, augment=False)
n_total = len(image_dataset) n_total = len(image_dataset)
n_test = int(n_total * testset_ratio) n_test = int(n_total * testset_ratio)
n_valid = int(n_total * validset_ratio) n_valid = int(n_total * validset_ratio)
n_train = n_total - n_test - n_valid n_train = n_total - n_test - n_valid
indices = np.random.permutation(n_total) indices = np.random.permutation(n_total)
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid]) # Create datasets with and without augmentation
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total]) train_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=True)
val_test_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=False)
dataset_train = Subset(train_dataset_source, indices=indices[0:n_train])
dataset_valid = Subset(val_test_dataset_source, indices=indices[n_train:n_train + n_valid])
dataset_test = Subset(val_test_dataset_source, indices=indices[n_train + n_valid:n_total])
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid) assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
@@ -248,28 +116,15 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
network.to(device) network.to(device)
network.train() network.train()
# defining the loss - Optimized for RMSE evaluation # defining the loss - combined loss for better reconstruction
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
mse_loss = torch.nn.MSELoss() # Keep for evaluation mse_loss = torch.nn.MSELoss() # Keep for evaluation
# defining the optimizer with AdamW for better weight decay handling # defining the optimizer with AdamW for better weight decay handling
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999)) optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
# Learning rate warmup # Learning rate scheduler for better convergence
warmup_steps = min(1000, n_updates // 10) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
# Cosine annealing with warm restarts for long training
scheduler_main = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
)
# Warmup scheduler
def get_lr_scale(step):
if step < warmup_steps:
return step / warmup_steps
return 1.0
if use_wandb: if use_wandb:
wandb.watch(network, mse_loss, log="all", log_freq=10) wandb.watch(network, mse_loss, log="all", log_freq=10)
@@ -280,224 +135,42 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
loss_list = [] loss_list = []
saved_model_path = os.path.join(results_path, "best_model.pt") saved_model_path = os.path.join(results_path, "best_model.pt")
# Save runtime configuration to JSON file for dynamic updates
config_json_path = os.path.join(results_path, "runtime_config.json")
runtime_params = {
'learningrate': learningrate,
'weight_decay': weight_decay,
'n_updates': n_updates,
'plot_at': plot_at,
'early_stopping_patience': early_stopping_patience,
'print_stats_at': print_stats_at,
'print_train_stats_at': print_train_stats_at,
'validate_at': validate_at,
'commands': {
'save_checkpoint': False,
'run_test_validation': False,
'generate_predictions': False
}
}
with open(config_json_path, 'w') as f:
json.dump(runtime_params, f, indent=2)
print(f"Started training on device {device}") print(f"Started training on device {device}")
print(f"Runtime config saved to: {config_json_path}")
print(f"You can modify this file during training to change parameters dynamically!")
print(f"Set command flags to true to trigger actions (save_checkpoint, run_test_validation, generate_predictions)\n")
while i < n_updates: while i < n_updates:
for input, target in dataloader_train: for input, target in dataloader_train:
input, target = input.to(device), target.to(device) input, target = input.to(device), target.to(device)
# Check for runtime config updates every 5 steps
if i % 5 == 0 and i > 0:
runtime_params = load_runtime_config(config_json_path, runtime_params)
n_updates = runtime_params['n_updates']
plot_at = runtime_params['plot_at']
early_stopping_patience = runtime_params['early_stopping_patience']
print_stats_at = runtime_params['print_stats_at']
print_train_stats_at = runtime_params['print_train_stats_at']
validate_at = runtime_params['validate_at']
# Update optimizer parameters if changed
if 'learningrate' in runtime_params:
new_lr = runtime_params['learningrate']
current_lr = optimizer.param_groups[0]['lr']
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
if 'weight_decay' in runtime_params:
new_wd = runtime_params['weight_decay']
current_wd = optimizer.param_groups[0]['weight_decay']
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['weight_decay'] = new_wd
# Execute runtime commands
commands = runtime_params.get('commands', {})
# Command: Save checkpoint
if commands.get('save_checkpoint', False):
checkpoint_path = os.path.join(results_path, f"checkpoint_step_{i}.pt")
torch.save(network.state_dict(), checkpoint_path)
print(f"\n[COMMAND] Checkpoint saved to: {checkpoint_path}\n")
clear_command_flag(config_json_path, 'save_checkpoint')
# Command: Generate predictions
if commands.get('generate_predictions', False) and testset_path is not None:
print(f"\n[COMMAND] Generating predictions at step {i}...")
try:
from utils import create_predictions
pred_save_path = save_path or os.path.join(results_path, "runtime_predictions", f"step_{i}")
pred_plot_path = plot_path_predictions or os.path.join(results_path, "runtime_predictions", "plots", f"step_{i}")
os.makedirs(pred_plot_path, exist_ok=True)
# Save current state temporarily
temp_state_path = os.path.join(results_path, f"temp_state_step_{i}.pt")
torch.save(network.state_dict(), temp_state_path)
# Generate predictions
create_predictions(network_config, temp_state_path, testset_path, None,
pred_save_path, pred_plot_path, plot_at=20, rmse_value=None)
print(f"[COMMAND] Predictions saved to: {pred_save_path}")
print(f"[COMMAND] Plots saved to: {pred_plot_path}\n")
# Clean up temp file
if os.path.exists(temp_state_path):
os.remove(temp_state_path)
except Exception as e:
print(f"[COMMAND] Error generating predictions: {e}\n")
network.train()
clear_command_flag(config_json_path, 'generate_predictions')
# Command: Run test validation
if commands.get('run_test_validation', False):
print(f"\n[COMMAND] Running test set validation at step {i}...")
network.eval()
test_loss, test_rmse = evaluate_model(network, dataloader_test, mse_loss, device)
print(f"[COMMAND] Test Loss: {test_loss:.6f}, Test RMSE: {test_rmse:.6f}\n")
network.train()
clear_command_flag(config_json_path, 'run_test_validation')
if (i + 1) % print_train_stats_at == 0: if (i + 1) % print_train_stats_at == 0:
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}') print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
optimizer.zero_grad() optimizer.zero_grad()
# Mixed precision training for memory efficiency output = network(input)
if use_amp:
with torch.amp.autocast('cuda'):
output = network(input)
total_loss, perceptual, mse, rmse = combined_loss(output, target)
# Check for NaN before backward
if not torch.isfinite(total_loss):
continue
scaler.scale(total_loss).backward()
# Unscale and check gradients
scaler.unscale_(optimizer)
# Check for NaN in gradients
has_nan = False
for name, param in network.named_parameters():
if param.grad is not None:
if not torch.isfinite(param.grad).all():
print(f"NaN gradient detected in {name}")
has_nan = True
break
if has_nan:
print(f"Skipping step {i+1}: NaN gradients detected")
optimizer.zero_grad()
scaler.update()
# Reset scaler if NaN persists
if (i + 1) % 10 == 0:
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
continue
# More aggressive gradient clipping for stability
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
# Skip update if gradient norm is too large
if grad_norm > 100.0:
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
optimizer.zero_grad()
scaler.update()
continue
scaler.step(optimizer)
scaler.update()
else:
output = network(input)
total_loss, perceptual, mse, rmse = combined_loss(output, target)
# Check for NaN before backward
if not torch.isfinite(total_loss):
print(f"Skipping step {i+1}: NaN or Inf loss detected")
continue
total_loss.backward()
# Check for NaN in gradients
has_nan = False
for name, param in network.named_parameters():
if param.grad is not None and not torch.isfinite(param.grad).all():
print(f"NaN gradient detected in {name}")
has_nan = True
break
if has_nan:
print(f"Skipping step {i+1}: NaN gradients detected")
optimizer.zero_grad()
continue
# More aggressive gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
if grad_norm > 100.0:
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
optimizer.zero_grad()
continue
optimizer.step()
# Apply learning rate scheduling with warmup
lr_scale = get_lr_scale(i)
for param_group in optimizer.param_groups:
param_group['lr'] = learningrate * lr_scale
if i >= warmup_steps:
scheduler_main.step()
loss_list.append(total_loss.item()) loss = combined_loss(output, target)
loss.backward()
# Gradient clipping for training stability
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step(i + len(loss_list) / len(dataloader_train))
loss_list.append(loss.item())
# writing the stats to wandb # writing the stats to wandb
if use_wandb and (i+1) % print_stats_at == 0: if use_wandb and (i+1) % print_stats_at == 0:
wandb.log({ wandb.log({"training/loss_per_batch": loss.item()}, step=i)
"training/loss_total": total_loss.item(),
"training/loss_mse": mse.item(),
"training/loss_rmse": rmse.item(),
"training/loss_perceptual": perceptual.item() if isinstance(perceptual, torch.Tensor) else perceptual,
"training/learning_rate": optimizer.param_groups[0]['lr']
}, step=i)
# plotting # plotting
if (i + 1) % plot_at == 0: if (i + 1) % plot_at == 0:
print(f"Plotting images, current update {i + 1}") print(f"Plotting images, current update {i + 1}")
# Convert to float32 for matplotlib compatibility plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
plot(input.float().cpu().numpy(),
target.detach().float().cpu().numpy(),
output.detach().float().cpu().numpy(),
plotpath, i)
# evaluating model every validate_at sample # evaluating model every validate_at sample
if (i + 1) % validate_at == 0: if (i + 1) % validate_at == 0:

View File

@@ -18,14 +18,12 @@ def plot(inputs, targets, predictions, path, update):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
fig, axes = plt.subplots(ncols=3, figsize=(15, 5)) fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
# Only plot up to min(5, batch_size) images for i in range(5):
num_images = min(5, inputs.shape[0])
for i in range(num_images):
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]): for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
ax.clear() ax.clear()
ax.set_title(title) ax.set_title(title)
img = data[i, 0:3, :, :] img = data[i:i + 1:, 0:3, :, :]
img = np.squeeze(img)
img = np.transpose(img, (1, 2, 0)) img = np.transpose(img, (1, 2, 0))
img = np.clip(img, 0, 1) img = np.clip(img, 0, 1)
ax.imshow(img) ax.imshow(img)
@@ -56,58 +54,24 @@ def testset_plot(input_array, output_array, path, index):
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device): def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
"""Returns MSE and RMSE of the model on the provided dataloader""" """Returnse MSE and RMSE of the model on the provided dataloader"""
# Save training mode and switch to eval
was_training = network.training
network.eval() network.eval()
loss = 0.0 loss = 0.0
num_batches = 0
with torch.no_grad(): with torch.no_grad():
for data in dataloader: for data in dataloader:
input_array, target = data input_array, target = data
input_array = input_array.to(device) input_array = input_array.to(device)
target = target.to(device) target = target.to(device)
# Check input validity
if not torch.isfinite(input_array).all() or not torch.isfinite(target).all():
print(f"Warning: NaN detected in evaluation inputs")
continue
outputs = network(input_array) outputs = network(input_array)
# Clamp outputs to valid range
outputs = torch.clamp(outputs, 0.0, 1.0)
# Check for NaN in outputs
if not torch.isfinite(outputs).all():
print(f"Warning: NaN detected in model outputs during evaluation")
continue
batch_loss = loss_fn(outputs, target).item()
# Check for NaN in loss
if not np.isfinite(batch_loss):
print(f"Warning: NaN detected in loss during evaluation")
continue
loss += batch_loss
num_batches += 1
if num_batches == 0:
print("Error: No valid batches in evaluation")
if was_training:
network.train()
return float('nan'), float('nan')
loss = loss / num_batches
rmse = 255.0 * np.sqrt(loss)
# Restore training mode loss += loss_fn(outputs, target).item()
if was_training:
network.train()
return loss, rmse loss = loss / len(dataloader)
network.train()
return loss, 255.0 * np.sqrt(loss)
def read_compressed_file(file_path: str): def read_compressed_file(file_path: str):
@@ -158,13 +122,6 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
predictions = np.stack(predictions, axis=0) predictions = np.stack(predictions, axis=0)
# Handle NaN and inf values before conversion
nan_mask = ~np.isfinite(predictions)
if nan_mask.any():
nan_count = nan_mask.sum()
print(f"Warning: Found {nan_count} NaN/Inf values in predictions. Replacing with 0.")
predictions = np.nan_to_num(predictions, nan=0.0, posinf=1.0, neginf=0.0)
predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8) predictions = (np.clip(predictions, 0, 1) * 255.0).astype(np.uint8)
data = { data = {