Skip to content

Commit

Permalink
unet using attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 23, 2023
1 parent 2c17e17 commit e2a2974
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
32 changes: 29 additions & 3 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, architecture_config):
self.upsample_factors = (
self.upsample_factors if self.upsample_factors is not None else []
)
self.use_attention = architecture_config.use_attention

self.unet = self.module()

Expand Down Expand Up @@ -64,6 +65,7 @@ def module(self):
activation_on_upsample=True,
upsample_channel_contraction=[False]
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand Down Expand Up @@ -323,9 +325,9 @@ def __init__(
self.attention = nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (level ),
F_l=num_fmaps * fmap_inc_factor ** (level ),
F_int=num_fmaps * fmap_inc_factor ** (level - 1),
F_g=num_fmaps * fmap_inc_factor ** (self.num_levels - level ),
F_l=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ),
F_int=num_fmaps * fmap_inc_factor ** (self.num_levels - level -1 ),
dims=self.dims,
)for level in range(1,self.num_levels)
])
Expand Down Expand Up @@ -663,9 +665,33 @@ def __init__(self, F_g, F_l, F_int, dims):

self.relu = nn.ReLU(inplace=True)

def calculate_and_apply_padding(self, smaller_tensor, larger_tensor):
"""
Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor.
Args:
smaller_tensor (Tensor): The tensor to be padded.
larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match.
Returns:
Tensor: The padded smaller tensor with the same dimensions as the larger tensor.
"""
padding = []
for i in range(2, 2 + self.dims):
diff = larger_tensor.size(i) - smaller_tensor.size(i)
padding.extend([diff // 2, diff - diff // 2])

# Reverse padding to match the 'pad' function's expectation
padding = padding[::-1]

# Apply symmetric padding
return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0)


def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
g1 = self.calculate_and_apply_padding(g1, x1)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
psi = self.up(psi)
Expand Down
6 changes: 6 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,9 @@ class CNNectomeUNetConfig(ArchitectureConfig):
default="valid",
metadata={"help_text": "The padding to use in convolution operations."},
)
use_attention: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D."
},
)

0 comments on commit e2a2974

Please sign in to comment.