From 2c17e176d79650bcf92d81092cd7975b11ed637c Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 16:04:21 -0500 Subject: [PATCH] create CNNectomeUNetModule using attention --- .../architectures/attention_unet.py | 71 -------------- .../architectures/cnnectome_unet.py | 98 ++++++++++++++++++- 2 files changed, 94 insertions(+), 75 deletions(-) delete mode 100644 dacapo/experiments/architectures/attention_unet.py diff --git a/dacapo/experiments/architectures/attention_unet.py b/dacapo/experiments/architectures/attention_unet.py deleted file mode 100644 index f9c7f767f..000000000 --- a/dacapo/experiments/architectures/attention_unet.py +++ /dev/null @@ -1,71 +0,0 @@ - -import torch -import torch.nn as nn -from .cnnectome_unet import ConvPass,Downsample,Upsample - -class AttentionBlockModule(nn.Module): - def __init__(self, F_g, F_l, F_int, dims): - """Attention Block Module:: - - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). - - [g] --> W_g --\ /--> psi --> * --> [output] - \ / - [x] --> W_x --> [+] --> relu -- - - Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights - - Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. - - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. - - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. - - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. - - """ - - - super(AttentionBlockModule, self).__init__() - self.dims = dims - self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] - print("kernel_sizes:",self.kernel_sizes) - - self.W_g = ConvPass(F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same") - - self.W_x = nn.Sequential( - ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, activation=None,padding="same"), - Downsample((2,)*self.dims) - ) - - self.psi = ConvPass(F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid",padding="same") - - up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] - - self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) - - self.relu = nn.ReLU(inplace=True) - - def forward(self, g, x): - g1 = self.W_g(g) - x1 = self.W_x(x) - psi = self.relu(g1 + x1) - psi = self.psi(psi) - psi = self.up(psi) - return x * psi \ No newline at end of file diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 01a261d09..8f3e74dfe 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -125,6 +125,7 @@ def __init__( padding="valid", upsample_channel_contraction=False, activation_on_upsample=False, + use_attention=False, ): """Create a U-Net:: @@ -244,6 +245,7 @@ def __init__( ) self.dims = len(downsample_factors[0]) + self.use_attention = use_attention # default arguments @@ -317,6 +319,17 @@ def __init__( ] ) + if self.use_attention: + self.attention = nn.ModuleList( + [ + AttentionBlockModule( + F_g=num_fmaps * fmap_inc_factor ** (level ), + F_l=num_fmaps * fmap_inc_factor ** (level ), + F_int=num_fmaps * fmap_inc_factor ** (level - 1), + dims=self.dims, + )for level in range(1,self.num_levels) + ]) + # right convolutional passes self.r_conv = nn.ModuleList( [ @@ -359,10 +372,16 @@ def rec_forward(self, level, f_in): # nested levels gs_out = self.rec_forward(level - 1, g_in) - # up, concat, and crop - fs_right = [ - self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) - ] + if self.use_attention: + f_left_attented = [self.attention[i-1](gs_out[h],f_left) for h in range(self.num_heads)] + fs_right = [ + self.r_up[h][i](gs_out[h], f_left_attented[h]) + for h in range(self.num_heads) + ] + else: # up, concat, and crop + fs_right = [ + self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads) + ] # convolve fs_out = [self.r_conv[h][i](fs_right[h]) for h in range(self.num_heads)] @@ -580,3 +599,74 @@ def forward(self, g_out, f_left=None): return torch.cat([f_cropped, g_cropped], dim=1) else: return g_cropped + + + +class AttentionBlockModule(nn.Module): + def __init__(self, F_g, F_l, F_int, dims): + """Attention Block Module:: + + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- + + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Args: + F_g (int): The number of feature channels in the gating signal (g). + This is the input channel dimension for the W_g convolutional layer. + + F_l (int): The number of feature channels in the input features (x). + This is the input channel dimension for the W_x convolutional layer. + + F_int (int): The number of intermediate feature channels. + This represents the output channel dimension of the W_g and W_x convolutional layers + and the input channel dimension for the psi layer. Typically, F_int is smaller + than F_g and F_l, as it serves to compress the feature representations before + applying the attention mechanism. + + The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, + and applies a sigmoid activation to generate an attention map. This map is then used + to scale the input features 'x', resulting in an output that focuses on important + features as dictated by the gating signal 'g'. + + """ + + super(AttentionBlockModule, self).__init__() + self.dims = dims + self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims] + print("kernel_sizes:", self.kernel_sizes) + + self.W_g = ConvPass( + F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same") + + self.W_x = nn.Sequential( + ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes, + activation=None, padding="same"), + Downsample((2,)*self.dims) + ) + + self.psi = ConvPass( + F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same") + + up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims] + + self.up = nn.Upsample(scale_factor=2, mode=up_mode, align_corners=True) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, g, x): + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + psi = self.up(psi) + return x * psi