fixed nan prediction errors
This commit is contained in:
@@ -32,7 +32,8 @@ def load_runtime_config(config_path, current_params):
|
||||
# Update modifiable parameters
|
||||
updated = False
|
||||
modifiable_keys = ['n_updates', 'plot_at', 'early_stopping_patience',
|
||||
'print_stats_at', 'print_train_stats_at', 'validate_at']
|
||||
'print_stats_at', 'print_train_stats_at', 'validate_at',
|
||||
'learningrate', 'weight_decay']
|
||||
|
||||
for key in modifiable_keys:
|
||||
if key in new_config and new_config[key] != current_params.get(key):
|
||||
@@ -150,6 +151,12 @@ class CombinedLoss(nn.Module):
|
||||
def forward(self, pred, target):
|
||||
# Clamp predictions to valid range
|
||||
pred = torch.clamp(pred, 0.0, 1.0)
|
||||
target = torch.clamp(target, 0.0, 1.0)
|
||||
|
||||
# Check for NaN in inputs
|
||||
if not torch.isfinite(pred).all() or not torch.isfinite(target).all():
|
||||
print("Warning: NaN detected in loss inputs")
|
||||
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||
|
||||
# Primary loss: MSE (equivalent to RMSE but more stable)
|
||||
mse = self.mse_loss(pred, target)
|
||||
@@ -158,6 +165,9 @@ class CombinedLoss(nn.Module):
|
||||
if self.use_perceptual:
|
||||
# Optional small perceptual component for texture quality
|
||||
perceptual = self.perceptual_loss(pred, target)
|
||||
# Check perceptual loss validity
|
||||
if not torch.isfinite(perceptual):
|
||||
perceptual = torch.tensor(0.0, device=pred.device)
|
||||
total_loss = self.mse_weight * mse + self.perceptual_weight * perceptual
|
||||
else:
|
||||
# Pure MSE optimization
|
||||
@@ -168,6 +178,9 @@ class CombinedLoss(nn.Module):
|
||||
if not torch.isfinite(total_loss):
|
||||
# Return MSE only as fallback
|
||||
total_loss = mse
|
||||
if not torch.isfinite(total_loss):
|
||||
print("Warning: MSE is NaN")
|
||||
return (torch.tensor(float('nan'), device=pred.device),) * 4
|
||||
|
||||
return total_loss, perceptual, mse, rmse
|
||||
|
||||
@@ -237,7 +250,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
# defining the loss - Optimized for RMSE evaluation
|
||||
# Set use_perceptual=False for pure MSE training, or keep True with 5% weight for texture quality
|
||||
combined_loss = CombinedLoss(device, use_perceptual=True, perceptual_weight=0.05).to(device)
|
||||
# TEMPORARILY DISABLED due to NaN issues - re-enable once training is stable
|
||||
combined_loss = CombinedLoss(device, use_perceptual=False, perceptual_weight=0.0).to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
@@ -270,6 +284,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
# Save runtime configuration to JSON file for dynamic updates
|
||||
config_json_path = os.path.join(results_path, "runtime_config.json")
|
||||
runtime_params = {
|
||||
'learningrate': learningrate,
|
||||
'weight_decay': weight_decay,
|
||||
'n_updates': n_updates,
|
||||
'plot_at': plot_at,
|
||||
'early_stopping_patience': early_stopping_patience,
|
||||
@@ -307,6 +323,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
print_train_stats_at = runtime_params['print_train_stats_at']
|
||||
validate_at = runtime_params['validate_at']
|
||||
|
||||
# Update optimizer parameters if changed
|
||||
if 'learningrate' in runtime_params:
|
||||
new_lr = runtime_params['learningrate']
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
if abs(new_lr - current_lr) > 1e-10: # Float comparison with tolerance
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
|
||||
if 'weight_decay' in runtime_params:
|
||||
new_wd = runtime_params['weight_decay']
|
||||
current_wd = optimizer.param_groups[0]['weight_decay']
|
||||
if abs(new_wd - current_wd) > 1e-10: # Float comparison with tolerance
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['weight_decay'] = new_wd
|
||||
|
||||
# Execute runtime commands
|
||||
commands = runtime_params.get('commands', {})
|
||||
|
||||
@@ -388,10 +419,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
print(f"Skipping step {i+1}: NaN gradients detected")
|
||||
optimizer.zero_grad()
|
||||
scaler.update()
|
||||
# Reset scaler if NaN persists
|
||||
if (i + 1) % 10 == 0:
|
||||
scaler = torch.amp.GradScaler('cuda', init_scale=2048.0, growth_interval=100)
|
||||
continue
|
||||
|
||||
# More aggressive gradient clipping for stability
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
# Skip update if gradient norm is too large
|
||||
if grad_norm > 100.0:
|
||||
@@ -415,8 +449,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
# Check for NaN in gradients
|
||||
has_nan = False
|
||||
for param in network.parameters():
|
||||
for name, param in network.named_parameters():
|
||||
if param.grad is not None and not torch.isfinite(param.grad).all():
|
||||
print(f"NaN gradient detected in {name}")
|
||||
has_nan = True
|
||||
break
|
||||
|
||||
@@ -426,7 +461,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
continue
|
||||
|
||||
# More aggressive gradient clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
if grad_norm > 100.0:
|
||||
print(f"Skipping step {i+1}: Gradient norm too large: {grad_norm:.2f}")
|
||||
|
||||
Reference in New Issue
Block a user