Compare commits
2 Commits
claude-son
...
beforeRunt
| Author | SHA1 | Date | |
|---|---|---|---|
| 846bf3ee77 | |||
| 06a0e58ea0 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/*
|
data/*
|
||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
16
image-inpainting/results/runtime_config.json
Normal file
16
image-inpainting/results/runtime_config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"learningrate": 0.0003,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"n_updates": 150000,
|
||||||
|
"plot_at": 400,
|
||||||
|
"early_stopping_patience": 40,
|
||||||
|
"print_stats_at": 200,
|
||||||
|
"print_train_stats_at": 50,
|
||||||
|
"validate_at": 200,
|
||||||
|
"accumulation_steps": 1,
|
||||||
|
"commands": {
|
||||||
|
"save_checkpoint": false,
|
||||||
|
"run_test_validation": false,
|
||||||
|
"generate_predictions": false
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
image-inpainting/results/runtime_predictions.npz
Normal file
BIN
image-inpainting/results/runtime_predictions.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
BIN
image-inpainting/results/testset/tikaiz-16.1240.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -223,8 +223,8 @@ class MyModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Encoder with progressive feature extraction
|
# Encoder with progressive feature extraction
|
||||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout, use_attention=False, use_dense=False)
|
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
|
|
||||||
# Enhanced bottleneck with multi-scale features and dense connections
|
# Enhanced bottleneck with multi-scale features and dense connections
|
||||||
@@ -238,8 +238,8 @@ class MyModel(nn.Module):
|
|||||||
|
|
||||||
# Decoder with progressive reconstruction
|
# Decoder with progressive reconstruction
|
||||||
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
self.up1 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True, use_dense=True)
|
||||||
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout, use_attention=True, use_dense=True)
|
self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout*0.7, use_attention=True, use_dense=True)
|
||||||
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout, use_attention=False, use_dense=False)
|
self.up3 = UpBlock(base_channels * 2, base_channels, dropout=dropout*0.5, use_attention=False, use_dense=False)
|
||||||
|
|
||||||
# Multi-scale feature fusion with dense connections
|
# Multi-scale feature fusion with dense connections
|
||||||
self.multiscale_fusion = nn.Sequential(
|
self.multiscale_fusion = nn.Sequential(
|
||||||
|
|||||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 8e-4 # Optimized for long training
|
config_dict['learningrate'] = 3e-4 # More stable learning rate
|
||||||
config_dict['weight_decay'] = 1e-4 # Better regularization for long training
|
config_dict['weight_decay'] = 1e-4 # Proper regularization
|
||||||
config_dict['n_updates'] = 30000 # Full day of training (~24 hours)
|
config_dict['n_updates'] = 40000 # Extended training
|
||||||
config_dict['batchsize'] = 64 # Balanced for memory and quality
|
config_dict['batchsize'] = 96 # Maximize batch size for better gradients
|
||||||
config_dict['early_stopping_patience'] = 15 # More patience for better convergence
|
config_dict['early_stopping_patience'] = 20 # More patience for convergence
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 50
|
config_dict['print_train_stats_at'] = 50
|
||||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 44, # Optimal capacity for 16GB VRAM
|
'base_channels': 64,
|
||||||
'dropout': 0.12 # Higher dropout for longer training
|
'dropout': 0.1 # Proper dropout for regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -19,15 +20,25 @@ from torch.utils.data import Subset
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
class RMSELoss(nn.Module):
|
class EnhancedRMSELoss(nn.Module):
|
||||||
"""RMSE loss for direct optimization of evaluation metric"""
|
"""Enhanced RMSE loss with edge weighting for sharper predictions"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse = nn.MSELoss()
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, target):
|
||||||
mse = self.mse(pred, target)
|
# Compute per-pixel squared error
|
||||||
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
|
se = (pred - target) ** 2
|
||||||
|
|
||||||
|
# Weight edges more heavily for sharper results
|
||||||
|
edge_weight = 1.0 + 0.3 * torch.abs(target[:, :, 1:, :] - target[:, :, :-1, :]).mean(dim=1, keepdim=True)
|
||||||
|
edge_weight = F.pad(edge_weight, (0, 0, 0, 1), value=1.0)
|
||||||
|
|
||||||
|
# Apply weighting
|
||||||
|
weighted_se = se * edge_weight
|
||||||
|
|
||||||
|
# Compute RMSE
|
||||||
|
mse = weighted_se.mean()
|
||||||
|
rmse = torch.sqrt(mse + 1e-8)
|
||||||
return rmse
|
return rmse
|
||||||
|
|
||||||
|
|
||||||
@@ -91,14 +102,14 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - RMSE for direct optimization of evaluation metric
|
# defining the loss - Enhanced RMSE for sharper predictions
|
||||||
rmse_loss = RMSELoss().to(device)
|
rmse_loss = EnhancedRMSELoss().to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999))
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
|
||||||
# Cosine annealing with warm restarts for long training
|
# Cosine annealing with warm restarts for gradual learning rate decay
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
optimizer, T_0=n_updates//4, T_mult=1, eta_min=learningrate/100
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user