Update runtime configuration and model parameters for improved training efficiency
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"learningrate": 0.005,
|
||||
"learningrate": 0.0005,
|
||||
"weight_decay": 5e-05,
|
||||
"n_updates": 35000,
|
||||
"n_updates": 12000,
|
||||
"plot_at": 500,
|
||||
"early_stopping_patience": 20,
|
||||
"print_stats_at": 200,
|
||||
"print_train_stats_at": 50,
|
||||
"validate_at": 500,
|
||||
"print_train_stats_at": 10,
|
||||
"validate_at": 250,
|
||||
"commands": {
|
||||
"save_checkpoint": false,
|
||||
"run_test_validation": false,
|
||||
|
||||
@@ -241,11 +241,14 @@ class MyModel(nn.Module):
|
||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
# Separate mask processing for better feature extraction
|
||||
# Separate mask processing for better feature extraction
|
||||
self.mask_conv = nn.Sequential(
|
||||
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
)
|
||||
|
||||
|
||||
@@ -17,24 +17,24 @@ if __name__ == '__main__':
|
||||
|
||||
config_dict['seed'] = 42
|
||||
config_dict['testset_ratio'] = 0.1
|
||||
config_dict['validset_ratio'] = 0.1
|
||||
config_dict['validset_ratio'] = 0.05
|
||||
# Get the absolute path based on the script's location
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(script_dir)
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup
|
||||
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
|
||||
config_dict['n_updates'] = 35000 # Extended training for better convergence
|
||||
config_dict['n_updates'] = 12000 # Extended training for better convergence
|
||||
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
|
||||
config_dict['early_stopping_patience'] = 20 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 50
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 200
|
||||
config_dict['plot_at'] = 500
|
||||
config_dict['validate_at'] = 500 # More frequent validation
|
||||
config_dict['validate_at'] = 250 # More frequent validation
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
@@ -69,7 +69,7 @@ if __name__ == '__main__':
|
||||
print(" - save_checkpoint: Save model at current step")
|
||||
print(" - run_test_validation: Run validation on final test set")
|
||||
print(" - generate_predictions: Generate predictions on challenge testset")
|
||||
print("\nChanges will be applied within 50 steps.")
|
||||
print("\nChanges will be applied within 5 steps.")
|
||||
print("="*60)
|
||||
print()
|
||||
|
||||
|
||||
@@ -313,8 +313,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
input, target = input.to(device), target.to(device)
|
||||
|
||||
# Check for runtime config updates every 50 steps
|
||||
if i % 50 == 0 and i > 0:
|
||||
# Check for runtime config updates every 5 steps
|
||||
if i % 5 == 0 and i > 0:
|
||||
runtime_params = load_runtime_config(config_json_path, runtime_params)
|
||||
n_updates = runtime_params['n_updates']
|
||||
plot_at = runtime_params['plot_at']
|
||||
|
||||
Reference in New Issue
Block a user