diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 01a261d09..ddf847456 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -25,6 +25,7 @@ def __init__(self, architecture_config): self.upsample_factors = ( self.upsample_factors if self.upsample_factors is not None else [] ) + self.use_attention = architecture_config.use_attention self.unet = self.module() @@ -64,6 +65,7 @@ def module(self): activation_on_upsample=True, upsample_channel_contraction=[False] + [True] * (len(downsample_factors) - 1), + use_attention=self.use_attention, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -125,6 +127,7 @@ def __init__( padding="valid", upsample_channel_contraction=False, activation_on_upsample=False, + use_attention=False, ): """Create a U-Net:: @@ -244,6 +247,7 @@ def __init__( ) self.dims = len(downsample_factors[0]) + self.use_attention = use_attention # default arguments @@ -316,6 +320,29 @@ def __init__( for _ in range(num_heads) ] ) + # if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out + if self.use_attention: + self.attention = nn.ModuleList( + [ + nn.ModuleList( + [ + AttentionBlockModule( + F_g=num_fmaps * fmap_inc_factor ** (level + 1), + F_l=num_fmaps * fmap_inc_factor**level, + F_int=num_fmaps + * fmap_inc_factor + ** (level + (1 - upsample_channel_contraction[level])) + if num_fmaps_out is None or level != 0 + else num_fmaps_out, + dims=self.dims, + upsample_factor=downsample_factors[level], + ) + for level in range(self.num_levels - 1) + ] + ) + for _ in range(num_heads) + ] + ) # right convolutional passes self.r_conv = nn.ModuleList( @@ -359,10 +386,19 @@ def rec_forward(self, level, f_in): # nested levels gs_out = self.rec_forward(level - 1, g_in) - # up, concat, and crop - fs_right = [ - self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) - ] + if self.use_attention: + f_left_attented = [ + self.attention[h][i](gs_out[h], f_left) + for h in range(self.num_heads) + ] + fs_right = [ + self.r_up[h][i](gs_out[h], f_left_attented[h]) + for h in range(self.num_heads) + ] + else: # up, concat, and crop + fs_right = [ + self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) + ] # convolve fs_out = [self.r_conv[h][i](fs_right[h]) for h in range(self.num_heads)] @@ -580,3 +616,112 @@ def forward(self, g_out, f_left=None): return torch.cat([f_cropped, g_cropped], dim=1) else: return g_cropped + + +class AttentionBlockModule(nn.Module): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): + """Attention Block Module:: + + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- + + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. + + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. + + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. + + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. + + """ + + super(AttentionBlockModule, self).__init__() + self.dims = dims + self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + if upsample_factor is not None: + self.upsample_factor = upsample_factor + else: + self.upsample_factor = (2,) * self.dims + + self.W_g = ConvPass( + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same" + ) + + self.W_x = nn.Sequential( + ConvPass( + F_l, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + ), + Downsample(upsample_factor), + ) + + self.psi = ConvPass( + F_int, + 1, + kernel_sizes=self.kernel_sizes, + activation="Sigmoid", + padding="same", + ) + + up_mode = {2: "bilinear", 3: "trilinear"}[self.dims] + + self.up = nn.Upsample( + scale_factor=upsample_factor, mode=up_mode, align_corners=True + ) + + self.relu = nn.ReLU(inplace=True) + + def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): + """ + Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. + + Args: + smaller_tensor (Tensor): The tensor to be padded. + larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. + + Returns: + Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + """ + padding = [] + for i in range(2, 2 + self.dims): + diff = larger_tensor.size(i) - smaller_tensor.size(i) + padding.extend([diff // 2, diff - diff // 2]) + + # Reverse padding to match the 'pad' function's expectation + padding = padding[::-1] + + # Apply symmetric padding + return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0) + + def forward(self, g, x): + g1 = self.W_g(g) + x1 = self.W_x(x) + g1 = self.calculate_and_apply_padding(g1, x1) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + psi = self.up(psi) + return x * psi diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 5a40cca6d..c0e9e5b9d 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -82,3 +82,9 @@ class CNNectomeUNetConfig(ArchitectureConfig): default="valid", metadata={"help_text": "The padding to use in convolution operations."}, ) + use_attention: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D." + }, + )