improve baseline

This commit is contained in:
2026-01-24 16:15:34 +01:00
parent 3b1d3c0497
commit 57026695d4
10 changed files with 228 additions and 63 deletions

View File

@@ -23,31 +23,32 @@ if __name__ == '__main__':
config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None
config_dict['learningrate'] = 5e-4 # Slightly lower for more stable training
config_dict['weight_decay'] = 1e-5 # default is 0
config_dict['n_updates'] = 200
config_dict['batchsize'] = 16 # Reduced due to larger model
config_dict['early_stopping_patience'] = 5 # More patience for complex model
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
config_dict['n_updates'] = 300 # More updates for better convergence
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
config_dict['early_stopping_patience'] = 10 # More patience for complex model
config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 100
config_dict['validate_at'] = 100
config_dict['plot_at'] = 300
config_dict['validate_at'] = 300 # Validate more frequently
network_config = {
'n_in_channels': 4,
'base_channels': 32 # Start with 32, can increase to 64 for even better results
'base_channels': 48, # Good balance between capacity and memory
'dropout': 0.1 # Regularization
}
config_dict['network_config'] = network_config
train(**config_dict)
rmse_value = train(**config_dict)
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
save_path = os.path.join(config_dict['results_path'], "testset", "my_submission_name.npz")
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
# Comment out, if predictions are required
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20)
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)