Compare commits
2 Commits
gemini-3-p
...
21.395
| Author | SHA1 | Date | |
|---|---|---|---|
| d7b7da6fc5 | |||
| 15cfbe315c |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -15,10 +15,8 @@ def init_weights(m):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
@@ -70,37 +68,76 @@ class CBAM(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> InstanceNorm2d -> GELU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0, dilation=1):
|
||||
class MultiScaleFeatureExtraction(nn.Module):
|
||||
"""Multi-scale feature extraction using dilated convolutions"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||
# InstanceNorm is preferred for style/inpainting tasks
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
self.act = nn.GELU()
|
||||
self.branch1 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 1),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=2, dilation=2),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=4, dilation=4),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch4 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=8, dilation=8),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b1 = self.branch1(x)
|
||||
b2 = self.branch2(x)
|
||||
b3 = self.branch3(x)
|
||||
b4 = self.branch4(x)
|
||||
out = torch.cat([b1, b2, b3, b4], dim=1)
|
||||
out = self.fusion(out)
|
||||
return out + x # Residual connection
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.act(self.bn(self.conv(x))))
|
||||
return self.dropout(self.relu(self.bn(self.conv(x))))
|
||||
|
||||
class ResidualConvBlock(nn.Module):
|
||||
"""Residual convolutional block for better gradient flow"""
|
||||
def __init__(self, channels, dropout=0.0, dilation=1):
|
||||
def __init__(self, channels, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn1 = nn.InstanceNorm2d(channels, affine=True)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn2 = nn.InstanceNorm2d(channels, affine=True)
|
||||
self.act = nn.GELU()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.act(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + residual
|
||||
return self.act(out)
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
@@ -161,12 +198,13 @@ class MyModel(nn.Module):
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
|
||||
# Bottleneck with multiple residual blocks
|
||||
# Bottleneck with multiple residual blocks and multi-scale features
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=2),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=4),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout, dilation=8),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
MultiScaleFeatureExtraction(base_channels * 16),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
CBAM(base_channels * 16)
|
||||
)
|
||||
|
||||
@@ -186,7 +224,7 @@ class MyModel(nn.Module):
|
||||
# Output layer with smooth transition
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(base_channels, base_channels // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.LeakyReLU(0.1, inplace=True),
|
||||
nn.Conv2d(base_channels // 2, 3, kernel_size=1),
|
||||
nn.Sigmoid() # Ensure output is in [0, 1] range
|
||||
)
|
||||
|
||||
@@ -10,11 +10,50 @@ import numpy as np
|
||||
import random
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
|
||||
IMAGE_DIMENSION = 100
|
||||
|
||||
|
||||
class DataAugmentation:
|
||||
"""Data augmentation pipeline for improved generalization"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, image: Image.Image) -> Image.Image:
|
||||
# Random horizontal flip
|
||||
if random.random() < self.p:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() < self.p * 0.5:
|
||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90 degree increments)
|
||||
if random.random() < self.p * 0.3:
|
||||
angle = random.choice([90, 180, 270])
|
||||
image = image.rotate(angle)
|
||||
|
||||
# Color jittering
|
||||
if random.random() < self.p * 0.4:
|
||||
# Brightness
|
||||
enhancer = ImageEnhance.Brightness(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
if random.random() < self.p * 0.4:
|
||||
# Contrast
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
if random.random() < self.p * 0.3:
|
||||
# Saturation
|
||||
enhancer = ImageEnhance.Color(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
||||
image_array = np.transpose(image_array, (2, 0, 1))
|
||||
known_array = np.zeros_like(image_array)
|
||||
@@ -26,21 +65,11 @@ def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tu
|
||||
|
||||
return image_array, known_array
|
||||
|
||||
def resize(img: Image, augment: bool = False):
|
||||
transforms_list = [
|
||||
def resize(img: Image):
|
||||
resize_transforms = transforms.Compose([
|
||||
transforms.Resize((IMAGE_DIMENSION, IMAGE_DIMENSION)),
|
||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||
]
|
||||
|
||||
if augment:
|
||||
transforms_list = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomVerticalFlip(),
|
||||
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
|
||||
transforms.RandomRotation(10),
|
||||
] + transforms_list
|
||||
|
||||
resize_transforms = transforms.Compose(transforms_list)
|
||||
])
|
||||
return resize_transforms(img)
|
||||
def preprocess(input_array: np.ndarray):
|
||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||
@@ -48,12 +77,13 @@ def preprocess(input_array: np.ndarray):
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder
|
||||
Dataset class for loading images from a folder with augmentation
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str, augment: bool = False):
|
||||
def __init__(self, datafolder: str, augment: bool = True):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder, "**", "*.jpg"), recursive=True))
|
||||
self.augment = augment
|
||||
self.augmentation = DataAugmentation(p=0.5) if augment else None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
@@ -61,18 +91,28 @@ class ImageDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx: int):
|
||||
index = int(idx)
|
||||
|
||||
image = Image.open(self.imagefiles[index])
|
||||
image = np.asarray(resize(image, self.augment))
|
||||
image = Image.open(self.imagefiles[index]).convert('RGB')
|
||||
|
||||
# Apply augmentation before resize
|
||||
if self.augment and self.augmentation is not None:
|
||||
image = self.augmentation(image)
|
||||
|
||||
image = resize(image)
|
||||
image = np.asarray(image)
|
||||
image = preprocess(image)
|
||||
spacing_x = random.randint(2,6)
|
||||
spacing_y = random.randint(2,6)
|
||||
offset_x = random.randint(0,8)
|
||||
offset_y = random.randint(0,8)
|
||||
|
||||
# More varied spacing for better generalization
|
||||
spacing_x = random.randint(2, 8)
|
||||
spacing_y = random.randint(2, 8)
|
||||
offset_x = random.randint(0, min(spacing_x - 1, 8))
|
||||
offset_y = random.randint(0, min(spacing_y - 1, 8))
|
||||
spacing = (spacing_x, spacing_y)
|
||||
offset = (offset_x, offset_y)
|
||||
|
||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||
target_image = torch.from_numpy(np.transpose(image, (2, 0, 1)))
|
||||
input_array = torch.from_numpy(input_array)
|
||||
known_array = torch.from_numpy(known_array)
|
||||
input_array = torch.cat((input_array, known_array), dim=0)
|
||||
|
||||
return input_array, target_image
|
||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||
config_dict['learningrate'] = 2e-4 # Slightly lower for stable training
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced weight decay
|
||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 48, # Good balance between capacity and memory
|
||||
'dropout': 0.1 # Regularization
|
||||
'base_channels': 56, # Increased capacity for better feature learning
|
||||
'dropout': 0.08 # Slightly less dropout with augmentation
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -10,6 +10,7 @@ from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@@ -20,15 +21,58 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
import wandb
|
||||
|
||||
|
||||
def gaussian_kernel(window_size=11, sigma=1.5):
|
||||
"""Create a Gaussian kernel for SSIM computation"""
|
||||
x = torch.arange(window_size).float() - window_size // 2
|
||||
gauss = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
||||
kernel = gauss / gauss.sum()
|
||||
kernel_2d = kernel.unsqueeze(1) * kernel.unsqueeze(0)
|
||||
return kernel_2d.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
class SSIMLoss(nn.Module):
|
||||
"""Structural Similarity Index Loss for perceptual quality"""
|
||||
def __init__(self, window_size=11, sigma=1.5):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
kernel = gaussian_kernel(window_size, sigma)
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.C1 = 0.01 ** 2
|
||||
self.C2 = 0.03 ** 2
|
||||
|
||||
def forward(self, pred, target):
|
||||
# Apply to each channel
|
||||
channels = pred.shape[1]
|
||||
kernel = self.kernel.repeat(channels, 1, 1, 1)
|
||||
|
||||
mu_pred = F.conv2d(pred, kernel, padding=self.window_size // 2, groups=channels)
|
||||
mu_target = F.conv2d(target, kernel, padding=self.window_size // 2, groups=channels)
|
||||
|
||||
mu_pred_sq = mu_pred.pow(2)
|
||||
mu_target_sq = mu_target.pow(2)
|
||||
mu_pred_target = mu_pred * mu_target
|
||||
|
||||
sigma_pred_sq = F.conv2d(pred * pred, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_sq
|
||||
sigma_target_sq = F.conv2d(target * target, kernel, padding=self.window_size // 2, groups=channels) - mu_target_sq
|
||||
sigma_pred_target = F.conv2d(pred * target, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_target
|
||||
|
||||
ssim = ((2 * mu_pred_target + self.C1) * (2 * sigma_pred_target + self.C2)) / \
|
||||
((mu_pred_sq + mu_target_sq + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2))
|
||||
|
||||
return 1 - ssim.mean()
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||
"""Combined loss: MSE + L1 + SSIM + Edge for comprehensive image reconstruction"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ssim_weight = ssim_weight
|
||||
self.mse = nn.MSELoss()
|
||||
self.l1 = nn.L1Loss()
|
||||
self.ssim = SSIMLoss(window_size=7)
|
||||
|
||||
# Sobel filters for edge detection
|
||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
@@ -38,10 +82,10 @@ class CombinedLoss(nn.Module):
|
||||
|
||||
def edge_loss(self, pred, target):
|
||||
"""Compute edge-aware loss using Sobel filters"""
|
||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = F.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = F.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
|
||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||
return edge_loss
|
||||
@@ -50,8 +94,12 @@ class CombinedLoss(nn.Module):
|
||||
mse_loss = self.mse(pred, target)
|
||||
l1_loss = self.l1(pred, target)
|
||||
edge_loss = self.edge_loss(pred, target)
|
||||
ssim_loss = self.ssim(pred, target)
|
||||
|
||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||
total_loss = (self.mse_weight * mse_loss +
|
||||
self.l1_weight * l1_loss +
|
||||
self.edge_weight * edge_loss +
|
||||
self.ssim_weight * ssim_loss)
|
||||
return total_loss
|
||||
|
||||
|
||||
@@ -84,21 +132,16 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
plotpath = os.path.join(results_path, "plots")
|
||||
os.makedirs(plotpath, exist_ok=True)
|
||||
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||
|
||||
n_total = len(image_dataset)
|
||||
n_test = int(n_total * testset_ratio)
|
||||
n_valid = int(n_total * validset_ratio)
|
||||
n_train = n_total - n_test - n_valid
|
||||
indices = np.random.permutation(n_total)
|
||||
|
||||
# Create datasets with and without augmentation
|
||||
train_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=True)
|
||||
val_test_dataset_source = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
|
||||
dataset_train = Subset(train_dataset_source, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(val_test_dataset_source, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(val_test_dataset_source, indices=indices[n_train + n_valid:n_total])
|
||||
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||
|
||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
@@ -117,7 +160,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss for better reconstruction
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3).to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
|
||||
@@ -81,9 +81,42 @@ def read_compressed_file(file_path: str):
|
||||
return input_arrays, known_arrays
|
||||
|
||||
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
||||
def apply_tta(model, input_tensor, device):
|
||||
"""
|
||||
Here, one might needs to adjust the code based on the used preprocessing
|
||||
Apply Test-Time Augmentation for better predictions.
|
||||
Averages predictions from original and augmented versions.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
# Original
|
||||
out = model(input_tensor)
|
||||
outputs.append(out)
|
||||
|
||||
# Horizontal flip
|
||||
flipped_h = torch.flip(input_tensor, dims=[3])
|
||||
out_h = model(flipped_h)
|
||||
out_h = torch.flip(out_h, dims=[3])
|
||||
outputs.append(out_h)
|
||||
|
||||
# Vertical flip
|
||||
flipped_v = torch.flip(input_tensor, dims=[2])
|
||||
out_v = model(flipped_v)
|
||||
out_v = torch.flip(out_v, dims=[2])
|
||||
outputs.append(out_v)
|
||||
|
||||
# Both flips
|
||||
flipped_hv = torch.flip(input_tensor, dims=[2, 3])
|
||||
out_hv = model(flipped_hv)
|
||||
out_hv = torch.flip(out_hv, dims=[2, 3])
|
||||
outputs.append(out_hv)
|
||||
|
||||
# Average all predictions
|
||||
return torch.stack(outputs, dim=0).mean(dim=0)
|
||||
|
||||
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True):
|
||||
"""
|
||||
Create predictions with optional Test-Time Augmentation for improved results.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
@@ -94,7 +127,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
device = torch.device(device)
|
||||
|
||||
model = MyModel(**model_config)
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
model.load_state_dict(torch.load(state_dict_path, weights_only=True))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@@ -111,9 +144,14 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
with torch.no_grad():
|
||||
for i in range(len(input_arrays)):
|
||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(
|
||||
device)
|
||||
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(device)
|
||||
input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array
|
||||
|
||||
if use_tta:
|
||||
output = apply_tta(model, input_tensor, device)
|
||||
else:
|
||||
output = model(input_tensor)
|
||||
|
||||
output = output.cpu().numpy()
|
||||
predictions.append(output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user