Added baseline
This commit is contained in:
@@ -17,30 +17,34 @@ if __name__ == '__main__':
|
||||
config_dict['seed'] = 42
|
||||
config_dict['testset_ratio'] = 0.1
|
||||
config_dict['validset_ratio'] = 0.1
|
||||
config_dict['results_path'] = os.path.join("results")
|
||||
config_dict['data_path'] = os.path.join("data", "dataset")
|
||||
# Get the absolute path based on the script's location
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(script_dir)
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 1e-3
|
||||
config_dict['learningrate'] = 5e-4 # Slightly lower for more stable training
|
||||
config_dict['weight_decay'] = 1e-5 # default is 0
|
||||
config_dict['n_updates'] = 50000
|
||||
config_dict['batchsize'] = 32
|
||||
config_dict['early_stopping_patience'] = 3
|
||||
config_dict['n_updates'] = 200
|
||||
config_dict['batchsize'] = 16 # Reduced due to larger model
|
||||
config_dict['early_stopping_patience'] = 5 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 100
|
||||
config_dict['plot_at'] = 100
|
||||
config_dict['plot_at'] = 10
|
||||
config_dict['validate_at'] = 100
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 32 # Start with 32, can increase to 64 for even better results
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
train(**config_dict)
|
||||
|
||||
testset_path = os.path.join("data", "challenge_testset.npz")
|
||||
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
|
||||
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
||||
save_path = os.path.join(config_dict['results_path'], "testset", "my_submission_name.npz")
|
||||
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
||||
|
||||
Reference in New Issue
Block a user