Update runtime configuration and model parameters for improved training efficiency

This commit is contained in:
2026-01-31 22:35:48 +01:00
parent 4af674b79d
commit d979c200f9
4 changed files with 15 additions and 12 deletions

View File

@@ -1,12 +1,12 @@
{
"learningrate": 0.005,
"learningrate": 0.0005,
"weight_decay": 5e-05,
"n_updates": 35000,
"n_updates": 12000,
"plot_at": 500,
"early_stopping_patience": 20,
"print_stats_at": 200,
"print_train_stats_at": 50,
"validate_at": 500,
"print_train_stats_at": 10,
"validate_at": 250,
"commands": {
"save_checkpoint": false,
"run_test_validation": false,

View File

@@ -241,11 +241,14 @@ class MyModel(nn.Module):
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
super().__init__()
# Separate mask processing for better feature extraction
# Separate mask processing for better feature extraction
self.mask_conv = nn.Sequential(
nn.Conv2d(1, base_channels // 4, 3, padding=1),
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
nn.LeakyReLU(0.2, inplace=True)
)

View File

@@ -17,24 +17,24 @@ if __name__ == '__main__':
config_dict['seed'] = 42
config_dict['testset_ratio'] = 0.1
config_dict['validset_ratio'] = 0.1
config_dict['validset_ratio'] = 0.05
# Get the absolute path based on the script's location
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None
config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
config_dict['n_updates'] = 35000 # Extended training for better convergence
config_dict['n_updates'] = 12000 # Extended training for better convergence
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
config_dict['early_stopping_patience'] = 20 # More patience for complex model
config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 50
config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 200
config_dict['plot_at'] = 500
config_dict['validate_at'] = 500 # More frequent validation
config_dict['validate_at'] = 250 # More frequent validation
network_config = {
'n_in_channels': 4,
@@ -69,7 +69,7 @@ if __name__ == '__main__':
print(" - save_checkpoint: Save model at current step")
print(" - run_test_validation: Run validation on final test set")
print(" - generate_predictions: Generate predictions on challenge testset")
print("\nChanges will be applied within 50 steps.")
print("\nChanges will be applied within 5 steps.")
print("="*60)
print()

View File

@@ -313,8 +313,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
input, target = input.to(device), target.to(device)
# Check for runtime config updates every 50 steps
if i % 50 == 0 and i > 0:
# Check for runtime config updates every 5 steps
if i % 5 == 0 and i > 0:
runtime_params = load_runtime_config(config_json_path, runtime_params)
n_updates = runtime_params['n_updates']
plot_at = runtime_params['plot_at']