fixed nan prediction errors

This commit is contained in:
2026-01-31 21:05:32 +01:00
parent fd81f3ce2e
commit 4af674b79d
5 changed files with 149 additions and 28 deletions

View File

@@ -1,10 +1,12 @@
{
"learningrate": 0.005,
"weight_decay": 5e-05,
"n_updates": 35000,
"plot_at": 500,
"early_stopping_patience": 20,
"print_stats_at": 200,
"print_train_stats_at": 50,
"validate_at": 400,
"validate_at": 500,
"commands": {
"save_checkpoint": false,
"run_test_validation": false,

View File

@@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def init_weights(m):
@@ -52,10 +53,13 @@ class EfficientChannelAttention(nn.Module):
def forward(self, x):
# Global pooling
y = self.avg_pool(x)
# 1D convolution on channel dimension
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
# 1D convolution on channel dimension - add safety checks
if y.size(-1) == 1 and y.size(-2) == 1:
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
y = torch.clamp(y, min=0.0, max=1.0) # Ensure valid range
return x * y.expand_as(x)
return x
class SpatialAttention(nn.Module):
@@ -104,8 +108,11 @@ class SelfAttention(nn.Module):
key = self.key(x).view(batch_size, -1, H * W)
value = self.value(x).view(batch_size, -1, H * W)
# Attention map
attention = self.softmax(torch.bmm(query, key))
# Attention map with numerical stability
attention_logits = torch.bmm(query, key)
# Scale for numerical stability
attention_logits = attention_logits / math.sqrt(query.size(-1))
attention = self.softmax(attention_logits)
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, H, W)
@@ -126,7 +133,8 @@ class ConvBlock(nn.Module):
)
else:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
self.bn = nn.BatchNorm2d(out_channels)
# Add momentum and eps for numerical stability
self.bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5, track_running_stats=True)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
@@ -250,7 +258,7 @@ class MyModel(nn.Module):
# Fusion of mask and image features
self.fusion = nn.Sequential(
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
nn.BatchNorm2d(base_channels),
nn.BatchNorm2d(base_channels, momentum=0.1, eps=1e-5, track_running_stats=True),
nn.LeakyReLU(0.2, inplace=True)
)
@@ -302,25 +310,58 @@ class MyModel(nn.Module):
image = x[:, :3, :, :]
mask = x[:, 3:4, :, :]
# Clamp inputs to valid range
image = torch.clamp(image, 0.0, 1.0)
mask = torch.clamp(mask, 0.0, 1.0)
# Process mask and image separately
mask_features = self.mask_conv(mask)
image_features = self.image_conv(image)
# Safety check after initial processing
if not torch.isfinite(mask_features).all():
mask_features = torch.nan_to_num(mask_features, nan=0.0, posinf=1.0, neginf=-1.0)
if not torch.isfinite(image_features).all():
image_features = torch.nan_to_num(image_features, nan=0.0, posinf=1.0, neginf=-1.0)
# Fuse features
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
if not torch.isfinite(x0).all():
x0 = torch.nan_to_num(x0, nan=0.0, posinf=1.0, neginf=-1.0)
# Encoder
x1, skip1 = self.down1(x0)
if not torch.isfinite(x1).all():
x1 = torch.nan_to_num(x1, nan=0.0, posinf=1.0, neginf=-1.0)
skip1 = torch.nan_to_num(skip1, nan=0.0, posinf=1.0, neginf=-1.0)
x2, skip2 = self.down2(x1)
if not torch.isfinite(x2).all():
x2 = torch.nan_to_num(x2, nan=0.0, posinf=1.0, neginf=-1.0)
skip2 = torch.nan_to_num(skip2, nan=0.0, posinf=1.0, neginf=-1.0)
x3, skip3 = self.down3(x2)
if not torch.isfinite(x3).all():
x3 = torch.nan_to_num(x3, nan=0.0, posinf=1.0, neginf=-1.0)
skip3 = torch.nan_to_num(skip3, nan=0.0, posinf=1.0, neginf=-1.0)
# Bottleneck
x = self.bottleneck(x3)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Decoder with skip connections
x = self.up1(x, skip3)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
x = self.up2(x, skip2)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
x = self.up3(x, skip1)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Handle dimension mismatch for final fusion
if x.shape[2:] != x0.shape[2:]:
@@ -329,12 +370,19 @@ class MyModel(nn.Module):
# Multi-scale fusion with initial features
x = torch.cat([x, x0], dim=1)
x = self.multiscale_fusion(x)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Pre-output processing
x = self.pre_output(x)
if not torch.isfinite(x).all():
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
# Concatenate with original masked image for residual learning
x = torch.cat([x, image], dim=1)
x = self.output(x)
# Final safety clamp
x = torch.clamp(x, 0.0, 1.0)
return x

View File

