Compare commits
2 Commits
gemini-3-p
...
21.395
| Author | SHA1 | Date | |
|---|---|---|---|
| d7b7da6fc5 | |||
| 15cfbe315c |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -15,11 +15,9 @@ def init_weights(m):
|
|||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
if m.weight is not None:
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.bias, 0)
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class ChannelAttention(nn.Module):
|
||||||
@@ -70,37 +68,76 @@ class CBAM(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class MultiScaleFeatureExtraction(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> InstanceNorm2d -> GELU"""
|
"""Multi-scale feature extraction using dilated convolutions"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0, dilation=1):
|
def __init__(self, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
self.branch1 = nn.Sequential(
|
||||||
# InstanceNorm is preferred for style/inpainting tasks
|
nn.Conv2d(channels, channels // 4, 1),
|
||||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
nn.BatchNorm2d(channels // 4),
|
||||||
self.act = nn.GELU()
|
nn.LeakyReLU(0.1, inplace=True)
|
||||||
|
)
|
||||||
|
self.branch2 = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, channels // 4, 3, padding=2, dilation=2),
|
||||||
|
nn.BatchNorm2d(channels // 4),
|
||||||
|
nn.LeakyReLU(0.1, inplace=True)
|
||||||
|
)
|
||||||
|
self.branch3 = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, channels // 4, 3, padding=4, dilation=4),
|
||||||
|
nn.BatchNorm2d(channels // 4),
|
||||||
|
nn.LeakyReLU(0.1, inplace=True)
|
||||||
|
)
|
||||||
|
self.branch4 = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, channels // 4, 3, padding=8, dilation=8),
|
||||||
|
nn.BatchNorm2d(channels // 4),
|
||||||
|
nn.LeakyReLU(0.1, inplace=True)
|
||||||
|
)
|
||||||
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, channels, 1),
|
||||||
|
nn.BatchNorm2d(channels),
|
||||||
|
nn.LeakyReLU(0.1, inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b1 = self.branch1(x)
|
||||||
|
b2 = self.branch2(x)
|
||||||
|
b3 = self.branch3(x)
|
||||||
|
b4 = self.branch4(x)
|
||||||
|
out = torch.cat([b1, b2, b3, b4], dim=1)
|
||||||
|
out = self.fusion(out)
|
||||||
|
return out + x # Residual connection
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.dropout(self.act(self.bn(self.conv(x))))
|
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||||
|
|
||||||
class ResidualConvBlock(nn.Module):
|
class ResidualConvBlock(nn.Module):
|
||||||
"""Residual convolutional block for better gradient flow"""
|
"""Residual convolutional block for better gradient flow"""
|
||||||
def __init__(self, channels, dropout=0.0, dilation=1):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn1 = nn.InstanceNorm2d(channels, affine=True)
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = nn.InstanceNorm2d(channels, affine=True)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
self.act = nn.GELU()
|
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
out = self.act(self.bn1(self.conv1(x)))
|
out = self.relu(self.bn1(self.conv1(x)))
|
||||||
out = self.dropout(out)
|
out = self.dropout(out)
|
||||||
out = self.bn2(self.conv2(out))
|
out = self.bn2(self.conv2(out))
|
||||||
out = out + residual
|
out = out + residual
|
||||||
return self.act(out)
|
return self.relu(out)
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
@@ -161,12 +198,13 @@ class MyModel(nn.Module):
|
|||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||||
|
|
||||||
# Bottleneck with multiple residual blocks
|
# Bottleneck with multiple residual blocks and multi-scale features
|
||||||
self.bottleneck = nn.Sequential(
|
self.bottleneck = nn.Sequential(
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=2),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=4),
|
MultiScaleFeatureExtraction(base_channels * 16),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=8),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
CBAM(base_channels * 16)
|
CBAM(base_channels * 16)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -186,7 +224,7 @@ class MyModel(nn.Module):
|
|||||||
# Output layer with smooth transition
|
# Output layer with smooth transition
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
||||||
nn.GELU(),
|
nn.LeakyReLU(0.1, inplace=True),
|
||||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
||||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
nn.Sigmoid() # Ensure output is in [0, 1] range
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,11 +10,50 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image, ImageEnhance, ImageFilter
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
|
|
||||||
|
class DataAugmentation:
|
||||||
|
"""Data augmentation pipeline for improved generalization"""
|
||||||
|
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def __call__(self, image: Image.Image) -> Image.Image:
|
||||||
|
# Random horizontal flip
|
||||||
|
if random.random() < self.p:
|
||||||
|
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
# Random vertical flip
|
||||||
|
if random.random() < self.p * 0.5:
|
||||||
|
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
# Random rotation (90 degree increments)
|
||||||
|
if random.random() < self.p * 0.3:
|
||||||
|
angle = random.choice([90, 180, 270])
|
||||||
|
image = image.rotate(angle)
|
||||||
|
|
||||||
|
# Color jittering
|
||||||
|
if random.random() < self.p * 0.4:
|
||||||
|
# Brightness
|
||||||
|
enhancer = ImageEnhance.Brightness(image)
|
||||||
|
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||||
|
|
||||||
|
if random.random() < self.p * 0.4:
|
||||||
|
# Contrast
|
||||||
|
enhancer = ImageEnhance.Contrast(image)
|
||||||
|
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||||
|
|
||||||
|
if random.random() < self.p * 0.3:
|
||||||
|
# Saturation
|
||||||
|
enhancer = ImageEnhance.Color(image)
|
||||||
|
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
||||||
image_array = np.transpose(image_array, (2, 0, 1))
|
image_array = np.transpose(image_array, (2, 0, 1))
|
||||||
known_array = np.zeros_like(image_array)
|
known_array = np.zeros_like(image_array)
|
||||||
@@ -26,21 +65,11 @@ def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tu
|
|||||||
|
|
||||||
return image_array, known_array
|
return image_array, known_array
|
||||||
|
|
||||||
def resize(img: Image, augment: bool = False):
|
def resize(img: Image):
|
||||||
transforms_list = [
|
resize_transforms = transforms.Compose([
|
||||||
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
||||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||||
]
|
])
|
||||||
|
|
||||||
if augment:
|
|
||||||
transforms_list = [
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomVerticalFlip(),
|
|
||||||
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
|
||||||
transforms.RandomRotation(10),
|
|
||||||
] + transforms_list
|
|
||||||
|
|
||||||
resize_transforms = transforms.Compose(transforms_list)
|
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
@@ -48,31 +77,42 @@ def preprocess(input_array: np.ndarray):
|
|||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder
|
Dataset class for loading images from a folder with augmentation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str, augment: bool = False):
|
def __init__(self, datafolder: str, augment: bool = True):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder, "**", "*.jpg"), recursive=True))
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
self.augmentation = DataAugmentation(p=0.5) if augment else None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
|
|
||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx: int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index])
|
image = Image.open(self.imagefiles[index]).convert('RGB')
|
||||||
image = np.asarray(resize(image, self.augment))
|
|
||||||
|
# Apply augmentation before resize
|
||||||
|
if self.augment and self.augmentation is not None:
|
||||||
|
image = self.augmentation(image)
|
||||||
|
|
||||||
|
image = resize(image)
|
||||||
|
image = np.asarray(image)
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
spacing_x = random.randint(2,6)
|
|
||||||
spacing_y = random.randint(2,6)
|
# More varied spacing for better generalization
|
||||||
offset_x = random.randint(0,8)
|
spacing_x = random.randint(2, 8)
|
||||||
offset_y = random.randint(0,8)
|
spacing_y = random.randint(2, 8)
|
||||||
|
offset_x = random.randint(0, min(spacing_x - 1, 8))
|
||||||
|
offset_y = random.randint(0, min(spacing_y - 1, 8))
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
|
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
target_image = torch.from_numpy(np.transpose(image, (2,0,1)))
|
target_image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
|
||||||
input_array = torch.from_numpy(input_array)
|
input_array = torch.from_numpy(input_array)
|
||||||
known_array = torch.from_numpy(known_array)
|
known_array = torch.from_numpy(known_array)
|
||||||
input_array = torch.cat((input_array, known_array), dim=0)
|
input_array = torch.cat((input_array, known_array), dim=0)
|
||||||
|
|
||||||
return input_array, target_image
|
return input_array, target_image
|
||||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
config_dict['learningrate'] = 2e-4 # Slightly lower for stable training
|
||||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
config_dict['weight_decay'] = 5e-5 # Reduced weight decay
|
||||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
config_dict['n_updates'] = 8000 # More updates for better convergence
|
||||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 10
|
||||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 48, # Good balance between capacity and memory
|
'base_channels': 56, # Increased capacity for better feature learning
|
||||||
'dropout': 0.1 # Regularization
|
'dropout': 0.08 # Slightly less dropout with augmentation
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -20,15 +21,58 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_kernel(window_size=11, sigma=1.5):
|
||||||
|
"""Create a Gaussian kernel for SSIM computation"""
|
||||||
|
x = torch.arange(window_size).float() - window_size // 2
|
||||||
|
gauss = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
||||||
|
kernel = gauss / gauss.sum()
|
||||||
|
kernel_2d = kernel.unsqueeze(1) * kernel.unsqueeze(0)
|
||||||
|
return kernel_2d.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class SSIMLoss(nn.Module):
|
||||||
|
"""Structural Similarity Index Loss for perceptual quality"""
|
||||||
|
def __init__(self, window_size=11, sigma=1.5):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
kernel = gaussian_kernel(window_size, sigma)
|
||||||
|
self.register_buffer('kernel', kernel)
|
||||||
|
self.C1 = 0.01 ** 2
|
||||||
|
self.C2 = 0.03 ** 2
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
# Apply to each channel
|
||||||
|
channels = pred.shape[1]
|
||||||
|
kernel = self.kernel.repeat(channels, 1, 1, 1)
|
||||||
|
|
||||||
|
mu_pred = F.conv2d(pred, kernel, padding=self.window_size // 2, groups=channels)
|
||||||
|
mu_target = F.conv2d(target, kernel, padding=self.window_size // 2, groups=channels)
|
||||||
|
|
||||||
|
mu_pred_sq = mu_pred.pow(2)
|
||||||
|
mu_target_sq = mu_target.pow(2)
|
||||||
|
mu_pred_target = mu_pred * mu_target
|
||||||
|
|
||||||
|
sigma_pred_sq = F.conv2d(pred * pred, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_sq
|
||||||
|
sigma_target_sq = F.conv2d(target * target, kernel, padding=self.window_size // 2, groups=channels) - mu_target_sq
|
||||||
|
sigma_pred_target = F.conv2d(pred * target, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_target
|
||||||
|
|
||||||
|
ssim = ((2 * mu_pred_target + self.C1) * (2 * sigma_pred_target + self.C2)) / \
|
||||||
|
((mu_pred_sq + mu_target_sq + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2))
|
||||||
|
|
||||||
|
return 1 - ssim.mean()
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
"""Combined loss: MSE + L1 + SSIM + Edge for comprehensive image reconstruction"""
|
||||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
self.mse_weight = mse_weight
|
||||||
self.l1_weight = l1_weight
|
self.l1_weight = l1_weight
|
||||||
self.edge_weight = edge_weight
|
self.edge_weight = edge_weight
|
||||||
|
self.ssim_weight = ssim_weight
|
||||||
self.mse = nn.MSELoss()
|
self.mse = nn.MSELoss()
|
||||||
self.l1 = nn.L1Loss()
|
self.l1 = nn.L1Loss()
|
||||||
|
self.ssim = SSIMLoss(window_size=7)
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
# Sobel filters for edge detection
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||||
@@ -38,10 +82,10 @@ class CombinedLoss(nn.Module):
|
|||||||
|
|
||||||
def edge_loss(self, pred, target):
|
def edge_loss(self, pred, target):
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
"""Compute edge-aware loss using Sobel filters"""
|
||||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
target_edge_x = F.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
target_edge_y = F.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||||
return edge_loss
|
return edge_loss
|
||||||
@@ -50,8 +94,12 @@ class CombinedLoss(nn.Module):
|
|||||||
mse_loss = self.mse(pred, target)
|
mse_loss = self.mse(pred, target)
|
||||||
l1_loss = self.l1(pred, target)
|
l1_loss = self.l1(pred, target)
|
||||||
edge_loss = self.edge_loss(pred, target)
|
edge_loss = self.edge_loss(pred, target)
|
||||||
|
ssim_loss = self.ssim(pred, target)
|
||||||
|
|
||||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
total_loss = (self.mse_weight * mse_loss +
|
||||||
|
self.l1_weight * l1_loss +
|
||||||
|
self.edge_weight * edge_loss +
|
||||||
|
self.ssim_weight * ssim_loss)
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
@@ -84,21 +132,16 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
plotpath = os.path.join(results_path, "plots")
|
plotpath = os.path.join(results_path, "plots")
|
||||||
os.makedirs(plotpath, exist_ok=True)
|
os.makedirs(plotpath, exist_ok=True)
|
||||||
|
|
||||||
image_dataset = datasets.ImageDataset(datafolder=data_path, augment=False)
|
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||||
|
|
||||||
n_total = len(image_dataset)
|
n_total = len(image_dataset)
|
||||||
n_test = int(n_total * testset_ratio)
|
n_test = int(n_total * testset_ratio)
|
||||||
n_valid = int(n_total * validset_ratio)
|
n_valid = int(n_total * validset_ratio)
|
||||||
n_train = n_total - n_test - n_valid
|
n_train = n_total - n_test - n_valid
|
||||||
indices = np.random.permutation(n_total)
|
indices = np.random.permutation(n_total)
|
||||||
|
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||||
# Create datasets with and without augmentation
|
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||||
train_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=True)
|
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||||
val_test_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=False)
|
|
||||||
|
|
||||||
dataset_train = Subset(train_dataset_source, indices=indices[0:n_train])
|
|
||||||
dataset_valid = Subset(val_test_dataset_source, indices=indices[n_train:n_train + n_valid])
|
|
||||||
dataset_test = Subset(val_test_dataset_source, indices=indices[n_train + n_valid:n_total])
|
|
||||||
|
|
||||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||||
|
|
||||||
@@ -117,7 +160,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss for better reconstruction
|
# defining the loss - combined loss for better reconstruction
|
||||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3).to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
|
|||||||
@@ -81,9 +81,42 @@ def read_compressed_file(file_path: str):
|
|||||||
return input_arrays, known_arrays
|
return input_arrays, known_arrays
|
||||||
|
|
||||||
|
|
||||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
def apply_tta(model, input_tensor, device):
|
||||||
"""
|
"""
|
||||||
Here, one might needs to adjust the code based on the used preprocessing
|
Apply Test-Time Augmentation for better predictions.
|
||||||
|
Averages predictions from original and augmented versions.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
# Original
|
||||||
|
out = model(input_tensor)
|
||||||
|
outputs.append(out)
|
||||||
|
|
||||||
|
# Horizontal flip
|
||||||
|
flipped_h = torch.flip(input_tensor, dims=[3])
|
||||||
|
out_h = model(flipped_h)
|
||||||
|
out_h = torch.flip(out_h, dims=[3])
|
||||||
|
outputs.append(out_h)
|
||||||
|
|
||||||
|
# Vertical flip
|
||||||
|
flipped_v = torch.flip(input_tensor, dims=[2])
|
||||||
|
out_v = model(flipped_v)
|
||||||
|
out_v = torch.flip(out_v, dims=[2])
|
||||||
|
outputs.append(out_v)
|
||||||
|
|
||||||
|
# Both flips
|
||||||
|
flipped_hv = torch.flip(input_tensor, dims=[2, 3])
|
||||||
|
out_hv = model(flipped_hv)
|
||||||
|
out_hv = torch.flip(out_hv, dims=[2, 3])
|
||||||
|
outputs.append(out_hv)
|
||||||
|
|
||||||
|
# Average all predictions
|
||||||
|
return torch.stack(outputs, dim=0).mean(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True):
|
||||||
|
"""
|
||||||
|
Create predictions with optional Test-Time Augmentation for improved results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -94,7 +127,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
model = MyModel(**model_config)
|
model = MyModel(**model_config)
|
||||||
model.load_state_dict(torch.load(state_dict_path))
|
model.load_state_dict(torch.load(state_dict_path, weights_only=True))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -111,9 +144,14 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(input_arrays)):
|
for i in range(len(input_arrays)):
|
||||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||||
input_array = torch.from_numpy(input_arrays[i]).to(
|
input_array = torch.from_numpy(input_arrays[i]).to(device)
|
||||||
device)
|
input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array
|
||||||
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
|
||||||
|
if use_tta:
|
||||||
|
output = apply_tta(model, input_tensor, device)
|
||||||
|
else:
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
output = output.cpu().numpy()
|
output = output.cpu().numpy()
|
||||||
predictions.append(output)
|
predictions.append(output)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user