fixed nan prediction errors

This commit is contained in:
2026-01-31 21:05:32 +01:00
parent fd81f3ce2e
commit 4af674b79d
5 changed files with 149 additions and 28 deletions

View File

@@ -32,7 +32,8 @@ def load_runtime_config(config_path, current_params):
# Update modifiable parameters
updated = False
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
'print_stats_at', 'print_train_stats_at', 'validate_at']
'print_stats_at', 'print_train_stats_at', 'validate_at',
'learningrate', 'weight_decay']
for key in modifiable_keys:
if key in new_config and new_config[key] != current_params.get(key):
@@ -150,6 +151,12 @@ class CombinedLoss(nn.Module):
def forward(self, pred, target):
# Clamp predictions to valid range
pred = torch.clamp(pred, 0.0, 1.0)
target = torch.clamp(target, 0.0, 1.0)
# Check for NaN in inputs
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
print("Warning: NaN detected in loss inputs")
return (torch.tensor(float('nan'), device=pred.device),) * 4
# Primary loss: MSE (equivalent to RMSE but more stable)
mse = self.mse_loss(pred, target)
@@ -158,6 +165,9 @@ class CombinedLoss(nn.Module):
if self.use_perceptual:
# Optional small perceptual component for texture quality
perceptual = self.perceptual_loss(pred, target)
# Check perceptual loss validity
if not torch.isfinite(perceptual):
perceptual = torch.tensor(0.0, device=pred.device)
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
else:
# Pure MSE optimization
@@ -168,6 +178,9 @@ class CombinedLoss(nn.Module):
if not torch.isfinite(total_loss):
# Return MSE only as fallback
total_loss = mse
if not torch.isfinite(total_loss):
print("Warning: MSE is NaN")
return (torch.tensor(float('nan'), device=pred.device),) * 4
return total_loss, perceptual, mse, rmse
@@ -237,7 +250,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# defining the loss - Optimized for RMSE evaluation
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality
combined_loss = CombinedLoss(device, use_perceptual=True, perceptual_weight=0.05).to(device)
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
mse_loss = torch.nn.MSELoss() # Keep for evaluation
# defining the optimizer with AdamW for better weight decay handling
@@ -270,6 +284,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# Save runtime configuration to JSON file for dynamic updates
config_json_path = os.path.join(results_path, "runtime_config.json")
runtime_params = {
'learningrate': learningrate,
'weight_decay': weight_decay,
'n_updates': n_updates,
'plot_at': plot_at,
'early_stopping_patience': early_stopping_patience,
@@ -307,6 +323,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
print_train_stats_at = runtime_params['print_train_stats_at']
validate_at = runtime_params['validate_at']
# Update optimizer parameters if changed
if 'learningrate' in runtime_params:
new_lr = runtime_params['learningrate']
current_lr = optimizer.param_groups[0]['lr']
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
if 'weight_decay' in runtime_params:
new_wd = runtime_params['weight_decay']
current_wd = optimizer.param_groups[0]['weight_decay']
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
for param_group in optimizer.param_groups:
param_group['weight_decay'] = new_wd
# Execute runtime commands
commands = runtime_params.get('commands', {})
@@ -388,10 +419,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
print(f"Skipping step {i+1}: NaN gradients detected")
optimizer.zero_grad()
scaler.update()
# Reset scaler if NaN persists
if (i + 1) % 10 == 0:
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
continue
# More aggressive gradient clipping for stability
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
# Skip update if gradient norm is too large
if grad_norm > 100.0:
@@ -415,8 +449,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
# Check for NaN in gradients
has_nan = False
for param in network.parameters():
for name, param in network.named_parameters():
if param.grad is not None and not torch.isfinite(param.grad).all():
print(f"NaN gradient detected in {name}")
has_nan = True
break
@@ -426,7 +461,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
continue
# More aggressive gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
if grad_norm > 100.0:
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")