fixed nan prediction errors
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
{
|
||||
"learningrate": 0.005,
|
||||
"weight_decay": 5e-05,
|
||||
"n_updates": 35000,
|
||||
"plot_at": 500,
|
||||
"early_stopping_patience": 20,
|
||||
"print_stats_at": 200,
|
||||
"print_train_stats_at": 50,
|
||||
"validate_at": 400,
|
||||
"validate_at": 500,
|
||||
"commands": {
|
||||
"save_checkpoint": false,
|
||||
"run_test_validation": false,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
@@ -52,10 +53,13 @@ class EfficientChannelAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
# Global pooling
|
||||
y = self.avg_pool(x)
|
||||
# 1D convolution on channel dimension
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
y = self.sigmoid(y)
|
||||
return x * y.expand_as(x)
|
||||
# 1D convolution on channel dimension - add safety checks
|
||||
if y.size(-1) == 1 and y.size(-2) == 1:
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
y = self.sigmoid(y)
|
||||
y = torch.clamp(y, min=0.0, max=1.0) # Ensure valid range
|
||||
return x * y.expand_as(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
@@ -104,8 +108,11 @@ class SelfAttention(nn.Module):
|
||||
key = self.key(x).view(batch_size, -1, H * W)
|
||||
value = self.value(x).view(batch_size, -1, H * W)
|
||||
|
||||
# Attention map
|
||||
attention = self.softmax(torch.bmm(query, key))
|
||||
# Attention map with numerical stability
|
||||
attention_logits = torch.bmm(query, key)
|
||||
# Scale for numerical stability
|
||||
attention_logits = attention_logits / math.sqrt(query.size(-1))
|
||||
attention = self.softmax(attention_logits)
|
||||
out = torch.bmm(value, attention.permute(0, 2, 1))
|
||||
out = out.view(batch_size, C, H, W)
|
||||
|
||||
@@ -126,7 +133,8 @@ class ConvBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
# Add momentum and eps for numerical stability
|
||||
self.bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5, track_running_stats=True)
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
@@ -250,7 +258,7 @@ class MyModel(nn.Module):
|
||||
# Fusion of mask and image features
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(base_channels + base_channels // 4, base_channels, 1),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.BatchNorm2d(base_channels, momentum=0.1, eps=1e-5, track_running_stats=True),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
@@ -302,25 +310,58 @@ class MyModel(nn.Module):
|
||||
image = x[:, :3, :, :]
|
||||
mask = x[:, 3:4, :, :]
|
||||
|
||||
# Clamp inputs to valid range
|
||||
image = torch.clamp(image, 0.0, 1.0)
|
||||
mask = torch.clamp(mask, 0.0, 1.0)
|
||||
|
||||
# Process mask and image separately
|
||||
mask_features = self.mask_conv(mask)
|
||||
image_features = self.image_conv(image)
|
||||
|
||||
# Safety check after initial processing
|
||||
if not torch.isfinite(mask_features).all():
|
||||
mask_features = torch.nan_to_num(mask_features, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
if not torch.isfinite(image_features).all():
|
||||
image_features = torch.nan_to_num(image_features, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Fuse features
|
||||
x0 = self.fusion(torch.cat([image_features, mask_features], dim=1))
|
||||
if not torch.isfinite(x0).all():
|
||||
x0 = torch.nan_to_num(x0, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Encoder
|
||||
x1, skip1 = self.down1(x0)
|
||||
if not torch.isfinite(x1).all():
|
||||
x1 = torch.nan_to_num(x1, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
skip1 = torch.nan_to_num(skip1, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
x2, skip2 = self.down2(x1)
|
||||
if not torch.isfinite(x2).all():
|
||||
x2 = torch.nan_to_num(x2, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
skip2 = torch.nan_to_num(skip2, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
x3, skip3 = self.down3(x2)
|
||||
if not torch.isfinite(x3).all():
|
||||
x3 = torch.nan_to_num(x3, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
skip3 = torch.nan_to_num(skip3, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Bottleneck
|
||||
x = self.bottleneck(x3)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Decoder with skip connections
|
||||
x = self.up1(x, skip3)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
x = self.up2(x, skip2)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
x = self.up3(x, skip1)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Handle dimension mismatch for final fusion
|
||||
if x.shape[2:] != x0.shape[2:]:
|
||||
@@ -329,12 +370,19 @@ class MyModel(nn.Module):
|
||||
# Multi-scale fusion with initial features
|
||||
x = torch.cat([x, x0], dim=1)
|
||||
x = self.multiscale_fusion(x)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Pre-output processing
|
||||
x = self.pre_output(x)
|
||||
if not torch.isfinite(x).all():
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
|
||||
# Concatenate with original masked image for residual learning
|
||||
x = torch.cat([x, image], dim=1)
|
||||
x = self.output(x)
|
||||
|
||||
# Final safety clamp
|
||||
x = torch.clamp(x, 0.0, 1.0)
|
||||
|
||||
return x
|
||||
@@ -24,17 +24,17 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
|
||||
config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
|
||||
config_dict['n_updates'] = 35000 # Extended training for better convergence
|
||||
config_dict['batchsize'] = 48 # Reduced for larger model and mixed precision
|
||||
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
|
||||
config_dict['early_stopping_patience'] = 20 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 50
|
||||
config_dict['print_stats_at'] = 200
|
||||
config_dict['plot_at'] = 500
|
||||
config_dict['validate_at'] = 400 # More frequent validation
|
||||
config_dict['validate_at'] = 500 # More frequent validation
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
@@ -69,7 +69,7 @@ if __name__ == '__main__':
|
||||
print(" - save_checkpoint: Save model at current step")
|
||||
print(" - run_test_validation: Run validation on final test set")
|
||||
print(" - generate_predictions: Generate predictions on challenge testset")
|
||||
print("\nChanges will be applied within 100 steps.")
|
||||
print("\nChanges will be applied within 50 steps.")
|
||||
print("="*60)
|
||||
print()
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ def load_runtime_config(config_path, current_params):
|
||||
# Update modifiable parameters
|
||||
updated = False
|
||||
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
|
||||
'print_stats_at', 'print_train_stats_at', 'validate_at']
|
||||
'print_stats_at', 'print_train_stats_at', 'validate_at',
|
||||
'learningrate', 'weight_decay']
|
||||
|
||||
for key in modifiable_keys:
|
||||
if key in new_config and new_config[key] != current_params.get(key):
|
||||
@@ -150,6 +151,12 @@ class CombinedLoss(nn.Module):
|
||||
def forward(self, pred, target):
|
||||
# Clamp predictions to valid range
|
||||
pred = torch.clamp(pred, 0.0, 1.0)
|
||||
target = torch.clamp(target, 0.0, 1.0)
|
||||
|
||||
# Check for NaN in inputs
|
||||
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
|
||||
print("Warning: NaN detected in loss inputs")
|
||||
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||
|
||||
# Primary loss: MSE (equivalent to RMSE but more stable)
|
||||
mse = self.mse_loss(pred, target)
|
||||
@@ -158,6 +165,9 @@ class CombinedLoss(nn.Module):
|
||||
if self.use_perceptual:
|
||||
# Optional small perceptual component for texture quality
|
||||
perceptual = self.perceptual_loss(pred, target)
|
||||
# Check perceptual loss validity
|
||||
if not torch.isfinite(perceptual):
|
||||
perceptual = torch.tensor(0.0, device=pred.device)
|
||||
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
|
||||
else:
|
||||
# Pure MSE optimization
|
||||
@@ -168,6 +178,9 @@ class CombinedLoss(nn.Module):
|
||||
if not torch.isfinite(total_loss):
|
||||
# Return MSE only as fallback
|
||||
total_loss = mse
|
||||
if not torch.isfinite(total_loss):
|
||||
print("Warning: MSE is NaN")
|
||||
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||
|
||||
return total_loss, perceptual, mse, rmse
|
||||
|
||||
@@ -237,7 +250,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
# defining the loss - Optimized for RMSE evaluation
|
||||
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality
|
||||
combined_loss = CombinedLoss(device, use_perceptual=True, perceptual_weight=0.05).to(device)
|
||||
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
|
||||
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
@@ -270,6 +284,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
# Save runtime configuration to JSON file for dynamic updates
|
||||
config_json_path = os.path.join(results_path, "runtime_config.json")
|
||||
runtime_params = {
|
||||
'learningrate': learningrate,
|
||||
'weight_decay': weight_decay,
|
||||
'n_updates': n_updates,
|
||||
'plot_at': plot_at,
|
||||
'early_stopping_patience': early_stopping_patience,
|
||||
@@ -307,6 +323,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
print_train_stats_at = runtime_params['print_train_stats_at']
|
||||
validate_at = runtime_params['validate_at']
|
||||
|
||||
# Update optimizer parameters if changed
|
||||
if 'learningrate' in runtime_params:
|
||||
new_lr = runtime_params['learningrate']
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
if 'weight_decay' in runtime_params:
|
||||
new_wd = runtime_params['weight_decay']
|
||||
current_wd = optimizer.param_groups[0]['weight_decay']
|
||||
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['weight_decay'] = new_wd
|
||||
|
||||
# Execute runtime commands
|
||||
commands = runtime_params.get('commands', {})
|
||||
|
||||
@@ -388,10 +419,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
print(f"Skipping step {i+1}: NaN gradients detected")
|
||||
optimizer.zero_grad()
|
||||
scaler.update()
|
||||
# Reset scaler if NaN persists
|
||||
if (i + 1) % 10 == 0:
|
||||
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
|
||||
continue
|
||||
|
||||
# More aggressive gradient clipping for stability
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
# Skip update if gradient norm is too large
|
||||
if grad_norm > 100.0:
|
||||
@@ -415,8 +449,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
# Check for NaN in gradients
|
||||
has_nan = False
|
||||
for param in network.parameters():
|
||||
for name, param in network.named_parameters():
|
||||
if param.grad is not None and not torch.isfinite(param.grad).all():
|
||||
print(f"NaN gradient detected in {name}")
|
||||
has_nan = True
|
||||
break
|
||||
|
||||
@@ -426,7 +461,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
continue
|
||||
|
||||
# More aggressive gradient clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
if grad_norm > 100.0:
|
||||
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
|
||||
|
||||
@@ -18,12 +18,14 @@ def plot(inputs, targets, predictions, path, update):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
|
||||
|
||||
for i in range(5):
|
||||
# Only plot up to min(5, batch_size) images
|
||||
num_images = min(5, inputs.shape[0])
|
||||
|
||||
for i in range(num_images):
|
||||
for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]):
|
||||
ax.clear()
|
||||
ax.set_title(title)
|
||||
img = data[i:i + 1:, 0:3, :, :]
|
||||
img = np.squeeze(img)
|
||||
img = data[i, 0:3, :, :]
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
img = np.clip(img, 0, 1)
|
||||
ax.imshow(img)
|
||||
@@ -54,24 +56,58 @@ def testset_plot(input_array, output_array, path, index):
|
||||
|
||||
|
||||
def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device):
|
||||
"""Returnse MSE and RMSE of the model on the provided dataloader"""
|
||||
"""Returns MSE and RMSE of the model on the provided dataloader"""
|
||||
# Save training mode and switch to eval
|
||||
was_training = network.training
|
||||
network.eval()
|
||||
|
||||
loss = 0.0
|
||||
num_batches = 0
|
||||
with torch.no_grad():
|
||||
for data in dataloader:
|
||||
input_array, target = data
|
||||
input_array = input_array.to(device)
|
||||
target = target.to(device)
|
||||
|
||||
# Check input validity
|
||||
if not torch.isfinite(input_array).all() or not torch.isfinite(target).all():
|
||||
print(f"Warning: NaN detected in evaluation inputs")
|
||||
continue
|
||||
|
||||
outputs = network(input_array)
|
||||
|
||||
# Clamp outputs to valid range
|
||||
outputs = torch.clamp(outputs, 0.0, 1.0)
|
||||
|
||||
# Check for NaN in outputs
|
||||
if not torch.isfinite(outputs).all():
|
||||
print(f"Warning: NaN detected in model outputs during evaluation")
|
||||
continue
|
||||
|
||||
batch_loss = loss_fn(outputs, target).item()
|
||||
|
||||
# Check for NaN in loss
|
||||
if not np.isfinite(batch_loss):
|
||||
print(f"Warning: NaN detected in loss during evaluation")
|
||||
continue
|
||||
|
||||
loss += batch_loss
|
||||
num_batches += 1
|
||||
|
||||
if num_batches == 0:
|
||||
print("Error: No valid batches in evaluation")
|
||||
if was_training:
|
||||
network.train()
|
||||
return float('nan'), float('nan')
|
||||
|
||||
loss = loss / num_batches
|
||||
rmse = 255.0 * np.sqrt(loss)
|
||||
|
||||
loss += loss_fn(outputs, target).item()
|
||||
# Restore training mode
|
||||
if was_training:
|
||||
network.train()
|
||||
|
||||
loss = loss / len(dataloader)
|
||||
|
||||
network.train()
|
||||
|
||||
return loss, 255.0 * np.sqrt(loss)
|
||||
return loss, rmse
|
||||
|
||||
|
||||
def read_compressed_file(file_path: str):
|
||||
|
||||
Reference in New Issue
Block a user