From e2a29749c57bcbff3a334ea90a8ee3792846027d Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Wed, 22 Nov 2023 19:16:37 -0500 Subject: [PATCH] unet using attention --- .../architectures/cnnectome_unet.py | 32 +++++++++++++++++-- .../architectures/cnnectome_unet_config.py | 6 ++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 8f3e74dfe..32cbe1744 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -25,6 +25,7 @@ def __init__(self, architecture_config): self.upsample_factors = ( self.upsample_factors if self.upsample_factors is not None else [] ) + self.use_attention = architecture_config.use_attention self.unet = self.module() @@ -64,6 +65,7 @@ def module(self): activation_on_upsample=True, upsample_channel_contraction=[False] + [True] * (len(downsample_factors) - 1), + use_attention=self.use_attention, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -323,9 +325,9 @@ def __init__( self.attention = nn.ModuleList( [ AttentionBlockModule( - F_g=num_fmaps * fmap_inc_factor ** (level ), - F_l=num_fmaps * fmap_inc_factor ** (level ), - F_int=num_fmaps * fmap_inc_factor ** (level - 1), + F_g=num_fmaps * fmap_inc_factor ** (self.num_levels - level ), + F_l=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), + F_int=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ), dims=self.dims, )for level in range(1,self.num_levels) ]) @@ -663,9 +665,33 @@ def __init__(self, F_g, F_l, F_int, dims): self.relu = nn.ReLU(inplace=True) + def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): + """ + Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. + + Args: + smaller_tensor (Tensor): The tensor to be padded. + larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. + + Returns: + Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + """ + padding = [] + for i in range(2, 2 + self.dims): + diff = larger_tensor.size(i) - smaller_tensor.size(i) + padding.extend([diff // 2, diff - diff // 2]) + + # Reverse padding to match the 'pad' function's expectation + padding = padding[::-1] + + # Apply symmetric padding + return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0) + + def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) + g1 = self.calculate_and_apply_padding(g1, x1) psi = self.relu(g1 + x1) psi = self.psi(psi) psi = self.up(psi) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 5a40cca6d..c0e9e5b9d 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -82,3 +82,9 @@ class CNNectomeUNetConfig(ArchitectureConfig): default="valid", metadata={"help_text": "The padding to use in convolution operations."}, ) + use_attention: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D." + }, + )