Files
DSAI/image-inpainting/src/main.py
2026-01-24 17:30:34 +01:00

63 lines
2.3 KiB
Python

"""
Author: Your Name
HTL-Grieskirchen 5. Jahrgang, Schuljahr 2025/26
main.py
"""
import os
from utils import create_predictions
from train import train
import shutil
if __name__ == '__main__':
config_dict = dict()
config_dict['seed'] = 42
config_dict['testset_ratio'] = 0.1
config_dict['validset_ratio'] = 0.1
# Get the absolute path based on the script's location
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
config_dict['results_path'] = os.path.join(project_root, "results")
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
config_dict['device'] = None
config_dict['learningrate'] = 2e-4 # Slightly lower for stable training
config_dict['weight_decay'] = 5e-5 # Reduced weight decay
config_dict['n_updates'] = 8000 # More updates for better convergence
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
config_dict['early_stopping_patience'] = 15 # More patience for complex model
config_dict['use_wandb'] = False
config_dict['print_train_stats_at'] = 10
config_dict['print_stats_at'] = 100
config_dict['plot_at'] = 300
config_dict['validate_at'] = 300 # Validate more frequently
network_config = {
'n_in_channels': 4,
'base_channels': 56, # Increased capacity for better feature learning
'dropout': 0.08 # Slightly less dropout with augmentation
}
config_dict['network_config'] = network_config
rmse_value = train(**config_dict)
testset_path = os.path.join(project_root, "data", "challenge_testset.npz")
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
os.makedirs(plot_path, exist_ok=True)
for name in os.listdir(plot_path):
p = os.path.join(plot_path, name)
if os.path.isfile(p) or os.path.islink(p):
os.unlink(p)
elif os.path.isdir(p):
shutil.rmtree(p)
# Comment out, if predictions are required
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)