@@ -24,17 +24,17 @@ if __name__ == '__main__':
config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
config_dict['n_updates'] = 35000 # Extended training for better convergence
config_dict['batchsize'] = 48 # Reduced for larger model and mixed precision
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
config_dict['early_stopping_patience'] = 20 # More patience for complex model
config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 50
config_dict['print_stats_at'] = 200
config_dict['plot_at'] = 500
config_dict['validate_at'] = 400 # More frequent validation
config_dict['validate_at'] = 500 # More frequent validation
network_config = {
'n_in_channels': 4,
@@ -69,7 +69,7 @@ if __name__ == '__main__':
print(" - save_checkpoint: Save model at current step")
print(" - run_test_validation: Run validation on final test set")
print(" - generate_predictions: Generate predictions on challenge testset")
print("\nChanges will be applied within 100 steps.")
print("\nChanges will be applied within 50 steps.")
print("="*60)
print()

View File

@@ -32,7 +32,8 @@ def load_runtime_config(config_path, current_params):
# Update modifiable parameters
updated = False
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
'print_stats_at', 'print_train_stats_at', 'validate_at']
'print_stats_at', 'print_train_stats_at', 'validate_at',
'learningrate', 'weight_decay']
for key in modifiable_keys:
if key in new_config and new_config[key] != current_params.get(key):
@@ -150,6 +151,12 @@ class CombinedLoss(nn.Module):
def forward(self, pred, target):
# Clamp predictions to valid range
pred = torch.clamp(pred, 0.0, 1.0)
target = torch.clamp(target, 0.0, 1.0)
# Check for NaN in inputs
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
print("Warning: NaN detected in loss inputs")
return (torch.tensor(float('nan'), device=pred.device),) * 4
# Primary loss: MSE (equivalent to RMSE but more stable)
mse = self.mse_loss(pred, target)
@@ -158,6 +165,9 @@ class CombinedLoss(nn.Module):
if self.use_perceptual:
# Optional small perceptual component for texture quality
perceptual = self.perceptual_loss(pred, target)
# Check perceptual loss validity
if not torch.isfinite(perceptual):
perceptual = torch.tensor(0.0, device=pred.device)
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
else:
# Pure MSE optimization
@@ -168,6 +178,9 @@ class CombinedLoss(nn.Module):
if not torch.isfinite(total_loss):
# Return MSE only as fallback
total_loss = mse
if not torch.isfinite(total_loss):
print("Warning: MSE is NaN")
return (torch.tensor(float('nan'), device=pred.device),) * 4
return total_loss, perceptual, mse, rmse
@@ -237,7 +250,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# defining the loss - Optimized for RMSE evaluation
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality
combined_loss = CombinedLoss(device, use_perceptual=True, perceptual_weight=0.05).to(device)
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
mse_loss = torch.nn.MSELoss() # Keep for evaluation
# defining the optimizer with AdamW for better weight decay handling
@@ -270,6 +284,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# Save runtime configuration to JSON file for dynamic updates
config_json_path = os.path.join(results_path, "runtime_config.json")
runtime_params = {
'learningrate': learningrate,
'weight_decay': weight_decay,
'n_updates': n_updates,
'plot_at': plot_at,
'early_stopping_patience': early_stopping_patience,
@@ -307,6 +323,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
print_train_stats_at = runtime_params['print_train_stats_at']
validate_at = runtime_params['validate_at']
# Update optimizer parameters if changed
if 'learningrate' in runtime_params:
new_lr = runtime_params['learningrate']
current_lr = optimizer.param_groups[0]['lr']
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
if 'weight_decay' in runtime_params:
new_wd = runtime_params['weight_decay']
current_wd = optimizer.param_groups[0]['weight_decay']
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['weight_decay'] = new_wd
# Execute runtime commands
commands = runtime_params.get('commands', {})
@@ -388,10 +419,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
print(f"Skipping step {i+1}: NaN gradients detected")
optimizer.zero_grad()
scaler.update()
# Reset scaler if NaN persists
if (i + 1) % 10 == 0:
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
continue
# More aggressive gradient clipping for stability
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
# Skip update if gradient norm is too large
if grad_norm > 100.0:
@@ -415,8 +449,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# Check for NaN in gradients
has_nan = False
for param in network.parameters():
for name, param in network.named_parameters():
if param.grad is not None and not torch.isfinite(param.grad).all():
print(f"NaN gradient detected in {name}")
has_nan = True
break
@@ -426,7 +461,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
continue
# More aggressive gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
if grad_norm > 100.0:
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")

View File

@@ -18,12 +18,14 @@ def plot(inputs, targets, predictions, path, update):
os.makedirs(path, exist_ok=True)
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
for i in range(5):
# Only plot up to min(5, batch_size) images
num_images = min(5, inputs.shape[0])
for i in range(num_images):
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
ax.clear()
ax.set_title(title)
img = data[i:i + 1:, 0:3, :, :]
img = np.squeeze(img)
img = data[i, 0:3, :, :]
img = np.transpose(img, (1, 2, 0))
img = np.clip(img, 0, 1)
ax.imshow(img)
@@ -54,24 +56,58 @@ def testset_plot(input_array, output_array, path, index):
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
"""Returnse MSE and RMSE of the model on the provided dataloader"""
"""Returns MSE and RMSE of the model on the provided dataloader"""
# Save training mode and switch to eval
was_training = network.training
network.eval()
loss = 0.0
num_batches = 0
with torch.no_grad():
for data in dataloader:
input_array, target = data
input_array = input_array.to(device)
target = target.to(device)
# Check input validity
if not torch.isfinite(input_array).all() or not torch.isfinite(target).all():
print(f"Warning: NaN detected in evaluation inputs")
continue
outputs = network(input_array)
# Clamp outputs to valid range
outputs = torch.clamp(outputs, 0.0, 1.0)
# Check for NaN in outputs
if not torch.isfinite(outputs).all():
print(f"Warning: NaN detected in model outputs during evaluation")
continue
batch_loss = loss_fn(outputs, target).item()
# Check for NaN in loss
if not np.isfinite(batch_loss):
print(f"Warning: NaN detected in loss during evaluation")
continue
loss += batch_loss
num_batches += 1
if num_batches == 0:
print("Error: No valid batches in evaluation")
if was_training:
network.train()
return float('nan'), float('nan')
loss = loss / num_batches
rmse = 255.0 * np.sqrt(loss)
loss += loss_fn(outputs, target).item()
# Restore training mode
if was_training:
network.train()
loss = loss / len(dataloader)
network.train()
return loss, 255.0 * np.sqrt(loss)
return loss, rmse
def read_compressed_file(file_path: str):