diff --git a/image-inpainting/results/testset/tikaiz-21.3950.npz b/image-inpainting/results/testset/tikaiz-21.3950.npz new file mode 100644 index 0000000..a090c87 Binary files /dev/null and b/image-inpainting/results/testset/tikaiz-21.3950.npz differ diff --git a/image-inpainting/src/__pycache__/architecture.cpython-313.pyc b/image-inpainting/src/__pycache__/architecture.cpython-313.pyc index 5295f71..6ba392c 100644 Binary files a/image-inpainting/src/__pycache__/architecture.cpython-313.pyc and b/image-inpainting/src/__pycache__/architecture.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/datasets.cpython-313.pyc b/image-inpainting/src/__pycache__/datasets.cpython-313.pyc index c101f32..ca6d23a 100644 Binary files a/image-inpainting/src/__pycache__/datasets.cpython-313.pyc and b/image-inpainting/src/__pycache__/datasets.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/train.cpython-313.pyc b/image-inpainting/src/__pycache__/train.cpython-313.pyc index 3b0020b..bd94f73 100644 Binary files a/image-inpainting/src/__pycache__/train.cpython-313.pyc and b/image-inpainting/src/__pycache__/train.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/utils.cpython-313.pyc b/image-inpainting/src/__pycache__/utils.cpython-313.pyc index 251d5ec..fdc20e7 100644 Binary files a/image-inpainting/src/__pycache__/utils.cpython-313.pyc and b/image-inpainting/src/__pycache__/utils.cpython-313.pyc differ diff --git a/image-inpainting/src/architecture.py b/image-inpainting/src/architecture.py index 76a6e1e..027c264 100644 --- a/image-inpainting/src/architecture.py +++ b/image-inpainting/src/architecture.py @@ -68,6 +68,46 @@ class CBAM(nn.Module): return x +class MultiScaleFeatureExtraction(nn.Module): + """Multi-scale feature extraction using dilated convolutions""" + def __init__(self, channels): + super().__init__() + self.branch1 = nn.Sequential( + nn.Conv2d(channels, channels // 4, 1), + nn.BatchNorm2d(channels // 4), + nn.LeakyReLU(0.1, inplace=True) + ) + self.branch2 = nn.Sequential( + nn.Conv2d(channels, channels // 4, 3, padding=2, dilation=2), + nn.BatchNorm2d(channels // 4), + nn.LeakyReLU(0.1, inplace=True) + ) + self.branch3 = nn.Sequential( + nn.Conv2d(channels, channels // 4, 3, padding=4, dilation=4), + nn.BatchNorm2d(channels // 4), + nn.LeakyReLU(0.1, inplace=True) + ) + self.branch4 = nn.Sequential( + nn.Conv2d(channels, channels // 4, 3, padding=8, dilation=8), + nn.BatchNorm2d(channels // 4), + nn.LeakyReLU(0.1, inplace=True) + ) + self.fusion = nn.Sequential( + nn.Conv2d(channels, channels, 1), + nn.BatchNorm2d(channels), + nn.LeakyReLU(0.1, inplace=True) + ) + + def forward(self, x): + b1 = self.branch1(x) + b2 = self.branch2(x) + b3 = self.branch3(x) + b4 = self.branch4(x) + out = torch.cat([b1, b2, b3, b4], dim=1) + out = self.fusion(out) + return out + x # Residual connection + + class ConvBlock(nn.Module): """Convolutional block with Conv2d -> BatchNorm -> LeakyReLU""" def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0): @@ -158,10 +198,11 @@ class MyModel(nn.Module): self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout) self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout) - # Bottleneck with multiple residual blocks + # Bottleneck with multiple residual blocks and multi-scale features self.bottleneck = nn.Sequential( ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout), ResidualConvBlock(base_channels * 16, dropout=dropout), + MultiScaleFeatureExtraction(base_channels * 16), ResidualConvBlock(base_channels * 16, dropout=dropout), ResidualConvBlock(base_channels * 16, dropout=dropout), CBAM(base_channels * 16) diff --git a/image-inpainting/src/datasets.py b/image-inpainting/src/datasets.py index d5e74eb..a2353e6 100644 --- a/image-inpainting/src/datasets.py +++ b/image-inpainting/src/datasets.py @@ -10,11 +10,50 @@ import numpy as np import random import glob import os -from PIL import Image +from PIL import Image, ImageEnhance, ImageFilter IMAGE_DIMENSION = 100 +class DataAugmentation: + """Data augmentation pipeline for improved generalization""" + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, image: Image.Image) -> Image.Image: + # Random horizontal flip + if random.random() < self.p: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + # Random vertical flip + if random.random() < self.p * 0.5: + image = image.transpose(Image.FLIP_TOP_BOTTOM) + + # Random rotation (90 degree increments) + if random.random() < self.p * 0.3: + angle = random.choice([90, 180, 270]) + image = image.rotate(angle) + + # Color jittering + if random.random() < self.p * 0.4: + # Brightness + enhancer = ImageEnhance.Brightness(image) + image = enhancer.enhance(random.uniform(0.85, 1.15)) + + if random.random() < self.p * 0.4: + # Contrast + enhancer = ImageEnhance.Contrast(image) + image = enhancer.enhance(random.uniform(0.85, 1.15)) + + if random.random() < self.p * 0.3: + # Saturation + enhancer = ImageEnhance.Color(image) + image = enhancer.enhance(random.uniform(0.85, 1.15)) + + return image + + def create_arrays_from_image(image_array: np.ndarray, offset: tuple, spacing: tuple) -> tuple[np.ndarray, np.ndarray]: image_array = np.transpose(image_array, (2, 0, 1)) known_array = np.zeros_like(image_array) @@ -38,30 +77,42 @@ def preprocess(input_array: np.ndarray): class ImageDataset(torch.utils.data.Dataset): """ - Dataset class for loading images from a folder + Dataset class for loading images from a folder with augmentation """ - def __init__(self, datafolder: str): - self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True)) + def __init__(self, datafolder: str, augment: bool = True): + self.imagefiles = sorted(glob.glob(os.path.join(datafolder, "**", "*.jpg"), recursive=True)) + self.augment = augment + self.augmentation = DataAugmentation(p=0.5) if augment else None def __len__(self): return len(self.imagefiles) - def __getitem__(self, idx:int): + def __getitem__(self, idx: int): index = int(idx) - image = Image.open(self.imagefiles[index]) - image = np.asarray(resize(image)) + image = Image.open(self.imagefiles[index]).convert('RGB') + + # Apply augmentation before resize + if self.augment and self.augmentation is not None: + image = self.augmentation(image) + + image = resize(image) + image = np.asarray(image) image = preprocess(image) - spacing_x = random.randint(2,6) - spacing_y = random.randint(2,6) - offset_x = random.randint(0,8) - offset_y = random.randint(0,8) + + # More varied spacing for better generalization + spacing_x = random.randint(2, 8) + spacing_y = random.randint(2, 8) + offset_x = random.randint(0, min(spacing_x - 1, 8)) + offset_y = random.randint(0, min(spacing_y - 1, 8)) spacing = (spacing_x, spacing_y) offset = (offset_x, offset_y) + input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing) - target_image = torch.from_numpy(np.transpose(image, (2,0,1))) + target_image = torch.from_numpy(np.transpose(image, (2, 0, 1))) input_array = torch.from_numpy(input_array) known_array = torch.from_numpy(known_array) input_array = torch.cat((input_array, known_array), dim=0) + return input_array, target_image \ No newline at end of file diff --git a/image-inpainting/src/main.py b/image-inpainting/src/main.py index dc4f59c..f0913af 100644 --- a/image-inpainting/src/main.py +++ b/image-inpainting/src/main.py @@ -23,11 +23,11 @@ if __name__ == '__main__': config_dict['results_path'] = os.path.join(project_root, "results") config_dict['data_path'] = os.path.join(project_root, "data", "dataset") config_dict['device'] = None - config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW - config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization - config_dict['n_updates'] = 5000 # More updates for better convergence + config_dict['learningrate'] = 2e-4 # Slightly lower for stable training + config_dict['weight_decay'] = 5e-5 # Reduced weight decay + config_dict['n_updates'] = 8000 # More updates for better convergence config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates - config_dict['early_stopping_patience'] = 10 # More patience for complex model + config_dict['early_stopping_patience'] = 15 # More patience for complex model config_dict['use_wandb'] = False config_dict['print_train_stats_at'] = 10 @@ -37,8 +37,8 @@ if __name__ == '__main__': network_config = { 'n_in_channels': 4, - 'base_channels': 48, # Good balance between capacity and memory - 'dropout': 0.1 # Regularization + 'base_channels': 56, # Increased capacity for better feature learning + 'dropout': 0.08 # Slightly less dropout with augmentation } config_dict['network_config'] = network_config diff --git a/image-inpainting/src/train.py b/image-inpainting/src/train.py index 10bf917..f91268c 100644 --- a/image-inpainting/src/train.py +++ b/image-inpainting/src/train.py @@ -10,6 +10,7 @@ from utils import plot, evaluate_model import torch import torch.nn as nn +import torch.nn.functional as F import numpy as np import os @@ -20,15 +21,58 @@ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts import wandb +def gaussian_kernel(window_size=11, sigma=1.5): + """Create a Gaussian kernel for SSIM computation""" + x = torch.arange(window_size).float() - window_size // 2 + gauss = torch.exp(-x.pow(2) / (2 * sigma ** 2)) + kernel = gauss / gauss.sum() + kernel_2d = kernel.unsqueeze(1) * kernel.unsqueeze(0) + return kernel_2d.unsqueeze(0).unsqueeze(0) + + +class SSIMLoss(nn.Module): + """Structural Similarity Index Loss for perceptual quality""" + def __init__(self, window_size=11, sigma=1.5): + super().__init__() + self.window_size = window_size + kernel = gaussian_kernel(window_size, sigma) + self.register_buffer('kernel', kernel) + self.C1 = 0.01 ** 2 + self.C2 = 0.03 ** 2 + + def forward(self, pred, target): + # Apply to each channel + channels = pred.shape[1] + kernel = self.kernel.repeat(channels, 1, 1, 1) + + mu_pred = F.conv2d(pred, kernel, padding=self.window_size // 2, groups=channels) + mu_target = F.conv2d(target, kernel, padding=self.window_size // 2, groups=channels) + + mu_pred_sq = mu_pred.pow(2) + mu_target_sq = mu_target.pow(2) + mu_pred_target = mu_pred * mu_target + + sigma_pred_sq = F.conv2d(pred * pred, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_sq + sigma_target_sq = F.conv2d(target * target, kernel, padding=self.window_size // 2, groups=channels) - mu_target_sq + sigma_pred_target = F.conv2d(pred * target, kernel, padding=self.window_size // 2, groups=channels) - mu_pred_target + + ssim = ((2 * mu_pred_target + self.C1) * (2 * sigma_pred_target + self.C2)) / \ + ((mu_pred_sq + mu_target_sq + self.C1) * (sigma_pred_sq + sigma_target_sq + self.C2)) + + return 1 - ssim.mean() + + class CombinedLoss(nn.Module): - """Combined loss: MSE + L1 + SSIM-like perceptual component""" - def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1): + """Combined loss: MSE + L1 + SSIM + Edge for comprehensive image reconstruction""" + def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3): super().__init__() self.mse_weight = mse_weight self.l1_weight = l1_weight self.edge_weight = edge_weight + self.ssim_weight = ssim_weight self.mse = nn.MSELoss() self.l1 = nn.L1Loss() + self.ssim = SSIMLoss(window_size=7) # Sobel filters for edge detection sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) @@ -38,10 +82,10 @@ class CombinedLoss(nn.Module): def edge_loss(self, pred, target): """Compute edge-aware loss using Sobel filters""" - pred_edge_x = torch.nn.functional.conv2d(pred, self.sobel_x, padding=1, groups=3) - pred_edge_y = torch.nn.functional.conv2d(pred, self.sobel_y, padding=1, groups=3) - target_edge_x = torch.nn.functional.conv2d(target, self.sobel_x, padding=1, groups=3) - target_edge_y = torch.nn.functional.conv2d(target, self.sobel_y, padding=1, groups=3) + pred_edge_x = F.conv2d(pred, self.sobel_x, padding=1, groups=3) + pred_edge_y = F.conv2d(pred, self.sobel_y, padding=1, groups=3) + target_edge_x = F.conv2d(target, self.sobel_x, padding=1, groups=3) + target_edge_y = F.conv2d(target, self.sobel_y, padding=1, groups=3) edge_loss = self.l1(pred_edge_x, target_edge_x) + self.l1(pred_edge_y, target_edge_y) return edge_loss @@ -50,8 +94,12 @@ class CombinedLoss(nn.Module): mse_loss = self.mse(pred, target) l1_loss = self.l1(pred, target) edge_loss = self.edge_loss(pred, target) + ssim_loss = self.ssim(pred, target) - total_loss = self.mse_weight * mse_loss + self.l1_weight * l1_loss + self.edge_weight * edge_loss + total_loss = (self.mse_weight * mse_loss + + self.l1_weight * l1_loss + + self.edge_weight * edge_loss + + self.ssim_weight * ssim_loss) return total_loss @@ -112,7 +160,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st network.train() # defining the loss - combined loss for better reconstruction - combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device) + combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.15, ssim_weight=0.3).to(device) mse_loss = torch.nn.MSELoss() # Keep for evaluation # defining the optimizer with AdamW for better weight decay handling diff --git a/image-inpainting/src/utils.py b/image-inpainting/src/utils.py index b12d2cf..fd9f1f9 100644 --- a/image-inpainting/src/utils.py +++ b/image-inpainting/src/utils.py @@ -81,9 +81,42 @@ def read_compressed_file(file_path: str): return input_arrays, known_arrays -def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None): +def apply_tta(model, input_tensor, device): """ - Here, one might needs to adjust the code based on the used preprocessing + Apply Test-Time Augmentation for better predictions. + Averages predictions from original and augmented versions. + """ + outputs = [] + + # Original + out = model(input_tensor) + outputs.append(out) + + # Horizontal flip + flipped_h = torch.flip(input_tensor, dims=[3]) + out_h = model(flipped_h) + out_h = torch.flip(out_h, dims=[3]) + outputs.append(out_h) + + # Vertical flip + flipped_v = torch.flip(input_tensor, dims=[2]) + out_v = model(flipped_v) + out_v = torch.flip(out_v, dims=[2]) + outputs.append(out_v) + + # Both flips + flipped_hv = torch.flip(input_tensor, dims=[2, 3]) + out_hv = model(flipped_hv) + out_hv = torch.flip(out_hv, dims=[2, 3]) + outputs.append(out_hv) + + # Average all predictions + return torch.stack(outputs, dim=0).mean(dim=0) + + +def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None, use_tta=True): + """ + Create predictions with optional Test-Time Augmentation for improved results. """ if device is None: @@ -94,7 +127,7 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save device = torch.device(device) model = MyModel(**model_config) - model.load_state_dict(torch.load(state_dict_path)) + model.load_state_dict(torch.load(state_dict_path, weights_only=True)) model.to(device) model.eval() @@ -111,9 +144,14 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save with torch.no_grad(): for i in range(len(input_arrays)): print(f"Processing image {i + 1}/{len(input_arrays)}") - input_array = torch.from_numpy(input_arrays[i]).to( - device) - output = model(input_array.unsqueeze(0) if hasattr(input_array, 'dim') and input_array.dim() == 3 else input_array) + input_array = torch.from_numpy(input_arrays[i]).to(device) + input_tensor = input_array.unsqueeze(0) if input_array.dim() == 3 else input_array + + if use_tta: + output = apply_tta(model, input_tensor, device) + else: + output = model(input_tensor) + output = output.cpu().numpy() predictions.append(output)