Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8b9b3f6 | |||
| 7d4caaf501 | |||
| 248ffb8faf | |||
| 1771377121 | |||
| eaf45f5c72 |
3
image-inpainting/.gitignore
vendored
3
image-inpainting/.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/*
|
data/*
|
||||||
*.zip
|
*.zip
|
||||||
*.jpg
|
*.jpg
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-2.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-3.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-4.npz
Normal file
Binary file not shown.
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
BIN
image-inpainting/results/submissions/tikaiz-5.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -18,20 +18,6 @@ def init_weights(m):
|
|||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.GroupNorm):
|
|
||||||
if m.weight is not None:
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_norm(num_channels: int) -> nn.Module:
|
|
||||||
"""Batch-size independent normalization (works well for batch_size=1 eval)."""
|
|
||||||
# Choose a group count that divides num_channels.
|
|
||||||
num_groups = min(32, num_channels)
|
|
||||||
while num_groups > 1 and (num_channels % num_groups) != 0:
|
|
||||||
num_groups //= 2
|
|
||||||
return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
class ChannelAttention(nn.Module):
|
||||||
@@ -87,7 +73,7 @@ class ConvBlock(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||||
self.bn = _make_norm(out_channels)
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
@@ -99,9 +85,9 @@ class ResidualConvBlock(nn.Module):
|
|||||||
def __init__(self, channels, dropout=0.0):
|
def __init__(self, channels, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn1 = _make_norm(channels)
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||||||
self.bn2 = _make_norm(channels)
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
@@ -114,26 +100,6 @@ class ResidualConvBlock(nn.Module):
|
|||||||
return self.relu(out)
|
return self.relu(out)
|
||||||
|
|
||||||
|
|
||||||
class GatedConvBlock(nn.Module):
|
|
||||||
"""Gated convolution block (helps the network condition on the mask channel)."""
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
self.feature = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
|
||||||
self.gate = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
|
||||||
self.norm = _make_norm(out_channels)
|
|
||||||
self.act = nn.LeakyReLU(0.1, inplace=True)
|
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
feat = self.feature(x)
|
|
||||||
gate = torch.sigmoid(self.gate(x))
|
|
||||||
out = feat * gate
|
|
||||||
out = self.norm(out)
|
|
||||||
out = self.act(out)
|
|
||||||
out = self.dropout(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(nn.Module):
|
class DownBlock(nn.Module):
|
||||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||||
@@ -181,7 +147,7 @@ class MyModel(nn.Module):
|
|||||||
|
|
||||||
# Initial convolution with larger receptive field
|
# Initial convolution with larger receptive field
|
||||||
self.init_conv = nn.Sequential(
|
self.init_conv = nn.Sequential(
|
||||||
GatedConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3, dropout=dropout),
|
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
||||||
ConvBlock(base_channels, base_channels),
|
ConvBlock(base_channels, base_channels),
|
||||||
ResidualConvBlock(base_channels)
|
ResidualConvBlock(base_channels)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,25 +32,6 @@ def resize(img: Image):
|
|||||||
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
transforms.CenterCrop((IMAGE_DIMENSION, IMAGE_DIMENSION))
|
||||||
])
|
])
|
||||||
return resize_transforms(img)
|
return resize_transforms(img)
|
||||||
|
|
||||||
|
|
||||||
def augment_geometric(img: Image.Image) -> Image.Image:
|
|
||||||
"""Lightweight, label-preserving augmentation (safe for train/val/test splits)."""
|
|
||||||
# Horizontal flip
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
|
|
||||||
# Vertical flip (less frequent)
|
|
||||||
if random.random() < 0.2:
|
|
||||||
img = img.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
|
||||||
# 90-degree rotations (no interpolation artifacts)
|
|
||||||
r = random.random()
|
|
||||||
if r < 0.25:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_90)
|
|
||||||
elif r < 0.5:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_180)
|
|
||||||
elif r < 0.75:
|
|
||||||
img = img.transpose(Image.Transpose.ROTATE_270)
|
|
||||||
return img
|
|
||||||
def preprocess(input_array: np.ndarray):
|
def preprocess(input_array: np.ndarray):
|
||||||
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
input_array = np.asarray(input_array, dtype=np.float32) / 255.0
|
||||||
return input_array
|
return input_array
|
||||||
@@ -69,17 +50,13 @@ class ImageDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, idx:int):
|
def __getitem__(self, idx:int):
|
||||||
index = int(idx)
|
index = int(idx)
|
||||||
|
|
||||||
image = Image.open(self.imagefiles[index]).convert("RGB")
|
image = Image.open(self.imagefiles[index])
|
||||||
image = augment_geometric(image)
|
|
||||||
image = np.asarray(resize(image))
|
image = np.asarray(resize(image))
|
||||||
image = preprocess(image)
|
image = preprocess(image)
|
||||||
|
spacing_x = random.randint(2,6)
|
||||||
# Sample a grid-mask similar in density to the challenge testset (~8% known pixels).
|
spacing_y = random.randint(2,6)
|
||||||
# IMPORTANT: offset ranges must be tied to spacing to avoid accidental distribution shift.
|
offset_x = random.randint(0,8)
|
||||||
spacing_x = random.randint(4, 6)
|
offset_y = random.randint(0,8)
|
||||||
spacing_y = random.randint(2, 4)
|
|
||||||
offset_x = random.randint(0, spacing_x - 1)
|
|
||||||
offset_y = random.randint(0, spacing_y - 1)
|
|
||||||
spacing = (spacing_x, spacing_y)
|
spacing = (spacing_x, spacing_y)
|
||||||
offset = (offset_x, offset_y)
|
offset = (offset_x, offset_y)
|
||||||
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
input_array, known_array = create_arrays_from_image(image.copy(), offset, spacing)
|
||||||
|
|||||||
Reference in New Issue
Block a user