Compare commits
2 Commits
claude-son
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
*.pt
|
||||
*.pt
|
||||
__pycache__/
|
||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"learningrate": 0.0003,
|
||||
"weight_decay": 1e-05,
|
||||
"n_updates": 150000,
|
||||
"plot_at": 400,
|
||||
"early_stopping_patience": 40,
|
||||
"print_stats_at": 200,
|
||||
"print_train_stats_at": 50,
|
||||
"validate_at": 200,
|
||||
"accumulation_steps": 1,
|
||||
"commands": {
|
||||
"save_checkpoint": false,
|
||||
"run_test_validation": false,
|
||||
"generate_predictions": false
|
||||
}
|
||||
}
|
||||
BIN
image-inpainting/results/runtime_predictions.npz
Normal file
BIN
image-inpainting/results/runtime_predictions.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -223,8 +223,8 @@ class MyModel(nn.Module):
|
||||
)
|
||||
|
||||
# Encoder with progressive feature extraction
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout, use_attention=False, use_dense=False)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||
|
||||
# Enhanced bottleneck with multi-scale features and dense connections
|
||||
@@ -238,8 +238,8 @@ class MyModel(nn.Module):
|
||||
|
||||
# Decoder with progressive reconstruction
|
||||
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout, use_attention=True, use_dense=True)
|
||||
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout, use_attention=False, use_dense=False)
|
||||
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||
|
||||
# Multi-scale feature fusion with dense connections
|
||||
self.multiscale_fusion = nn.Sequential(
|
||||
|
||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 8e-4 # Optimized for long training
|
||||
config_dict['weight_decay'] = 1e-4 # Better regularization for long training
|
||||
config_dict['n_updates'] = 30000 # Full day of training (~24 hours)
|
||||
config_dict['batchsize'] = 64 # Balanced for memory and quality
|
||||
config_dict['early_stopping_patience'] = 15 # More patience for better convergence
|
||||
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||
config_dict['n_updates'] = 40000 # Extended training
|
||||
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 50
|
||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 44, # Optimal capacity for 16GB VRAM
|
||||
'dropout': 0.12 # Higher dropout for longer training
|
||||
'base_channels': 64,
|
||||
'dropout': 0.1 # Proper dropout for regularization
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -10,6 +10,7 @@ from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@@ -19,15 +20,25 @@ from torch.utils.data import Subset
|
||||
import wandb
|
||||
|
||||
|
||||
class RMSELoss(nn.Module):
|
||||
"""RMSE loss for direct optimization of evaluation metric"""
|
||||
class EnhancedRMSELoss(nn.Module):
|
||||
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, pred, target):
|
||||
mse = self.mse(pred, target)
|
||||
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
|
||||
# Compute per-pixel squared error
|
||||
se = (pred - target) ** 2
|
||||
|
||||
# Weight edges more heavily for sharper results
|
||||
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||
|
||||
# Apply weighting
|
||||
weighted_se = se * edge_weight
|
||||
|
||||
# Compute RMSE
|
||||
mse = weighted_se.mean()
|
||||
rmse = torch.sqrt(mse + 1e-8)
|
||||
return rmse
|
||||
|
||||
|
||||
@@ -91,14 +102,14 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss - RMSE for direct optimization of evaluation metric
|
||||
rmse_loss = RMSELoss().to(device)
|
||||
# defining the loss - Enhanced RMSE for sharper predictions
|
||||
rmse_loss = EnhancedRMSELoss().to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999))
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||
|
||||
# Cosine annealing with warm restarts for long training
|
||||
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user