Compare commits
5 Commits
claude-son
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8b9b3f6 | |||
| 7d4caaf501 | |||
| 248ffb8faf | |||
| 1771377121 | |||
| eaf45f5c72 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
data/*
|
||||
*.zip
|
||||
*.jpg
|
||||
*.pt
|
||||
*.pt
|
||||
__pycache__/
|
||||
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -100,26 +100,6 @@ class ResidualConvBlock(nn.Module):
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
class DilatedResidualBlock(nn.Module):
|
||||
"""Residual block with dilated convolutions for larger receptive field"""
|
||||
def __init__(self, channels, dilation=2, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=dilation, dilation=dilation)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + residual
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
@@ -178,12 +158,11 @@ class MyModel(nn.Module):
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
|
||||
# Bottleneck with multi-scale dilated convolutions (ASPP-style)
|
||||
# Bottleneck with multiple residual blocks
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
DilatedResidualBlock(base_channels * 16, dilation=2, dropout=dropout),
|
||||
DilatedResidualBlock(base_channels * 16, dilation=4, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
CBAM(base_channels * 16)
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
import random
|
||||
import glob
|
||||
import os
|
||||
from PIL import Image, ImageEnhance
|
||||
from PIL import Image
|
||||
|
||||
IMAGE_DIMENSION = 100
|
||||
|
||||
@@ -38,69 +38,25 @@ def preprocess(input_array: np.ndarray):
|
||||
|
||||
class ImageDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset class for loading images from a folder with data augmentation
|
||||
Dataset class for loading images from a folder
|
||||
"""
|
||||
|
||||
def __init__(self, datafolder: str, augment: bool = True):
|
||||
def __init__(self, datafolder: str):
|
||||
self.imagefiles = sorted(glob.glob(os.path.join(datafolder,"**","*.jpg"),recursive=True))
|
||||
self.augment = augment
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imagefiles)
|
||||
|
||||
def augment_image(self, image: Image) -> Image:
|
||||
"""Apply random augmentations to image"""
|
||||
# Random horizontal flip
|
||||
if random.random() > 0.5:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
# Random vertical flip
|
||||
if random.random() > 0.5:
|
||||
image = image.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Random rotation (90, 180, 270 degrees)
|
||||
if random.random() > 0.5:
|
||||
angle = random.choice([90, 180, 270])
|
||||
image = image.rotate(angle)
|
||||
|
||||
# Random brightness adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Brightness(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
# Random contrast adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Contrast(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
# Random color adjustment
|
||||
if random.random() > 0.5:
|
||||
enhancer = ImageEnhance.Color(image)
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
image = enhancer.enhance(factor)
|
||||
|
||||
return image
|
||||
|
||||
def __getitem__(self, idx:int):
|
||||
index = int(idx)
|
||||
|
||||
image = Image.open(self.imagefiles[index])
|
||||
image = resize(image)
|
||||
|
||||
# Apply augmentation if enabled
|
||||
if self.augment:
|
||||
image = self.augment_image(image)
|
||||
|
||||
image = np.asarray(image)
|
||||
image = np.asarray(resize(image))
|
||||
image = preprocess(image)
|
||||
|
||||
# Vary spacing and offset more for additional diversity
|
||||
spacing_x = random.randint(2,7)
|
||||
spacing_y = random.randint(2,7)
|
||||
offset_x = random.randint(0,10)
|
||||
offset_y = random.randint(0,10)
|
||||
spacing_x = random.randint(2,6)
|
||||
spacing_y = random.randint(2,6)
|
||||
offset_x = random.randint(0,8)
|
||||
offset_y = random.randint(0,8)
|
||||
spacing = (spacing_x, spacing_y)
|
||||
offset = (offset_x, offset_y)
|
||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||
|
||||
@@ -24,22 +24,22 @@ if __name__ == '__main__':
|
||||
config_dict['results_path'] = os.path.join(project_root, "results")
|
||||
config_dict['data_path'] = os.path.join(project_root, "data", "dataset")
|
||||
config_dict['device'] = None
|
||||
config_dict['learningrate'] = 2e-4 # Slightly lower for more stable training
|
||||
config_dict['weight_decay'] = 5e-5 # Reduced for less aggressive regularization
|
||||
config_dict['n_updates'] = 8000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 12 # Larger batch for more stable gradients
|
||||
config_dict['early_stopping_patience'] = 15 # More patience for complex model
|
||||
config_dict['learningrate'] = 3e-4 # Optimal learning rate for AdamW
|
||||
config_dict['weight_decay'] = 1e-4 # Slightly higher for better regularization
|
||||
config_dict['n_updates'] = 5000 # More updates for better convergence
|
||||
config_dict['batchsize'] = 8 # Smaller batch for better gradient estimates
|
||||
config_dict['early_stopping_patience'] = 10 # More patience for complex model
|
||||
config_dict['use_wandb'] = False
|
||||
|
||||
config_dict['print_train_stats_at'] = 10
|
||||
config_dict['print_stats_at'] = 100
|
||||
config_dict['plot_at'] = 400
|
||||
config_dict['validate_at'] = 200 # Validate frequently but not too often
|
||||
config_dict['plot_at'] = 300
|
||||
config_dict['validate_at'] = 300 # Validate more frequently
|
||||
|
||||
network_config = {
|
||||
'n_in_channels': 4,
|
||||
'base_channels': 32, # Smaller base for efficiency, depth compensates
|
||||
'dropout': 0.15 # Slightly more regularization with augmentation
|
||||
'base_channels': 48, # Good balance between capacity and memory
|
||||
'dropout': 0.1 # Regularization
|
||||
}
|
||||
|
||||
config_dict['network_config'] = network_config
|
||||
|
||||
@@ -21,8 +21,8 @@ import wandb
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""Combined loss: MSE + L1 + Edge-aware component for better reconstruction"""
|
||||
def __init__(self, mse_weight=0.7, l1_weight=0.8, edge_weight=0.2):
|
||||
"""Combined loss: MSE + L1 + SSIM-like perceptual component"""
|
||||
def __init__(self, mse_weight=1.0, l1_weight=0.5, edge_weight=0.1):
|
||||
super().__init__()
|
||||
self.mse_weight = mse_weight
|
||||
self.l1_weight = l1_weight
|
||||
@@ -84,24 +84,20 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
plotpath = os.path.join(results_path, "plots")
|
||||
os.makedirs(plotpath, exist_ok=True)
|
||||
|
||||
# Create dataset with augmentation for training, without for validation/test
|
||||
image_dataset_full = datasets.ImageDataset(datafolder=data_path, augment=False)
|
||||
image_dataset = datasets.ImageDataset(datafolder=data_path)
|
||||
|
||||
n_total = len(image_dataset_full)
|
||||
n_total = len(image_dataset)
|
||||
n_test = int(n_total * testset_ratio)
|
||||
n_valid = int(n_total * validset_ratio)
|
||||
n_train = n_total - n_test - n_valid
|
||||
indices = np.random.permutation(n_total)
|
||||
|
||||
# Create augmented dataset for training
|
||||
image_dataset_train = datasets.ImageDataset(datafolder=data_path, augment=True)
|
||||
dataset_train = Subset(image_dataset_train, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset_full, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset_full, indices=indices[n_train + n_valid:n_total])
|
||||
dataset_train = Subset(image_dataset, indices=indices[0:n_train])
|
||||
dataset_valid = Subset(image_dataset, indices=indices[n_train:n_train + n_valid])
|
||||
dataset_test = Subset(image_dataset, indices=indices[n_train + n_valid:n_total])
|
||||
|
||||
assert n_total == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
assert len(image_dataset) == len(dataset_train) + len(dataset_test) + len(dataset_valid)
|
||||
|
||||
del image_dataset_full, image_dataset_train
|
||||
del image_dataset
|
||||
|
||||
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize,
|
||||
num_workers=0, shuffle=True)
|
||||
@@ -115,19 +111,15 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
network.to(device)
|
||||
network.train()
|
||||
|
||||
# defining the loss - combined loss with optimized weights
|
||||
combined_loss = CombinedLoss(mse_weight=0.7, l1_weight=0.8, edge_weight=0.2).to(device)
|
||||
# defining the loss - combined loss for better reconstruction
|
||||
combined_loss = CombinedLoss(mse_weight=1.0, l1_weight=0.5, edge_weight=0.1).to(device)
|
||||
mse_loss = torch.nn.MSELoss() # Keep for evaluation
|
||||
|
||||
# defining the optimizer with AdamW for better weight decay handling
|
||||
optimizer = torch.optim.AdamW(network.parameters(), lr=learningrate, weight_decay=weight_decay)
|
||||
|
||||
# Learning rate scheduler with better configuration
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=2, eta_min=1e-7)
|
||||
|
||||
# Mixed precision training for faster computation and lower memory usage
|
||||
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
|
||||
use_amp = scaler is not None
|
||||
# Learning rate scheduler for better convergence
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
|
||||
|
||||
if use_wandb:
|
||||
wandb.watch(network, mse_loss, log="all", log_freq=10)
|
||||
@@ -136,13 +128,10 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
counter = 0
|
||||
best_validation_loss = np.inf
|
||||
loss_list = []
|
||||
accumulation_steps = 2 # Gradient accumulation for effective larger batch size
|
||||
|
||||
saved_model_path = os.path.join(results_path, "best_model.pt")
|
||||
|
||||
print(f"Started training on device {device}")
|
||||
print(f"Using mixed precision: {use_amp}")
|
||||
print(f"Gradient accumulation steps: {accumulation_steps}")
|
||||
|
||||
while i < n_updates:
|
||||
|
||||
@@ -153,33 +142,21 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
if (i + 1) % print_train_stats_at == 0:
|
||||
print(f'Update Step {i + 1} of {n_updates}: Current loss: {loss_list[-1]}')
|
||||
|
||||
# Use mixed precision if available
|
||||
if use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = network(input)
|
||||
loss = combined_loss(output, target)
|
||||
loss = loss / accumulation_steps
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
output = network(input)
|
||||
loss = combined_loss(output, target)
|
||||
loss = loss / accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
# Gradient accumulation - update weights every accumulation_steps
|
||||
if (i + 1) % accumulation_steps == 0:
|
||||
if use_amp:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step(i / n_updates)
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss_list.append(loss.item() * accumulation_steps)
|
||||
output = network(input)
|
||||
|
||||
loss = combined_loss(output, target)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping for training stability
|
||||
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step(i + len(loss_list) / len(dataloader_train))
|
||||
|
||||
loss_list.append(loss.item())
|
||||
|
||||
# writing the stats to wandb
|
||||
if use_wandb and (i+1) % print_stats_at == 0:
|
||||
@@ -188,9 +165,7 @@ def train(seed, testset_ratio, validset_ratio, data_path, results_path, early_st
|
||||
# plotting
|
||||
if (i + 1) % plot_at == 0:
|
||||
print(f"Plotting images, current update {i + 1}")
|
||||
# Convert to float32 for matplotlib compatibility (mixed precision may produce float16)
|
||||
plot(input.float().cpu().numpy(), target.detach().float().cpu().numpy(),
|
||||
output.detach().float().cpu().numpy(), plotpath, i)
|
||||
plot(input.cpu().numpy(), target.detach().cpu().numpy(), output.detach().cpu().numpy(), plotpath, i)
|
||||
|
||||
# evaluating model every validate_at sample
|
||||
if (i + 1) % validate_at == 0:
|
||||
|
||||
Reference in New Issue
Block a user