Update runtime configuration and model parameters for improved training efficiency
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"learningrate": 0.005,
|
"learningrate": 0.0005,
|
||||||
"weight_decay": 5e-05,
|
"weight_decay": 5e-05,
|
||||||
"n_updates": 35000,
|
"n_updates": 12000,
|
||||||
"plot_at": 500,
|
"plot_at": 500,
|
||||||
"early_stopping_patience": 20,
|
"early_stopping_patience": 20,
|
||||||
"print_stats_at": 200,
|
"print_stats_at": 200,
|
||||||
"print_train_stats_at": 50,
|
"print_train_stats_at": 10,
|
||||||
"validate_at": 500,
|
"validate_at": 250,
|
||||||
"commands": {
|
"commands": {
|
||||||
"save_checkpoint": false,
|
"save_checkpoint": false,
|
||||||
"run_test_validation": false,
|
"run_test_validation": false,
|
||||||
|
|||||||
@@ -241,11 +241,14 @@ class MyModel(nn.Module):
|
|||||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Separate mask processing for better feature extraction
|
||||||
# Separate mask processing for better feature extraction
|
# Separate mask processing for better feature extraction
|
||||||
self.mask_conv = nn.Sequential(
|
self.mask_conv = nn.Sequential(
|
||||||
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
nn.Conv2d(1, base_channels // 4, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
nn.Conv2d(base_channels // 4, base_channels // 4, 3, padding=1),
|
||||||
|
nn.BatchNorm2d(base_channels // 4, momentum=0.1, eps=1e-5),
|
||||||
nn.LeakyReLU(0.2, inplace=True)
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,24 +17,24 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
config_dict['seed'] = 42
|
config_dict['seed'] = 42
|
||||||
config_dict['testset_ratio'] = 0.1
|
config_dict['testset_ratio'] = 0.1
|
||||||
config_dict['validset_ratio'] = 0.1
|
config_dict['validset_ratio'] = 0.05
|
||||||
# Get the absolute path based on the script's location
|
# Get the absolute path based on the script's location
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
project_root = os.path.dirname(script_dir)
|
project_root = os.path.dirname(script_dir)
|
||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 5e-3 # Lower initial LR with warmup
|
config_dict['learningrate'] = 5e-4 # Lower initial LR with warmup
|
||||||
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
|
config_dict['weight_decay'] = 5e-5 # Reduced for more capacity
|
||||||
config_dict['n_updates'] = 35000 # Extended training for better convergence
|
config_dict['n_updates'] = 12000 # Extended training for better convergence
|
||||||
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
|
config_dict['batchsize'] = 64 # Reduced for larger model and mixed precision
|
||||||
config_dict['early_stopping_patience'] = 20 # More patience for complex model
|
config_dict['early_stopping_patience'] = 20 # More patience for complex model
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 50
|
config_dict['print_train_stats_at'] = 10
|
||||||
config_dict['print_stats_at'] = 200
|
config_dict['print_stats_at'] = 200
|
||||||
config_dict['plot_at'] = 500
|
config_dict['plot_at'] = 500
|
||||||
config_dict['validate_at'] = 500 # More frequent validation
|
config_dict['validate_at'] = 250 # More frequent validation
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
@@ -69,7 +69,7 @@ if __name__ == '__main__':
|
|||||||
print(" - save_checkpoint: Save model at current step")
|
print(" - save_checkpoint: Save model at current step")
|
||||||
print(" - run_test_validation: Run validation on final test set")
|
print(" - run_test_validation: Run validation on final test set")
|
||||||
print(" - generate_predictions: Generate predictions on challenge testset")
|
print(" - generate_predictions: Generate predictions on challenge testset")
|
||||||
print("\nChanges will be applied within 50 steps.")
|
print("\nChanges will be applied within 5 steps.")
|
||||||
print("="*60)
|
print("="*60)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|||||||
@@ -313,8 +313,8 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
|
|
||||||
input, target = input.to(device), target.to(device)
|
input, target = input.to(device), target.to(device)
|
||||||
|
|
||||||
# Check for runtime config updates every 50 steps
|
# Check for runtime config updates every 5 steps
|
||||||
if i % 50 == 0 and i > 0:
|
if i % 5 == 0 and i > 0:
|
||||||
runtime_params = load_runtime_config(config_json_path, runtime_params)
|
runtime_params = load_runtime_config(config_json_path, runtime_params)
|
||||||
n_updates = runtime_params['n_updates']
|
n_updates = runtime_params['n_updates']
|
||||||
plot_at = runtime_params['plot_at']
|
plot_at = runtime_params['plot_at']
|
||||||
|
|||||||
Reference in New Issue
Block a user