improve baseline

This commit is contained in:
2026-01-24 16:15:34 +01:00
parent 3b1d3c0497
commit 57026695d4
10 changed files with 228 additions and 63 deletions

View File

@@ -1,3 +1,4 @@
data/* data/*
*.zip *.zip
*.jpg *.jpg
*.pt

View File

@@ -6,78 +6,190 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
def init_weights(m):
"""Initialize weights using Kaiming initialization for better training"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ChannelAttention(nn.Module):
"""Channel attention module (squeeze-and-excitation style)"""
def __init__(self, channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
reduced = max(channels // reduction, 8)
self.fc = nn.Sequential(
nn.Conv2d(channels, reduced, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(reduced, channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
return x * self.sigmoid(avg_out + max_out)
class SpatialAttention(nn.Module):
"""Spatial attention module"""
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attn = torch.cat([avg_out, max_out], dim=1)
attn = self.sigmoid(self.conv(attn))
return x * attn
class CBAM(nn.Module):
"""Convolutional Block Attention Module"""
def __init__(self, channels, reduction=16):
super().__init__()
self.channel_attn = ChannelAttention(channels, reduction)
self.spatial_attn = SpatialAttention()
def forward(self, x):
x = self.channel_attn(x)
x = self.spatial_attn(x)
return x
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
"""Convolutional block with Conv2d -> BatchNorm -> ReLU""" """Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
super().__init__() super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
self.bn = nn.BatchNorm2d(out_channels) self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True) self.relu = nn.LeakyReLU(0.1, inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x): def forward(self, x):
return self.relu(self.bn(self.conv(x))) return self.dropout(self.relu(self.bn(self.conv(x))))
class ResidualConvBlock(nn.Module):
"""Residual convolutional block for better gradient flow"""
def __init__(self, channels, dropout=0.0):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.LeakyReLU(0.1, inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.dropout(out)
out = self.bn2(self.conv2(out))
out = out + residual
return self.relu(out)
class DownBlock(nn.Module): class DownBlock(nn.Module):
"""Downsampling block with two conv blocks and max pooling""" """Downsampling block with conv blocks, residual connection, attention, and max pooling"""
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels, dropout=0.1):
super().__init__() super().__init__()
self.conv1 = ConvBlock(in_channels, out_channels) self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
self.conv2 = ConvBlock(out_channels, out_channels) self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
self.attention = CBAM(out_channels)
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
def forward(self, x): def forward(self, x):
skip = self.conv2(self.conv1(x)) x = self.conv1(x)
x = self.conv2(x)
x = self.residual(x)
skip = self.attention(x)
return self.pool(skip), skip return self.pool(skip), skip
class UpBlock(nn.Module): class UpBlock(nn.Module):
"""Upsampling block with transposed conv and two conv blocks""" """Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels, dropout=0.1):
super().__init__() super().__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.conv1 = ConvBlock(in_channels, out_channels) # in_channels because of concatenation # After concat: out_channels (from upconv) + in_channels (from skip)
self.conv2 = ConvBlock(out_channels, out_channels) self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
self.attention = CBAM(out_channels)
def forward(self, x, skip): def forward(self, x, skip):
x = self.up(x) x = self.up(x)
# Handle dimension mismatch by interpolating x to match skip's size # Handle dimension mismatch by interpolating x to match skip's size
if x.shape[2:] != skip.shape[2:]: if x.shape[2:] != skip.shape[2:]:
x = nn.functional.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, skip], dim=1) x = torch.cat([x, skip], dim=1)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.residual(x)
x = self.attention(x)
return x return x
class MyModel(nn.Module): class MyModel(nn.Module):
"""U-Net style architecture for image inpainting""" """Improved U-Net style architecture for image inpainting with attention and residual connections"""
def __init__(self, n_in_channels: int, base_channels: int = 64): def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
super().__init__() super().__init__()
# Initial convolution # Initial convolution with larger receptive field
self.init_conv = ConvBlock(n_in_channels, base_channels) self.init_conv = nn.Sequential(
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
ConvBlock(base_channels, base_channels),
ResidualConvBlock(base_channels)
)
# Encoder (downsampling path) # Encoder (downsampling path)
self.down1 = DownBlock(base_channels, base_channels * 2) self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
self.down2 = DownBlock(base_channels * 2, base_channels * 4) self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
self.down3 = DownBlock(base_channels * 4, base_channels * 8) self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
# Bottleneck # Bottleneck with multiple residual blocks
self.bottleneck1 = ConvBlock(base_channels * 8, base_channels * 16) self.bottleneck = nn.Sequential(
self.bottleneck2 = ConvBlock(base_channels * 16, base_channels * 16) ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
ResidualConvBlock(base_channels * 16, dropout=dropout),
ResidualConvBlock(base_channels * 16, dropout=dropout),
ResidualConvBlock(base_channels * 16, dropout=dropout),
CBAM(base_channels * 16)
)
# Decoder (upsampling path) # Decoder (upsampling path)
self.up1 = UpBlock(base_channels * 16, base_channels * 8) self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
self.up2 = UpBlock(base_channels * 8, base_channels * 4) self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
self.up3 = UpBlock(base_channels * 4, base_channels * 2) self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
# Final upsampling and output # Final refinement layers
self.final_up = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2) self.final_conv = nn.Sequential(
self.final_conv1 = ConvBlock(base_channels * 2, base_channels) ConvBlock(base_channels * 2, base_channels),
self.final_conv2 = ConvBlock(base_channels, base_channels) ResidualConvBlock(base_channels),
ConvBlock(base_channels, base_channels)
)
# Output layer # Output layer with smooth transition
self.output = nn.Conv2d(base_channels, 3, kernel_size=1) self.output = nn.Sequential(
self.sigmoid = nn.Sigmoid() # To ensure output is in [0, 1] range nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
nn.Sigmoid() # Ensure output is in [0, 1] range
)
# Apply weight initialization
self.apply(init_weights)
def forward(self, x): def forward(self, x):
# Initial convolution # Initial convolution
@@ -87,27 +199,26 @@ class MyModel(nn.Module):
x1, skip1 = self.down1(x0) x1, skip1 = self.down1(x0)
x2, skip2 = self.down2(x1) x2, skip2 = self.down2(x1)
x3, skip3 = self.down3(x2) x3, skip3 = self.down3(x2)
x4, skip4 = self.down4(x3)
# Bottleneck # Bottleneck
x = self.bottleneck1(x3) x = self.bottleneck(x4)
x = self.bottleneck2(x)
# Decoder with skip connections # Decoder with skip connections
x = self.up1(x, skip3) x = self.up1(x, skip4)
x = self.up2(x, skip2) x = self.up2(x, skip3)
x = self.up3(x, skip1) x = self.up3(x, skip2)
x = self.up4(x, skip1)
# Final layers
x = self.final_up(x)
# Handle dimension mismatch for final concatenation # Handle dimension mismatch for final concatenation
if x.shape[2:] != x0.shape[2:]: if x.shape[2:] != x0.shape[2:]:
x = nn.functional.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False) x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
# Concatenate with initial features for better detail preservation
x = torch.cat([x, x0], dim=1) x = torch.cat([x, x0], dim=1)
x = self.final_conv1(x) x = self.final_conv(x)
x = self.final_conv2(x)
# Output # Output
x = self.output(x) x = self.output(x)
x = self.sigmoid(x)
return x return x

View File

@@ -23,31 +23,32 @@ if __name__ == '__main__':
config_dict['results_path'] = os.path.join(project_root, "results") config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset") config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None config_dict['device'] = None
config_dict['learningrate'] = 5e-4 # Slightly lower for more stable training config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
config_dict['weight_decay'] = 1e-5 # default is 0 config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
config_dict['n_updates'] = 200 config_dict['n_updates'] = 300 # More updates for better convergence
config_dict['batchsize'] = 16 # Reduced due to larger model config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
config_dict['early_stopping_patience'] = 5 # More patience for complex model config_dict['early_stopping_patience'] = 10 # More patience for complex model
config_dict['use_wandb'] = False config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 10 config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 100 config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 100 config_dict['plot_at'] = 300
config_dict['validate_at'] = 100 config_dict['validate_at'] = 300 # Validate more frequently
network_config = { network_config = {
'n_in_channels': 4, 'n_in_channels': 4,
'base_channels': 32 # Start with 32, can increase to 64 for even better results 'base_channels': 48, # Good balance between capacity and memory
'dropout': 0.1 # Regularization
} }
config_dict['network_config'] = network_config config_dict['network_config'] = network_config
train(**config_dict) rmse_value = train(**config_dict)
testset_path = os.path.join(project_root, "data", "challenge_testset.npz") testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt") state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
save_path = os.path.join(config_dict['results_path'], "testset", "my_submission_name.npz") save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
plot_path = os.path.join(config_dict['results_path'], "testset", "plots") plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
# Comment out, if predictions are required # Comment out, if predictions are required
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20) create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)

View File

@@ -9,15 +9,52 @@ from architecture import MyModel
from utils import plot, evaluate_model from utils import plot, evaluate_model
import torch import torch
import torch.nn as nn
import numpy as np import numpy as np
import os import os
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Subset from torch.utils.data import Subset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import wandb import wandb
class CombinedLoss(nn.Module):
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
super().__init__()
self.mse_weight = mse_weight
self.l1_weight = l1_weight
self.edge_weight = edge_weight
self.mse = nn.MSELoss()
self.l1 = nn.L1Loss()
# Sobel filters for edge detection
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
def edge_loss(self, pred, target):
"""Compute edge-aware loss using Sobel filters"""
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
return edge_loss
def forward(self, pred, target):
mse_loss = self.mse(pred, target)
l1_loss = self.l1(pred, target)
edge_loss = self.edge_loss(pred, target)
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
return total_loss
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate, def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize, weight_decay, n_updates, use_wandb, print_train_stats_at, print_stats_at, plot_at, validate_at, batchsize,
network_config: dict): network_config: dict):
@@ -74,11 +111,15 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
network.to(device) network.to(device)
network.train() network.train()
# defining the loss # defining the loss - combined loss for better reconstruction
mse_loss = torch.nn.MSELoss() combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
mse_loss = torch.nn.MSELoss() # Keep for evaluation
# defining the optimizer # defining the optimizer with AdamW for better weight decay handling
optimizer = torch.optim.Adam(network.parameters(), lr=learningrate, weight_decay=weight_decay) optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
# Learning rate scheduler for better convergence
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
if use_wandb: if use_wandb:
wandb.watch(network, mse_loss, log="all", log_freq=10) wandb.watch(network, mse_loss, log="all", log_freq=10)
@@ -105,11 +146,15 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
output = network(input) output = network(input)
loss = mse_loss(output, target) loss = combined_loss(output, target)
loss.backward() loss.backward()
# Gradient clipping for training stability
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
optimizer.step() optimizer.step()
scheduler.step(i + len(loss_list) / len(dataloader_train))
loss_list.append(loss.item()) loss_list.append(loss.item())
@@ -164,3 +209,5 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
wandb.summary["testset/loss"] = testset_loss wandb.summary["testset/loss"] = testset_loss
wandb.summary["testset/RMSE"] = testset_rmse wandb.summary["testset/RMSE"] = testset_rmse
wandb.finish() wandb.finish()
return testset_rmse

View File

@@ -81,7 +81,7 @@ def read_compressed_file(file_path: str):
return input_arrays, known_arrays return input_arrays, known_arrays
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20): def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
""" """
Here, one might needs to adjust the code based on the used preprocessing Here, one might needs to adjust the code based on the used preprocessing
""" """
@@ -128,6 +128,11 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
"predictions": predictions "predictions": predictions
} }
# Modify save_path to include RMSE value if provided
if rmse_value is not None:
base_path = save_path.rsplit('.npz', 1)[0]
save_path = f"{base_path}-{rmse_value:.4f}.npz"
np.savez_compressed(save_path, **data) np.savez_compressed(save_path, **data)
print(f"Predictions saved at {save_path}") print(f"Predictions saved at {save_path}")