Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8b9b3f6 | |||
| 7d4caaf501 | |||
| 248ffb8faf | |||
| 1771377121 | |||
| eaf45f5c72 | |||
| 8f0fb11926 |
1
image-inpainting/.gitignore
vendored
1
image-inpainting/.gitignore
vendored
@@ -2,3 +2,4 @@ data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
*.pt
|
||||
__pycache__/
|
||||
BIN
image-inpainting/results/submissions/tikaiz-1.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-1.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -68,46 +68,6 @@ class CBAM(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleFeatureExtraction(nn.Module):
|
||||
"""Multi-scale feature extraction using dilated convolutions"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.branch1 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 1),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch2 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=2, dilation=2),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=4, dilation=4),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.branch4 = nn.Sequential(
|
||||
nn.Conv2d(channels, channels // 4, 3, padding=8, dilation=8),
|
||||
nn.BatchNorm2d(channels // 4),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.LeakyReLU(0.1, inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b1 = self.branch1(x)
|
||||
b2 = self.branch2(x)
|
||||
b3 = self.branch3(x)
|
||||
b4 = self.branch4(x)
|
||||
out = torch.cat([b1, b2, b3, b4], dim=1)
|
||||
out = self.fusion(out)
|
||||
return out + x # Residual connection
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||
@@ -198,11 +158,10 @@ class MyModel(nn.Module):
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
|
||||
# Bottleneck with multiple residual blocks and multi-scale features
|
||||
# Bottleneck with multiple residual blocks
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
MultiScaleFeatureExtraction(base_channels * 16),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
CBAM(base_channels * 16)
|
||||
|
||||
@@ -10,50 +10,11 @@ import numpy as np
|
||||
import random
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image, ImageEnhance, ImageFilter
|
||||
from PIL import Image
|
||||
|
||||
IMAGE_DIMENSION = 100
|
||||
|
||||
|
||||
class DataAugmentation:
|
||||
"""Data augmentation pipeline for improved generalization"""
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, image: Image.Image) -> Image.Image:
|
||||
# Random horizontal flip
|
||||
if random.random() < self.p:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() < self.p * 0.5:
|
||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90 degree increments)
|
||||
if random.random() < self.p * 0.3:
|
||||
angle = random.choice([90, 180, 270])
|
||||
image = image.rotate(angle)
|
||||
|
||||
# Color jittering
|
||||
if random.random() < self.p * 0.4:
|
||||
# Brightness
|
||||
enhancer = ImageEnhance.Brightness(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
if random.random() < self.p * 0.4:
|
||||
# Contrast
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
if random.random() < self.p * 0.3:
|
||||
# Saturation
|
||||
enhancer = ImageEnhance.Color(image)
|
||||
image = enhancer.enhance(random.uniform(0.85, 1.15))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]:
|
||||
image_array = np.transpose(image_array, (2, 0, 1))
|
||||
known_array = np.zeros_like(image_array)
|
||||
@@ -77,13 +38,11 @@ def preprocess(input_array: np.ndarray):
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder with augmentation
|
||||
Dataset class for loading images from a folder
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str, augment: bool = True):
|
||||
def __init__(self, datafolder: str):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||
self.augment = augment
|
||||
self.augmentation = DataAugmentation(p=0.5) if augment else None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
@@ -91,28 +50,18 @@ class ImageDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx:int):
|
||||
index = int(idx)
|
||||
|
||||
image = Image.open(self.imagefiles[index]).convert('RGB')
|
||||
|
||||
# Apply augmentation before resize
|
||||
if self.augment and self.augmentation is not None:
|
||||
image = self.augmentation(image)
|
||||
|
||||
image = resize(image)
|
||||
image = np.asarray(image)
|
||||
image = Image.open(self.imagefiles[index])
|
||||
image = np.asarray(resize(image))
|
||||
image = preprocess(image)
|
||||
|
||||
# More varied spacing for better generalization
|
||||
spacing_x = random.randint(2, 8)
|
||||
spacing_y = random.randint(2, 8)
|
||||
offset_x = random.randint(0, min(spacing_x - 1, 8))
|
||||
offset_y = random.randint(0, min(spacing_y - 1, 8))
|
||||
spacing_x = random.randint(2,6)
|
||||
spacing_y = random.randint(2,6)
|
||||
offset_x = random.randint(0,8)
|
||||
offset_y = random.randint(0,8)
|
||||
spacing = (spacing_x, spacing_y)
|
||||
offset = (offset_x, offset_y)
|
||||
|
||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||
target_image = torch.from_numpy(np.transpose(image, (2,0,1)))
|
||||
input_array = torch.from_numpy(input_array)
|
||||
known_array = torch.from_numpy(known_array)
|
||||
input_array = torch.cat((input_array, known_array), dim=0)
|
||||
|
||||
return input_array, target_image
|
||||
@@ -24,11 +24,11 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 2e-4 # Slightly lower for stable training
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced weight decay
|
||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 56, # Increased capacity for better feature learning
|
||||
'dropout': 0.08 # Slightly less dropout with augmentation
|
||||
'base_channels': 48, # Good balance between capacity and memory
|
||||
'dropout': 0.1 # Regularization
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -10,7 +10,6 @@ from utils import plot, evaluate_model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
@@ -21,58 +20,15 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
import wandb
|
||||
|
||||
|
||||
def gaussian_kernel(window_size=11, sigma=1.5):
|
||||
"""Create a Gaussian kernel for SSIM computation"""
|
||||
x = torch.arange(window_size).float() - window_size // 2
|
||||
gauss = torch.exp(-x.pow(2) / (2 * sigma ** 2))
|
||||
kernel = gauss / gauss.sum()
|
||||
kernel_2d = kernel.unsqueeze(1) * kernel.unsqueeze(0)
|
||||
return kernel_2d.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
class SSIMLoss(nn.Module):
|
||||
"""Structural Similarity Index Loss for perceptual quality"""
|
||||
def __init__(self, window_size=11, sigma=1.5):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
kernel = gaussian_kernel(window_size, sigma)
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.C1 = 0.01 ** 2
|
||||
self.C2 = 0.03 ** 2
|
||||
|
||||
def forward(self, pred, target):
|
||||
# Apply to each channel
|
||||
channels = pred.shape[1]
|
||||
kernel = self.kernel.repeat(channels, 1, 1, 1)
|
||||
|
||||
mu_pred = F.conv2d(pred, kernel, padding=self.window_size // 2, groups=channels)
|
||||
mu_target = F.conv2d(target, kernel, padding=self.window_size // 2, groups=channels)
|
||||
|
||||
mu_pred_sq = mu_pred.pow(2)
|
||||
mu_target_sq = mu_target.pow(2)
|
||||
mu_pred_target = mu_pred * mu_target
|
||||
|
||||
sigma_pred_sq = F.conv2d(pred * pred, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_sq
|
||||
sigma_target_sq = F.conv2d(target * target, kernel, padding=self.window_size // 2, groups=channels) - mu_target_sq
|
||||
sigma_pred_target = F.conv2d(pred * target, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_target
|
||||
|
||||
ssim = ((2 * mu_pred_target + self.C1) * (2 * sigma_pred_target + self.C2)) / \
|
||||
((mu_pred_sq + mu_target_sq + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2))
|
||||
|
||||
return 1 - ssim.mean()
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + SSIM + Edge for comprehensive image reconstruction"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3):
|
||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ssim_weight = ssim_weight
|
||||
self.mse = nn.MSELoss()
|
||||
self.l1 = nn.L1Loss()
|
||||
self.ssim = SSIMLoss(window_size=7)
|
||||
|
||||
# Sobel filters for edge detection
|
||||
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
||||
@@ -82,10 +38,10 @@ class CombinedLoss(nn.Module):
|
||||
|
||||
def edge_loss(self, pred, target):
|
||||
"""Compute edge-aware loss using Sobel filters"""
|
||||
pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = F.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = F.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3)
|
||||
pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3)
|
||||
target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3)
|
||||
target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3)
|
||||
|
||||
edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y)
|
||||
return edge_loss
|
||||
@@ -94,12 +50,8 @@ class CombinedLoss(nn.Module):
|
||||
mse_loss = self.mse(pred, target)
|
||||
l1_loss = self.l1(pred, target)
|
||||
edge_loss = self.edge_loss(pred, target)
|
||||
ssim_loss = self.ssim(pred, target)
|
||||
|
||||
total_loss = (self.mse_weight * mse_loss +
|
||||
self.l1_weight * l1_loss +
|
||||
self.edge_weight * edge_loss +
|
||||
self.ssim_weight * ssim_loss)
|
||||
total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss
|
||||
return total_loss
|
||||
|
||||
|
||||
@@ -160,7 +112,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss for better reconstruction
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3).to(device)
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
|
||||
@@ -81,42 +81,9 @@ def read_compressed_file(file_path: str):
|
||||
return input_arrays, known_arrays
|
||||
|
||||
|
||||
def apply_tta(model, input_tensor, device):
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
||||
"""
|
||||
Apply Test-Time Augmentation for better predictions.
|
||||
Averages predictions from original and augmented versions.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
# Original
|
||||
out = model(input_tensor)
|
||||
outputs.append(out)
|
||||
|
||||
# Horizontal flip
|
||||
flipped_h = torch.flip(input_tensor, dims=[3])
|
||||
out_h = model(flipped_h)
|
||||
out_h = torch.flip(out_h, dims=[3])
|
||||
outputs.append(out_h)
|
||||
|
||||
# Vertical flip
|
||||
flipped_v = torch.flip(input_tensor, dims=[2])
|
||||
out_v = model(flipped_v)
|
||||
out_v = torch.flip(out_v, dims=[2])
|
||||
outputs.append(out_v)
|
||||
|
||||
# Both flips
|
||||
flipped_hv = torch.flip(input_tensor, dims=[2, 3])
|
||||
out_hv = model(flipped_hv)
|
||||
out_hv = torch.flip(out_hv, dims=[2, 3])
|
||||
outputs.append(out_hv)
|
||||
|
||||
# Average all predictions
|
||||
return torch.stack(outputs, dim=0).mean(dim=0)
|
||||
|
||||
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True):
|
||||
"""
|
||||
Create predictions with optional Test-Time Augmentation for improved results.
|
||||
Here, one might needs to adjust the code based on the used preprocessing
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
@@ -127,7 +94,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
device = torch.device(device)
|
||||
|
||||
model = MyModel(**model_config)
|
||||
model.load_state_dict(torch.load(state_dict_path, weights_only=True))
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
@@ -144,14 +111,9 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
with torch.no_grad():
|
||||
for i in range(len(input_arrays)):
|
||||
print(f"Processing image {i + 1}/{len(input_arrays)}")
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(device)
|
||||
input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array
|
||||
|
||||
if use_tta:
|
||||
output = apply_tta(model, input_tensor, device)
|
||||
else:
|
||||
output = model(input_tensor)
|
||||
|
||||
input_array = torch.from_numpy(input_arrays[i]).to(
|
||||
device)
|
||||
output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array)
|
||||
output = output.cpu().numpy()
|
||||
predictions.append(output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user