Compare commits
5 Commits
claude-son
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8b9b3f6 | |||
| 7d4caaf501 | |||
| 248ffb8faf | |||
| 1771377121 | |||
| eaf45f5c72 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/*
|
data/*
|
||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -100,26 +100,6 @@ class ResidualConvBlock(nn.Module):
|
|||||||
return self.relu(out)
|
return self.relu(out)
|
||||||
|
|
||||||
|
|
||||||
class DilatedResidualBlock(nn.Module):
|
|
||||||
"""Residual block with dilated convolutions for larger receptive field"""
|
|
||||||
def __init__(self, channels, dilation=2, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
|
||||||
self.bn1 = nn.BatchNorm2d(channels)
|
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
|
||||||
self.bn2 = nn.BatchNorm2d(channels)
|
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
residual = x
|
|
||||||
out = self.relu(self.bn1(self.conv1(x)))
|
|
||||||
out = self.dropout(out)
|
|
||||||
out = self.bn2(self.conv2(out))
|
|
||||||
out = out + residual
|
|
||||||
return self.relu(out)
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||||
@@ -178,12 +158,11 @@ class MyModel(nn.Module):
|
|||||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||||
|
|
||||||
# Bottleneck with multi-scale dilated convolutions (ASPP-style)
|
# Bottleneck with multiple residual blocks
|
||||||
self.bottleneck = nn.Sequential(
|
self.bottleneck = nn.Sequential(
|
||||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
DilatedResidualBlock(base_channels * 16, dilation=2, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
DilatedResidualBlock(base_channels * 16, dilation=4, dropout=dropout),
|
|
||||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||||
CBAM(base_channels * 16)
|
CBAM(base_channels * 16)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
from PIL import Image, ImageEnhance
|
from PIL import Image
|
||||||
|
|
||||||
IMAGE_DIMENSION = 100
|
IMAGE_DIMENSION = 100
|
||||||
|
|
||||||
@@ -38,69 +38,25 @@ def preprocess(input_array: np.ndarray):
|
|||||||
|
|
||||||
class ImageDataset(torch.utils.data.Dataset):
|
class ImageDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset class for loading images from a folder with data augmentation
|
Dataset class for loading images from a folder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datafolder: str, augment: bool = True):
|
def __init__(self, datafolder: str):
|
||||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||||
self.augment = augment
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imagefiles)
|
return len(self.imagefiles)
|
||||||
|
|
||||||
def augment_image(self, image: Image) -> Image:
|
|
||||||
"""Apply random augmentations to image"""
|
|
||||||
# Random horizontal flip
|
|
||||||
if random.random() > 0.5:
|
|
||||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
|
||||||
|
|
||||||
# Random vertical flip
|
|
||||||
if random.random() > 0.5:
|
|
||||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
|
||||||
|
|
||||||
# Random rotation (90, 180, 270 degrees)
|
|
||||||
if random.random() > 0.5:
|
|
||||||
angle = random.choice([90, 180, 270])
|
|
||||||
image = image.rotate(angle)
|
|
||||||
|
|
||||||
# Random brightness adjustment
|
|
||||||
if random.random() > 0.5:
|
|
||||||
enhancer = ImageEnhance.Brightness(image)
|
|
||||||
factor = random.uniform(0.8, 1.2)
|
|
||||||
image = enhancer.enhance(factor)
|
|
||||||
|
|
||||||
# Random contrast adjustment
|
|
||||||
if random.random() > 0.5:
|
|
||||||
enhancer = ImageEnhance.Contrast(image)
|
|
||||||
factor = random.uniform(0.8, 1.2)
|
|
||||||
image = enhancer.enhance(factor)
|
|
||||||
|
|
||||||
# Random color adjustment
|
|
||||||
if random.random() > 0.5:
|
|
||||||
enhancer = ImageEnhance.Color(image)
|
|
||||||
factor = random.uniform(0.8, 1.2)
|
|
||||||
image = enhancer.enhance(factor)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx:int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index])
|
image = Image.open(self.imagefiles[index])
|
||||||
image = resize(image)
|
image = np.asarray(resize(image))
|
||||||
|
|
||||||
# Apply augmentation if enabled
|
|
||||||
if self.augment:
|
|
||||||
image = self.augment_image(image)
|
|
||||||
|
|
||||||
image = np.asarray(image)
|
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
spacing_x = random.randint(2,6)
|
||||||
# Vary spacing and offset more for additional diversity
|
spacing_y = random.randint(2,6)
|
||||||
spacing_x = random.randint(2,7)
|
offset_x = random.randint(0,8)
|
||||||
spacing_y = random.randint(2,7)
|
offset_y = random.randint(0,8)
|
||||||
offset_x = random.randint(0,10)
|
|
||||||
offset_y = random.randint(0,10)
|
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
|
|||||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
|||||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||||
config_dict['device'] = None
|
config_dict['device'] = None
|
||||||
config_dict['learningrate'] = 2e-4 # Slightly lower for more stable training
|
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||||
config_dict['weight_decay'] = 5e-5 # Reduced for less aggressive regularization
|
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||||
config_dict['batchsize'] = 12 # Larger batch for more stable gradients
|
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||||
config_dict['use_wandb'] = False
|
config_dict['use_wandb'] = False
|
||||||
|
|
||||||
config_dict['print_train_stats_at'] = 10
|
config_dict['print_train_stats_at'] = 10
|
||||||
config_dict['print_stats_at'] = 100
|
config_dict['print_stats_at'] = 100
|
||||||
config_dict['plot_at'] = 400
|
config_dict['plot_at'] = 300
|
||||||
config_dict['validate_at'] = 200 # Validate frequently but not too often
|
config_dict['validate_at'] = 300 # Validate more frequently
|
||||||
|
|
||||||
network_config = {
|
network_config = {
|
||||||
'n_in_channels': 4,
|
'n_in_channels': 4,
|
||||||
'base_channels': 32, # Smaller base for efficiency, depth compensates
|
'base_channels': 48, # Good balance between capacity and memory
|
||||||
'dropout': 0.15 # Slightly more regularization with augmentation
|
'dropout': 0.1 # Regularization
|
||||||
}
|
}
|
||||||
|
|
||||||
config_dict['network_config'] = network_config
|
config_dict['network_config'] = network_config
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ import wandb
|
|||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Module):
|
class CombinedLoss(nn.Module):
|
||||||
"""Combined loss: MSE + L1 + Edge-aware component for better reconstruction"""
|
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||||
def __init__(self, mse_weight=0.7, l1_weight=0.8, edge_weight=0.2):
|
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse_weight = mse_weight
|
self.mse_weight = mse_weight
|
||||||
self.l1_weight = l1_weight
|
self.l1_weight = l1_weight
|
||||||
@@ -84,24 +84,20 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
plotpath = os.path.join(results_path, "plots")
|
plotpath = os.path.join(results_path, "plots")
|
||||||
os.makedirs(plotpath, exist_ok=True)
|
os.makedirs(plotpath, exist_ok=True)
|
||||||
|
|
||||||
# Create dataset with augmentation for training, without for validation/test
|
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||||
image_dataset_full = datasets.ImageDataset(datafolder=data_path, augment=False)
|
|
||||||
|
|
||||||
n_total = len(image_dataset_full)
|
n_total = len(image_dataset)
|
||||||
n_test = int(n_total * testset_ratio)
|
n_test = int(n_total * testset_ratio)
|
||||||
n_valid = int(n_total * validset_ratio)
|
n_valid = int(n_total * validset_ratio)
|
||||||
n_train = n_total - n_test - n_valid
|
n_train = n_total - n_test - n_valid
|
||||||
indices = np.random.permutation(n_total)
|
indices = np.random.permutation(n_total)
|
||||||
|
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||||
# Create augmented dataset for training
|
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||||
image_dataset_train = datasets.ImageDataset(datafolder=data_path, augment=True)
|
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||||
dataset_train = Subset(image_dataset_train, indices=indices[0:n_train])
|
|
||||||
dataset_valid = Subset(image_dataset_full, indices=indices[n_train:n_train + n_valid])
|
|
||||||
dataset_test = Subset(image_dataset_full, indices=indices[n_train + n_valid:n_total])
|
|
||||||
|
|
||||||
assert n_total == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||||
|
|
||||||
del image_dataset_full, image_dataset_train
|
del image_dataset
|
||||||
|
|
||||||
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
||||||
num_workers=0, shuffle=True)
|
num_workers=0, shuffle=True)
|
||||||
@@ -115,19 +111,15 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
network.to(device)
|
network.to(device)
|
||||||
network.train()
|
network.train()
|
||||||
|
|
||||||
# defining the loss - combined loss with optimized weights
|
# defining the loss - combined loss for better reconstruction
|
||||||
combined_loss = CombinedLoss(mse_weight=0.7, l1_weight=0.8, edge_weight=0.2).to(device)
|
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||||
|
|
||||||
# defining the optimizer with AdamW for better weight decay handling
|
# defining the optimizer with AdamW for better weight decay handling
|
||||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||||
|
|
||||||
# Learning rate scheduler with better configuration
|
# Learning rate scheduler for better convergence
|
||||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=1e-7)
|
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
||||||
|
|
||||||
# Mixed precision training for faster computation and lower memory usage
|
|
||||||
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
|
||||||
use_amp = scaler is not None
|
|
||||||
|
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||||
@@ -136,13 +128,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
counter = 0
|
counter = 0
|
||||||
best_validation_loss = np.inf
|
best_validation_loss = np.inf
|
||||||
loss_list = []
|
loss_list = []
|
||||||
accumulation_steps = 2 # Gradient accumulation for effective larger batch size
|
|
||||||
|
|
||||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||||
|
|
||||||
print(f"Started training on device {device}")
|
print(f"Started training on device {device}")
|
||||||
print(f"Using mixed precision: {use_amp}")
|
|
||||||
print(f"Gradient accumulation steps: {accumulation_steps}")
|
|
||||||
|
|
||||||
while i < n_updates:
|
while i < n_updates:
|
||||||
|
|
||||||
@@ -153,33 +142,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
if (i + 1) % print_train_stats_at == 0:
|
if (i + 1) % print_train_stats_at == 0:
|
||||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||||
|
|
||||||
# Use mixed precision if available
|
optimizer.zero_grad()
|
||||||
if use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output = network(input)
|
|
||||||
loss = combined_loss(output, target)
|
|
||||||
loss = loss / accumulation_steps
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
else:
|
|
||||||
output = network(input)
|
|
||||||
loss = combined_loss(output, target)
|
|
||||||
loss = loss / accumulation_steps
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# Gradient accumulation - update weights every accumulation_steps
|
|
||||||
if (i + 1) % accumulation_steps == 0:
|
|
||||||
if use_amp:
|
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
else:
|
|
||||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
scheduler.step(i / n_updates)
|
|
||||||
|
|
||||||
loss_list.append(loss.item() * accumulation_steps)
|
output = network(input)
|
||||||
|
|
||||||
|
loss = combined_loss(output, target)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping for training stability
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
||||||
|
|
||||||
|
loss_list.append(loss.item())
|
||||||
|
|
||||||
# writing the stats to wandb
|
# writing the stats to wandb
|
||||||
if use_wandb and (i+1) % print_stats_at == 0:
|
if use_wandb and (i+1) % print_stats_at == 0:
|
||||||
@@ -188,9 +165,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
|||||||
# plotting
|
# plotting
|
||||||
if (i + 1) % plot_at == 0:
|
if (i + 1) % plot_at == 0:
|
||||||
print(f"Plotting images, current update {i + 1}")
|
print(f"Plotting images, current update {i + 1}")
|
||||||
# Convert to float32 for matplotlib compatibility (mixed precision may produce float16)
|
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
||||||
plot(input.float().cpu().numpy(), target.detach().float().cpu().numpy(),
|
|
||||||
output.detach().float().cpu().numpy(), plotpath, i)
|
|
||||||
|
|
||||||
# evaluating model every validate_at sample
|
# evaluating model every validate_at sample
|
||||||
if (i + 1) % validate_at == 0:
|
if (i + 1) % validate_at == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user