From c7f83889af279e9c62fbbb7bdd9309df97f078a0 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 4 Nov 2024 15:45:13 -0500 Subject: [PATCH] enable skip gate --- .../architectures/cnnectome_unet.py | 47 ++++++++++++++++++- .../architectures/cnnectome_unet_config.py | 4 ++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d89e902ac..b1ad1a821 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -4,7 +4,9 @@ import torch.nn as nn import math +import logging +logger = logging.getLogger(__name__) class CNNectomeUNet(Architecture): """ @@ -172,9 +174,19 @@ def __init__(self, architecture_config): ) self.use_attention = architecture_config.use_attention self.batch_norm = architecture_config.batch_norm + self._skip_gate = architecture_config.skip_gate self.unet = self.module() + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + self._skip_gate = skip + self.unet.skip_gate = skip + @property def eval_shape_increase(self): """ @@ -264,6 +276,7 @@ def module(self): + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, batch_norm=self.batch_norm, + skip_gate=self.skip_gate, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -460,6 +473,7 @@ def __init__( activation_on_upsample=False, use_attention=False, batch_norm=True, + skip_gate=True, ): """ Create a U-Net:: @@ -579,6 +593,7 @@ def __init__( self.dims = len(downsample_factors[0]) self.use_attention = use_attention self.batch_norm = batch_norm + self._skip_gate = skip_gate # default arguments @@ -647,6 +662,7 @@ def __init__( crop_factor=crop_factors[level], next_conv_kernel_sizes=kernel_size_up[level], activation=activation if activation_on_upsample else None, + skip_gate=skip_gate, ) for level in range(self.num_levels - 1) ] @@ -711,6 +727,33 @@ def __init__( ] ) + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + + def set_skip(self, skip): + """ + Set the skip_gate for all the Upsample layers. + + Args: + skip (bool): The value to set for skip_gate. + """ + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + def rec_forward(self, level, f_in): """ Recursive forward pass of the U-Net. @@ -1038,6 +1081,7 @@ def __init__( crop_factor=None, next_conv_kernel_sizes=None, activation=None, + skip_gate = True, ): """ Upsample module. This module performs upsampling of the input tensor @@ -1070,6 +1114,7 @@ def __init__( self.crop_factor = crop_factor self.next_conv_kernel_sizes = next_conv_kernel_sizes + self.skip_gate = skip_gate self.dims = len(scale_factor) @@ -1250,7 +1295,7 @@ def forward(self, g_out, f_left=None): else: g_cropped = g_up - if f_left is not None: + if f_left is not None and self.skip_gate: f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :]) return torch.cat([f_cropped, g_cropped], dim=1) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 7eab80115..b2d84bff1 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -131,3 +131,7 @@ class CNNectomeUNetConfig(ArchitectureConfig): default=True, metadata={"help_text": "Whether to use batch normalization."}, ) + skip_gate: bool = attr.ib( + default=True, + metadata={"help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth."}, + )