Skip to content

Commit

Permalink
fix fmap calculation for attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 23, 2023
1 parent e2a2974 commit c45e93c
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,31 @@ def __init__(
for _ in range(num_heads)
]
)

# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out
if self.use_attention:
self.attention = nn.ModuleList(
[
nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (self.num_levels - level ),
F_l=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ),
F_int=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ),
F_g=num_fmaps * fmap_inc_factor ** (level + 1),
F_l=num_fmaps
* fmap_inc_factor
** level,
F_int=num_fmaps
* fmap_inc_factor
** (level + (1 - upsample_channel_contraction[level]))
if num_fmaps_out is None or level != 0
else num_fmaps_out,
dims=self.dims,
)for level in range(1,self.num_levels)
])
upsample_factor=downsample_factors[level],
)
for level in range(self.num_levels - 1)
]
)
for _ in range(num_heads)
]
)

# right convolutional passes
self.r_conv = nn.ModuleList(
Expand Down Expand Up @@ -375,7 +389,7 @@ def rec_forward(self, level, f_in):
gs_out = self.rec_forward(level - 1, g_in)

if self.use_attention:
f_left_attented = [self.attention[i-1](gs_out[h],f_left) for h in range(self.num_heads)]
f_left_attented = [self.attention[h][i](gs_out[h],f_left) for h in range(self.num_heads)]
fs_right = [
self.r_up[h][i](gs_out[h], f_left_attented[h])
for h in range(self.num_heads)
Expand Down Expand Up @@ -605,7 +619,7 @@ def forward(self, g_out, f_left=None):


class AttentionBlockModule(nn.Module):
def __init__(self, F_g, F_l, F_int, dims):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
"""Attention Block Module::
The attention block takes two inputs: 'g' (gating signal) and 'x' (input features).
Expand Down Expand Up @@ -645,23 +659,26 @@ def __init__(self, F_g, F_l, F_int, dims):
super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
print("kernel_sizes:", self.kernel_sizes)
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,)*self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same")

self.W_x = nn.Sequential(
ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes,
activation=None, padding="same"),
Downsample((2,)*self.dims)
Downsample(upsample_factor)
)

self.psi = ConvPass(
F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same")

up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims]

self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True)
self.up = nn.Upsample(scale_factor=upsample_factor, mode=up_mode, align_corners=True)

self.relu = nn.ReLU(inplace=True)

Expand Down

0 comments on commit c45e93c

Please sign in to comment.