added result, 18.0253
This commit is contained in:
@@ -70,9 +70,9 @@ class CBAM(nn.Module):
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolutional block with Conv2d -> BatchNorm -> LeakyReLU"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dropout=0.0):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, dropout=0.0):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
|
||||
@@ -101,13 +101,13 @@ class ResidualConvBlock(nn.Module):
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
"""Downsampling block with conv blocks, residual connection, attention, and max pooling"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Simplified downsampling block with conv blocks, residual connection, and max pooling"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBlock(in_channels, out_channels, dropout=dropout)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
self.attention = CBAM(out_channels) if use_attention else nn.Identity()
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -118,15 +118,14 @@ class DownBlock(nn.Module):
|
||||
return self.pool(skip), skip
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
"""Upsampling block with transposed conv, residual connection, attention, and conv blocks"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1):
|
||||
"""Simplified upsampling block with transposed conv, residual connection, and conv blocks"""
|
||||
def __init__(self, in_channels, out_channels, dropout=0.1, use_attention=True):
|
||||
super().__init__()
|
||||
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
# After concat: out_channels (from upconv) + in_channels (from skip)
|
||||
self.conv1 = ConvBlock(out_channels + in_channels, out_channels, dropout=dropout)
|
||||
self.conv2 = ConvBlock(out_channels, out_channels, dropout=dropout)
|
||||
self.residual = ResidualConvBlock(out_channels, dropout=dropout)
|
||||
self.attention = CBAM(out_channels)
|
||||
self.attention = CBAM(out_channels) if use_attention else nn.Identity()
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.up(x)
|
||||
@@ -135,7 +134,6 @@ class UpBlock(nn.Module):
|
||||
x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.residual(x)
|
||||
x = self.attention(x)
|
||||
return x
|
||||
@@ -145,38 +143,35 @@ class MyModel(nn.Module):
|
||||
def __init__(self, n_in_channels: int, base_channels: int = 64, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
# Initial convolution with larger receptive field
|
||||
# Initial convolution - simplified
|
||||
self.init_conv = nn.Sequential(
|
||||
ConvBlock(n_in_channels, base_channels, kernel_size=7, padding=3),
|
||||
ConvBlock(base_channels, base_channels),
|
||||
ResidualConvBlock(base_channels)
|
||||
ConvBlock(n_in_channels, base_channels, kernel_size=5, padding=2),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
# Encoder (downsampling path)
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout)
|
||||
# Encoder (downsampling path) - attention only on deeper layers
|
||||
self.down1 = DownBlock(base_channels, base_channels * 2, dropout=dropout, use_attention=False)
|
||||
self.down2 = DownBlock(base_channels * 2, base_channels * 4, dropout=dropout, use_attention=False)
|
||||
self.down3 = DownBlock(base_channels * 4, base_channels * 8, dropout=dropout, use_attention=True)
|
||||
self.down4 = DownBlock(base_channels * 8, base_channels * 16, dropout=dropout, use_attention=True)
|
||||
|
||||
# Bottleneck with multiple residual blocks
|
||||
# Simplified bottleneck with dilated convolutions
|
||||
self.bottleneck = nn.Sequential(
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
ConvBlock(base_channels * 16, base_channels * 16, dilation=2, padding=2, dropout=dropout),
|
||||
ResidualConvBlock(base_channels * 16, dropout=dropout),
|
||||
CBAM(base_channels * 16)
|
||||
)
|
||||
|
||||
# Decoder (upsampling path)
|
||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout)
|
||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout)
|
||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout)
|
||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout)
|
||||
# Decoder (upsampling path) - attention only on deeper layers
|
||||
self.up1 = UpBlock(base_channels * 16, base_channels * 8, dropout=dropout, use_attention=True)
|
||||
self.up2 = UpBlock(base_channels * 8, base_channels * 4, dropout=dropout, use_attention=True)
|
||||
self.up3 = UpBlock(base_channels * 4, base_channels * 2, dropout=dropout, use_attention=False)
|
||||
self.up4 = UpBlock(base_channels * 2, base_channels, dropout=dropout, use_attention=False)
|
||||
|
||||
# Final refinement layers
|
||||
# Simplified final refinement layers
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvBlock(base_channels * 2, base_channels),
|
||||
ResidualConvBlock(base_channels),
|
||||
ConvBlock(base_channels, base_channels)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user