Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8b9b3f6 | |||
| 7d4caaf501 | |||
| 248ffb8faf | |||
| 1771377121 | |||
| eaf45f5c72 | |||
| 8f0fb11926 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
|||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
BIN
image-inpainting/results/submissions/tikaiz-1.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-1.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -68,46 +68,6 @@ class CBAM(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleFeatureExtraction(nn.Module):
|
|
||||||
"""Multi-scale feature extraction using dilated convolutions"""
|
|
||||||
def __init__(self, channels):
|
|
||||||
super().__init__()
|
|
||||||
self.branch1 = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, channels // 4, 1),
|
|
||||||
nn.BatchNorm2d(channels // 4),
|
|
||||||
nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
)
|
|
||||||
self.branch2 = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, channels // 4, 3, padding=2, dilation=2),
|
|
||||||
nn.BatchNorm2d(channels // 4),
|
|
||||||
nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
)
|
|
||||||
self.branch3 = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, channels // 4, 3, padding=4, dilation=4),
|
|
||||||
nn.BatchNorm2d(channels // 4),
|
|
||||||
nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
)
|
|
||||||
self.branch4 = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, channels // 4, 3, padding=8, dilation=8),
|
|
||||||
nn.BatchNorm2d(channels // 4),
|
|
||||||
nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
)
|
|
||||||
self.fusion = nn.Sequential(
|
|
||||||
nn.Conv2d(channels, channels, 1),
|
|
||||||
nn.BatchNorm2d(channels),
|
|
||||||
nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b1 = self.branch1(x)
|
|
||||||
b2 = self.branch2(x)
|
|
||||||
b3 = self.branch3(x)
|
|
||||||
b4 = self.branch4(x)
|
|
||||||
out = torch.cat([b1, b2, b3, b4], dim=1)
|
|
||||||
out = self.fusion(out)
|
|
||||||
return out + x # Residual connection
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||||
@@ -198,11 +158,10 @@ class MyModel(nn.Module):
|
|||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||||
|
|
||||||
# Bottleneck with multiple residual blocks and multi-scale features
|
# Bottleneck with multiple residual blocks
|
||||||
self.bottleneck = nn.Sequential(
|
self.bottleneck = nn.Sequential(
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
MultiScaleFeatureExtraction(base_channels * 16),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
CBAM(base_channels * 16)
|
CBAM(base_channels * 16)
|
||||||
|
|||||||
@@ -10,50 +10,11 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image, ImageEnhance, ImageFilter
|
from PIL import Image
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
|
|
||||||
class DataAugmentation:
|
|
||||||
"""Data augmentation pipeline for improved generalization"""
|
|
||||||
|
|
||||||
def __init__(self, p=0.5):
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
def __call__(self, image: Image.Image) -> Image.Image:
|
|
||||||
# Random horizontal flip
|
|
||||||
if random.random() < self.p:
|
|
||||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
|
||||||
|
|
||||||
# Random vertical flip
|
|
||||||
if random.random() < self.p * 0.5:
|
|
||||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
|
||||||
|
|
||||||
# Random rotation (90 degree increments)
|
|
||||||
if random.random() < self.p * 0.3:
|
|
||||||
angle = random.choice([90, 180, 270])
|
|
||||||
image = image.rotate(angle)
|
|
||||||
|
|
||||||
# Color jittering
|
|
||||||
if random.random() < self.p * 0.4:
|
|
||||||
# Brightness
|
|
||||||
enhancer = ImageEnhance.Brightness(image)
|
|
||||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
|
||||||
|
|
||||||
if random.random() < self.p * 0.4:
|
|
||||||
# Contrast
|
|
||||||
enhancer = ImageEnhance.Contrast(image)
|
|
||||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
|
||||||
|
|
||||||
if random.random() < self.p * 0.3:
|
|
||||||
# Saturation
|
|
||||||
enhancer = ImageEnhance.Color(image)
|
|
||||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
||||||
image_array = np.transpose(image_array, (2, 0, 1))
|
image_array = np.transpose(image_array, (2, 0, 1))
|
||||||
known_array = np.zeros_like(image_array)
|
known_array = np.zeros_like(image_array)
|
||||||
@@ -77,13 +38,11 @@ def preprocess(input_array: np.ndarray):
|
|||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder with augmentation
|
Dataset class for loading images from a folder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str, augment: bool = True):
|
def __init__(self, datafolder: str):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
self.augment = augment
|
|
||||||
self.augmentation = DataAugmentation(p=0.5) if augment else None
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
@@ -91,28 +50,18 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx:int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index]).convert('RGB')
|
image = Image.open(self.imagefiles[index])
|
||||||
|
image = np.asarray(resize(image))
|
||||||
# Apply augmentation before resize
|
|
||||||
if self.augment and self.augmentation is not None:
|
|
||||||
image = self.augmentation(image)
|
|
||||||
|
|
||||||
image = resize(image)
|
|
||||||
image = np.asarray(image)
|
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
spacing_x = random.randint(2,6)
|
||||||
# More varied spacing for better generalization
|
spacing_y = random.randint(2,6)
|
||||||
spacing_x = random.randint(2, 8)
|
offset_x = random.randint(0,8)
|
||||||
spacing_y = random.randint(2, 8)
|
offset_y = random.randint(0,8)
|
||||||
offset_x = random.randint(0, min(spacing_x - 1, 8))
|
|
||||||
offset_y = random.randint(0, min(spacing_y - 1, 8))
|
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
|
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
target_image = torch.from_numpy(np.transpose(image, (2,0,1)))
|
target_image = torch.from_numpy(np.transpose(image, (2,0,1)))
|
||||||
input_array = torch.from_numpy(input_array)
|
input_array = torch.from_numpy(input_array)
|
||||||
known_array = torch.from_numpy(known_array)
|
known_array = torch.from_numpy(known_array)
|
||||||
input_array = torch.cat((input_array, known_array), dim=0)
|
input_array = torch.cat((input_array, known_array), dim=0)
|
||||||
|
|
||||||
return input_array, target_image
|
return input_array, target_image
|
||||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 2e-4 # Slightly lower for stable training
|
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||||
config_dict['weight_decay'] = 5e-5 # Reduced weight decay
|
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 10
|
||||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 56, # Increased capacity for better feature learning
|
'base_channels': 48, # Good balance between capacity and memory
|
||||||
'dropout': 0.08 # Slightly less dropout with augmentation
|
'dropout': 0.1 # Regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from utils import plot, evaluate_model
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -21,58 +20,15 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
def gaussian_kernel(window_size=11, sigma=1.5):
|
|
||||||
"""Create a Gaussian kernel for SSIM computation"""
|
|
||||||
x = torch.arange(window_size).float() - window_size // 2
|
|
||||||
gauss = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
|
||||||
kernel = gauss / gauss.sum()
|
|
||||||
kernel_2d = kernel.unsqueeze(1) * kernel.unsqueeze(0)
|
|
||||||
return kernel_2d.unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class SSIMLoss(nn.Module):
|
|
||||||
"""Structural Similarity Index Loss for perceptual quality"""
|
|
||||||
def __init__(self, window_size=11, sigma=1.5):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
kernel = gaussian_kernel(window_size, sigma)
|
|
||||||
self.register_buffer('kernel', kernel)
|
|
||||||
self.C1 = 0.01 ** 2
|
|
||||||
self.C2 = 0.03 ** 2
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
|
||||||
# Apply to each channel
|
|
||||||
channels = pred.shape[1]
|
|
||||||
kernel = self.kernel.repeat(channels, 1, 1, 1)
|
|
||||||
|
|
||||||
mu_pred = F.conv2d(pred, kernel, padding=self.window_size // 2, groups=channels)
|
|
||||||
mu_target = F.conv2d(target, kernel, padding=self.window_size // 2, groups=channels)
|
|
||||||
|
|
||||||
mu_pred_sq = mu_pred.pow(2)
|
|
||||||
mu_target_sq = mu_target.pow(2)
|
|
||||||
mu_pred_target = mu_pred * mu_target
|
|
||||||
|
|
||||||
sigma_pred_sq = F.conv2d(pred * pred, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_sq
|
|
||||||
sigma_target_sq = F.conv2d(target * target, kernel, padding=self.window_size // 2, groups=channels) - mu_target_sq
|
|
||||||
sigma_pred_target = F.conv2d(pred * target, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_target
|
|
||||||
|
|
||||||
ssim = ((2 * mu_pred_target + self.C1) * (2 * sigma_pred_target + self.C2)) / \
|
|
||||||
((mu_pred_sq + mu_target_sq + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2))
|
|
||||||
|
|
||||||
return 1 - ssim.mean()
|
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + SSIM + Edge for comprehensive image reconstruction"""
|
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3):
|
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
self.mse_weight = mse_weight
|
||||||
self.l1_weight = l1_weight
|
self.l1_weight = l1_weight
|
||||||
self.edge_weight = edge_weight
|
self.edge_weight = edge_weight
|
||||||
self.ssim_weight = ssim_weight
|
|
||||||
self.mse = nn.MSELoss()
|
self.mse = nn.MSELoss()
|
||||||
self.l1 = nn.L1Loss()
|
self.l1 = nn.L1Loss()
|
||||||
self.ssim = SSIMLoss(window_size=7)
|
|
||||||
|
|
||||||
# Sobel filters for edge detection
|
# Sobel filters for edge detection
|
||||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||||
@@ -82,10 +38,10 @@ class CombinedLoss(nn.Module):
|
|||||||
|
|
||||||
def edge_loss(self, pred, target):
|
def edge_loss(self, pred, target):
|
||||||
"""Compute edge-aware loss using Sobel filters"""
|
"""Compute edge-aware loss using Sobel filters"""
|
||||||
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||||
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||||
target_edge_x = F.conv2d(target, self.sobel_x, padding=1, groups=3)
|
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||||
target_edge_y = F.conv2d(target, self.sobel_y, padding=1, groups=3)
|
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||||
|
|
||||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||||
return edge_loss
|
return edge_loss
|
||||||
@@ -94,12 +50,8 @@ class CombinedLoss(nn.Module):
|
|||||||
mse_loss = self.mse(pred, target)
|
mse_loss = self.mse(pred, target)
|
||||||
l1_loss = self.l1(pred, target)
|
l1_loss = self.l1(pred, target)
|
||||||
edge_loss = self.edge_loss(pred, target)
|
edge_loss = self.edge_loss(pred, target)
|
||||||
ssim_loss = self.ssim(pred, target)
|
|
||||||
|
|
||||||
total_loss = (self.mse_weight * mse_loss +
|
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||||
self.l1_weight * l1_loss +
|
|
||||||
self.edge_weight * edge_loss +
|
|
||||||
self.ssim_weight * ssim_loss)
|
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
@@ -160,7 +112,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss for better reconstruction
|
# defining the loss - combined loss for better reconstruction
|
||||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3).to(device)
|
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
|
|||||||
@@ -81,42 +81,9 @@ def read_compressed_file(file_path: str):
|
|||||||
return input_arrays, known_arrays
|
return input_arrays, known_arrays
|
||||||
|
|
||||||
|
|
||||||
def apply_tta(model, input_tensor, device):
|
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
||||||
"""
|
"""
|
||||||
Apply Test-Time Augmentation for better predictions.
|
Here, one might needs to adjust the code based on the used preprocessing
|
||||||
Averages predictions from original and augmented versions.
|
|
||||||
"""
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
# Original
|
|
||||||
out = model(input_tensor)
|
|
||||||
outputs.append(out)
|
|
||||||
|
|
||||||
# Horizontal flip
|
|
||||||
flipped_h = torch.flip(input_tensor, dims=[3])
|
|
||||||
out_h = model(flipped_h)
|
|
||||||
out_h = torch.flip(out_h, dims=[3])
|
|
||||||
outputs.append(out_h)
|
|
||||||
|
|
||||||
# Vertical flip
|
|
||||||
flipped_v = torch.flip(input_tensor, dims=[2])
|
|
||||||
out_v = model(flipped_v)
|
|
||||||
out_v = torch.flip(out_v, dims=[2])
|
|
||||||
outputs.append(out_v)
|
|
||||||
|
|
||||||
# Both flips
|
|
||||||
flipped_hv = torch.flip(input_tensor, dims=[2, 3])
|
|
||||||
out_hv = model(flipped_hv)
|
|
||||||
out_hv = torch.flip(out_hv, dims=[2, 3])
|
|
||||||
outputs.append(out_hv)
|
|
||||||
|
|
||||||
# Average all predictions
|
|
||||||
return torch.stack(outputs, dim=0).mean(dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True):
|
|
||||||
"""
|
|
||||||
Create predictions with optional Test-Time Augmentation for improved results.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -127,7 +94,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
model = MyModel(**model_config)
|
model = MyModel(**model_config)
|
||||||
model.load_state_dict(torch.load(state_dict_path, weights_only=True))
|
model.load_state_dict(torch.load(state_dict_path))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -144,14 +111,9 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(len(input_arrays)):
|
for i in range(len(input_arrays)):
|
||||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||||
input_array = torch.from_numpy(input_arrays[i]).to(device)
|
input_array = torch.from_numpy(input_arrays[i]).to(
|
||||||
input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array
|
device)
|
||||||
|
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
||||||
if use_tta:
|
|
||||||
output = apply_tta(model, input_tensor, device)
|
|
||||||
else:
|
|
||||||
output = model(input_tensor)
|
|
||||||
|
|
||||||
output = output.cpu().numpy()
|
output = output.cpu().numpy()
|
||||||
predictions.append(output)
|
predictions.append(output)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user