Added baseline

This commit is contained in:
2026-01-23 11:15:55 +01:00
parent 9a2092cbde
commit 09d1911feb
9 changed files with 149 additions and 19 deletions

View File

@@ -17,30 +17,34 @@ if __name__ == '__main__':
config_dict['seed'] = 42
config_dict['testset_ratio'] = 0.1
config_dict['validset_ratio'] = 0.1
config_dict['results_path'] = os.path.join("results")
config_dict['data_path'] = os.path.join("data", "dataset")
# Get the absolute path based on the script's location
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None
config_dict['learningrate'] = 1e-3
config_dict['learningrate'] = 5e-4 # Slightly lower for more stable training
config_dict['weight_decay'] = 1e-5 # default is 0
config_dict['n_updates'] = 50000
config_dict['batchsize'] = 32
config_dict['early_stopping_patience'] = 3
config_dict['n_updates'] = 200
config_dict['batchsize'] = 16 # Reduced due to larger model
config_dict['early_stopping_patience'] = 5 # More patience for complex model
config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 100
config_dict['plot_at'] = 10
config_dict['validate_at'] = 100
network_config = {
'n_in_channels': 4
'n_in_channels': 4,
'base_channels': 32 # Start with 32, can increase to 64 for even better results
}
config_dict['network_config'] = network_config
train(**config_dict)
testset_path = os.path.join("data", "challenge_testset.npz")
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
save_path = os.path.join(config_dict['results_path'], "testset", "my_submission_name.npz")
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")