Skip to content

Commit

Permalink
create CNNectomeUNetModule using attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 22, 2023
1 parent 20b5404 commit 2c17e17
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 75 deletions.
71 changes: 0 additions & 71 deletions dacapo/experiments/architectures/attention_unet.py

This file was deleted.

98 changes: 94 additions & 4 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
padding="valid",
upsample_channel_contraction=False,
activation_on_upsample=False,
use_attention=False,
):
"""Create a U-Net::
Expand Down Expand Up @@ -244,6 +245,7 @@ def __init__(
)

self.dims = len(downsample_factors[0])
self.use_attention = use_attention

# default arguments

Expand Down Expand Up @@ -317,6 +319,17 @@ def __init__(
]
)

if self.use_attention:
self.attention = nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (level ),
F_l=num_fmaps * fmap_inc_factor ** (level ),
F_int=num_fmaps * fmap_inc_factor ** (level - 1),
dims=self.dims,
)for level in range(1,self.num_levels)
])

# right convolutional passes
self.r_conv = nn.ModuleList(
[
Expand Down Expand Up @@ -359,10 +372,16 @@ def rec_forward(self, level, f_in):
# nested levels
gs_out = self.rec_forward(level - 1, g_in)

# up, concat, and crop
fs_right = [
self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads)
]
if self.use_attention:
f_left_attented = [self.attention[i-1](gs_out[h],f_left) for h in range(self.num_heads)]
fs_right = [
self.r_up[h][i](gs_out[h], f_left_attented[h])
for h in range(self.num_heads)
]
else: # up, concat, and crop
fs_right = [
self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads)
]

# convolve
fs_out = [self.r_conv[h][i](fs_right[h]) for h in range(self.num_heads)]
Expand Down Expand Up @@ -580,3 +599,74 @@ def forward(self, g_out, f_left=None):
return torch.cat([f_cropped, g_cropped], dim=1)
else:
return g_cropped



class AttentionBlockModule(nn.Module):
def __init__(self, F_g, F_l, F_int, dims):
"""Attention Block Module::
The attention block takes two inputs: 'g' (gating signal) and 'x' (input features).
[g] --> W_g --\ /--> psi --> * --> [output]
\ /
[x] --> W_x --> [+] --> relu --
Where:
- W_g and W_x are 1x1 Convolution followed by Batch Normalization
- [+] indicates element-wise addition
- relu is the Rectified Linear Unit activation function
- psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation
- * indicates element-wise multiplication between the output of psi and input feature 'x'
- [output] has the same dimensions as input 'x', selectively emphasized by attention weights
Args:
F_g (int): The number of feature channels in the gating signal (g).
This is the input channel dimension for the W_g convolutional layer.
F_l (int): The number of feature channels in the input features (x).
This is the input channel dimension for the W_x convolutional layer.
F_int (int): The number of intermediate feature channels.
This represents the output channel dimension of the W_g and W_x convolutional layers
and the input channel dimension for the psi layer. Typically, F_int is smaller
than F_g and F_l, as it serves to compress the feature representations before
applying the attention mechanism.
The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them,
and applies a sigmoid activation to generate an attention map. This map is then used
to scale the input features 'x', resulting in an output that focuses on important
features as dictated by the gating signal 'g'.
"""

super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
print("kernel_sizes:", self.kernel_sizes)

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same")

self.W_x = nn.Sequential(
ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes,
activation=None, padding="same"),
Downsample((2,)*self.dims)
)

self.psi = ConvPass(
F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same")

up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims]

self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True)

self.relu = nn.ReLU(inplace=True)

def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
psi = self.up(psi)
return x * psi

0 comments on commit 2c17e17

Please sign in to comment.