diff --git a/image-inpainting/results/testset/tikaiz-21.3905.npz b/image-inpainting/results/testset/tikaiz-21.3905.npz new file mode 100644 index 0000000..e83ce6a Binary files /dev/null and b/image-inpainting/results/testset/tikaiz-21.3905.npz differ diff --git a/image-inpainting/src/__pycache__/architecture.cpython-313.pyc b/image-inpainting/src/__pycache__/architecture.cpython-313.pyc index 5295f71..7f9190e 100644 Binary files a/image-inpainting/src/__pycache__/architecture.cpython-313.pyc and b/image-inpainting/src/__pycache__/architecture.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/datasets.cpython-313.pyc b/image-inpainting/src/__pycache__/datasets.cpython-313.pyc index c101f32..91c5054 100644 Binary files a/image-inpainting/src/__pycache__/datasets.cpython-313.pyc and b/image-inpainting/src/__pycache__/datasets.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/train.cpython-313.pyc b/image-inpainting/src/__pycache__/train.cpython-313.pyc index 3b0020b..1d2f61e 100644 Binary files a/image-inpainting/src/__pycache__/train.cpython-313.pyc and b/image-inpainting/src/__pycache__/train.cpython-313.pyc differ diff --git a/image-inpainting/src/__pycache__/utils.cpython-313.pyc b/image-inpainting/src/__pycache__/utils.cpython-313.pyc index 251d5ec..dc2f6aa 100644 Binary files a/image-inpainting/src/__pycache__/utils.cpython-313.pyc and b/image-inpainting/src/__pycache__/utils.cpython-313.pyc differ diff --git a/image-inpainting/src/architecture.py b/image-inpainting/src/architecture.py index 76a6e1e..d9aece7 100644 --- a/image-inpainting/src/architecture.py +++ b/image-inpainting/src/architecture.py @@ -100,6 +100,26 @@ class ResidualConvBlock(nn.Module): return self.relu(out) +class DilatedResidualBlock(nn.Module): + """Residual block with dilated convolutions for larger receptive field""" + def __init__(self, channels, dilation=2, dropout=0.0): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation) + self.bn1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation) + self.bn2 = nn.BatchNorm2d(channels) + self.relu = nn.LeakyReLU(0.1, inplace=True) + self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() + + def forward(self, x): + residual = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.dropout(out) + out = self.bn2(self.conv2(out)) + out = out + residual + return self.relu(out) + + class DownBlock(nn.Module): """Downsampling block with conv blocks, residual connection, attention, and max pooling""" def __init__(self, in_channels, out_channels, dropout=0.1): @@ -158,11 +178,12 @@ class MyModel(nn.Module): self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout) self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout) - # Bottleneck with multiple residual blocks + # Bottleneck with multi-scale dilated convolutions (ASPP-style) self.bottleneck = nn.Sequential( ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout), ResidualConvBlock(base_channels * 16, dropout=dropout), - ResidualConvBlock(base_channels * 16, dropout=dropout), + DilatedResidualBlock(base_channels * 16, dilation=2, dropout=dropout), + DilatedResidualBlock(base_channels * 16, dilation=4, dropout=dropout), ResidualConvBlock(base_channels * 16, dropout=dropout), CBAM(base_channels * 16) ) diff --git a/image-inpainting/src/datasets.py b/image-inpainting/src/datasets.py index d5e74eb..db95954 100644 --- a/image-inpainting/src/datasets.py +++ b/image-inpainting/src/datasets.py @@ -10,7 +10,7 @@ import numpy as np import random import glob import os -from PIL import Image +from PIL import Image, ImageEnhance IMAGE_DIMENSION = 100 @@ -38,25 +38,69 @@ def preprocess(input_array: np.ndarray): class ImageDataset(torch.utils.data.Dataset): """ - Dataset class for loading images from a folder + Dataset class for loading images from a folder with data augmentation """ - def __init__(self, datafolder: str): + def __init__(self, datafolder: str, augment: bool = True): self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True)) + self.augment = augment def __len__(self): return len(self.imagefiles) + def augment_image(self, image: Image) -> Image: + """Apply random augmentations to image""" + # Random horizontal flip + if random.random() > 0.5: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + # Random vertical flip + if random.random() > 0.5: + image = image.transpose(Image.FLIP_TOP_BOTTOM) + + # Random rotation (90, 180, 270 degrees) + if random.random() > 0.5: + angle = random.choice([90, 180, 270]) + image = image.rotate(angle) + + # Random brightness adjustment + if random.random() > 0.5: + enhancer = ImageEnhance.Brightness(image) + factor = random.uniform(0.8, 1.2) + image = enhancer.enhance(factor) + + # Random contrast adjustment + if random.random() > 0.5: + enhancer = ImageEnhance.Contrast(image) + factor = random.uniform(0.8, 1.2) + image = enhancer.enhance(factor) + + # Random color adjustment + if random.random() > 0.5: + enhancer = ImageEnhance.Color(image) + factor = random.uniform(0.8, 1.2) + image = enhancer.enhance(factor) + + return image + def __getitem__(self, idx:int): index = int(idx) image = Image.open(self.imagefiles[index]) - image = np.asarray(resize(image)) + image = resize(image) + + # Apply augmentation if enabled + if self.augment: + image = self.augment_image(image) + + image = np.asarray(image) image = preprocess(image) - spacing_x = random.randint(2,6) - spacing_y = random.randint(2,6) - offset_x = random.randint(0,8) - offset_y = random.randint(0,8) + + # Vary spacing and offset more for additional diversity + spacing_x = random.randint(2,7) + spacing_y = random.randint(2,7) + offset_x = random.randint(0,10) + offset_y = random.randint(0,10) spacing = (spacing_x, spacing_y) offset = (offset_x, offset_y) input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing) diff --git a/image-inpainting/src/main.py b/image-inpainting/src/main.py index 9f61cd6..0569cbb 100644 --- a/image-inpainting/src/main.py +++ b/image-inpainting/src/main.py @@ -24,22 +24,22 @@ if __name__ == '__main__': config_dict['results_path'] = os.path.join(project_root, "results") config_dict['data_path'] = os.path.join(project_root, "data", "dataset") config_dict['device'] = None - config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW - config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization - config_dict['n_updates'] = 5000 # More updates for better convergence - config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates - config_dict['early_stopping_patience'] = 10 # More patience for complex model + config_dict['learningrate'] = 2e-4 # Slightly lower for more stable training + config_dict['weight_decay'] = 5e-5 # Reduced for less aggressive regularization + config_dict['n_updates'] = 8000 # More updates for better convergence + config_dict['batchsize'] = 12 # Larger batch for more stable gradients + config_dict['early_stopping_patience'] = 15 # More patience for complex model config_dict['use_wandb'] = False config_dict['print_train_stats_at'] = 10 config_dict['print_stats_at'] = 100 - config_dict['plot_at'] = 300 - config_dict['validate_at'] = 300 # Validate more frequently + config_dict['plot_at'] = 400 + config_dict['validate_at'] = 200 # Validate frequently but not too often network_config = { 'n_in_channels': 4, - 'base_channels': 48, # Good balance between capacity and memory - 'dropout': 0.1 # Regularization + 'base_channels': 32, # Smaller base for efficiency, depth compensates + 'dropout': 0.15 # Slightly more regularization with augmentation } config_dict['network_config'] = network_config diff --git a/image-inpainting/src/train.py b/image-inpainting/src/train.py index 10bf917..524e04b 100644 --- a/image-inpainting/src/train.py +++ b/image-inpainting/src/train.py @@ -21,8 +21,8 @@ import wandb class CombinedLoss(nn.Module): - """Combined loss: MSE + L1 + SSIM-like perceptual component""" - def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1): + """Combined loss: MSE + L1 + Edge-aware component for better reconstruction""" + def __init__(self, mse_weight=0.7, l1_weight=0.8, edge_weight=0.2): super().__init__() self.mse_weight = mse_weight self.l1_weight = l1_weight @@ -84,20 +84,24 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st plotpath = os.path.join(results_path, "plots") os.makedirs(plotpath, exist_ok=True) - image_dataset = datasets.ImageDataset(datafolder=data_path) + # Create dataset with augmentation for training, without for validation/test + image_dataset_full = datasets.ImageDataset(datafolder=data_path, augment=False) - n_total = len(image_dataset) + n_total = len(image_dataset_full) n_test = int(n_total * testset_ratio) n_valid = int(n_total * validset_ratio) n_train = n_total - n_test - n_valid indices = np.random.permutation(n_total) - dataset_train = Subset(image_dataset, indices=indices[0:n_train]) - dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid]) - dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total]) + + # Create augmented dataset for training + image_dataset_train = datasets.ImageDataset(datafolder=data_path, augment=True) + dataset_train = Subset(image_dataset_train, indices=indices[0:n_train]) + dataset_valid = Subset(image_dataset_full, indices=indices[n_train:n_train + n_valid]) + dataset_test = Subset(image_dataset_full, indices=indices[n_train + n_valid:n_total]) - assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid) + assert n_total == len(dataset_train) + len(dataset_test) + len(dataset_valid) - del image_dataset + del image_dataset_full, image_dataset_train dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize, num_workers=0, shuffle=True) @@ -111,15 +115,19 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st network.to(device) network.train() - # defining the loss - combined loss for better reconstruction - combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device) + # defining the loss - combined loss with optimized weights + combined_loss = CombinedLoss(mse_weight=0.7, l1_weight=0.8, edge_weight=0.2).to(device) mse_loss = torch.nn.MSELoss() # Keep for evaluation # defining the optimizer with AdamW for better weight decay handling optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay) - # Learning rate scheduler for better convergence - scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6) + # Learning rate scheduler with better configuration + scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=1e-7) + + # Mixed precision training for faster computation and lower memory usage + scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None + use_amp = scaler is not None if use_wandb: wandb.watch(network, mse_loss, log="all", log_freq=10) @@ -128,10 +136,13 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st counter = 0 best_validation_loss = np.inf loss_list = [] + accumulation_steps = 2 # Gradient accumulation for effective larger batch size saved_model_path = os.path.join(results_path, "best_model.pt") print(f"Started training on device {device}") + print(f"Using mixed precision: {use_amp}") + print(f"Gradient accumulation steps: {accumulation_steps}") while i < n_updates: @@ -142,21 +153,33 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st if (i + 1) % print_train_stats_at == 0: print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}') - optimizer.zero_grad() - - output = network(input) - - loss = combined_loss(output, target) - - loss.backward() + # Use mixed precision if available + if use_amp: + with torch.cuda.amp.autocast(): + output = network(input) + loss = combined_loss(output, target) + loss = loss / accumulation_steps + scaler.scale(loss).backward() + else: + output = network(input) + loss = combined_loss(output, target) + loss = loss / accumulation_steps + loss.backward() - # Gradient clipping for training stability - torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0) + # Gradient accumulation - update weights every accumulation_steps + if (i + 1) % accumulation_steps == 0: + if use_amp: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + else: + torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + scheduler.step(i / n_updates) - optimizer.step() - scheduler.step(i + len(loss_list) / len(dataloader_train)) - - loss_list.append(loss.item()) + loss_list.append(loss.item() * accumulation_steps) # writing the stats to wandb if use_wandb and (i+1) % print_stats_at == 0: @@ -165,7 +188,9 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st # plotting if (i + 1) % plot_at == 0: print(f"Plotting images, current update {i + 1}") - plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i) + # Convert to float32 for matplotlib compatibility (mixed precision may produce float16) + plot(input.float().cpu().numpy(), target.detach().float().cpu().numpy(), + output.detach().float().cpu().numpy(), plotpath, i) # evaluating model every validate_at sample if (i + 1) % validate_at == 0: