added result, 18.0253
This commit is contained in:
@@ -15,44 +15,21 @@ import os
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Subset
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||
class RMSELoss(nn.Module):
|
||||
"""RMSE loss for direct optimization of evaluation metric"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.mse = nn.MSELoss()
|
||||
self.l1 = nn.L1Loss()
|
||||
|
||||
# Sobel filters for edge detection
|
||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
self.register_buffer('sobel_x', sobel_x.repeat(3, 1, 1, 1))
|
||||
self.register_buffer('sobel_y', sobel_y.repeat(3, 1, 1, 1))
|
||||
|
||||
def edge_loss(self, pred, target):
|
||||
"""Compute edge-aware loss using Sobel filters"""
|
||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
|
||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||
return edge_loss
|
||||
|
||||
def forward(self, pred, target):
|
||||
mse_loss = self.mse(pred, target)
|
||||
l1_loss = self.l1(pred, target)
|
||||
edge_loss = self.edge_loss(pred, target)
|
||||
|
||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||
return total_loss
|
||||
mse = self.mse(pred, target)
|
||||
rmse = torch.sqrt(mse + 1e-8) # Add epsilon for numerical stability
|
||||
return rmse
|
||||
|
||||
|
||||
def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_stopping_patience, device, learningrate,
|
||||
@@ -111,15 +88,16 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss for better reconstruction
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||
# defining the loss - RMSE for direct optimization of evaluation metric
|
||||
rmse_loss = RMSELoss().to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay, betas=(0.9, 0.99))
|
||||
|
||||
# Learning rate scheduler for better convergence
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
||||
# OneCycleLR for fast convergence - ramps up then down over entire training
|
||||
scheduler = OneCycleLR(optimizer, max_lr=learningrate, total_steps=n_updates,
|
||||
pct_start=0.3, anneal_strategy='cos', div_factor=25.0, final_div_factor=1e4)
|
||||
|
||||
if use_wandb:
|
||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||
@@ -146,7 +124,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
|
||||
output = network(input)
|
||||
|
||||
loss = combined_loss(output, target)
|
||||
loss = rmse_loss(output, target)
|
||||
|
||||
loss.backward()
|
||||
|
||||
@@ -154,7 +132,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
||||
scheduler.step() # OneCycleLR steps once per optimizer step
|
||||
|
||||
loss_list.append(loss.item())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user