From c7f83889af279e9c62fbbb7bdd9309df97f078a0 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Mon, 4 Nov 2024 15:45:13 -0500 Subject: [PATCH 1/2] enable skip gate --- .../architectures/cnnectome_unet.py | 47 ++++++++++++++++++- .../architectures/cnnectome_unet_config.py | 4 ++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d89e902ac..b1ad1a821 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -4,7 +4,9 @@ import torch.nn as nn import math +import logging +logger = logging.getLogger(__name__) class CNNectomeUNet(Architecture): """ @@ -172,9 +174,19 @@ def __init__(self, architecture_config): ) self.use_attention = architecture_config.use_attention self.batch_norm = architecture_config.batch_norm + self._skip_gate = architecture_config.skip_gate self.unet = self.module() + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + self._skip_gate = skip + self.unet.skip_gate = skip + @property def eval_shape_increase(self): """ @@ -264,6 +276,7 @@ def module(self): + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, batch_norm=self.batch_norm, + skip_gate=self.skip_gate, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -460,6 +473,7 @@ def __init__( activation_on_upsample=False, use_attention=False, batch_norm=True, + skip_gate=True, ): """ Create a U-Net:: @@ -579,6 +593,7 @@ def __init__( self.dims = len(downsample_factors[0]) self.use_attention = use_attention self.batch_norm = batch_norm + self._skip_gate = skip_gate # default arguments @@ -647,6 +662,7 @@ def __init__( crop_factor=crop_factors[level], next_conv_kernel_sizes=kernel_size_up[level], activation=activation if activation_on_upsample else None, + skip_gate=skip_gate, ) for level in range(self.num_levels - 1) ] @@ -711,6 +727,33 @@ def __init__( ] ) + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + + def set_skip(self, skip): + """ + Set the skip_gate for all the Upsample layers. + + Args: + skip (bool): The value to set for skip_gate. + """ + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + def rec_forward(self, level, f_in): """ Recursive forward pass of the U-Net. @@ -1038,6 +1081,7 @@ def __init__( crop_factor=None, next_conv_kernel_sizes=None, activation=None, + skip_gate = True, ): """ Upsample module. This module performs upsampling of the input tensor @@ -1070,6 +1114,7 @@ def __init__( self.crop_factor = crop_factor self.next_conv_kernel_sizes = next_conv_kernel_sizes + self.skip_gate = skip_gate self.dims = len(scale_factor) @@ -1250,7 +1295,7 @@ def forward(self, g_out, f_left=None): else: g_cropped = g_up - if f_left is not None: + if f_left is not None and self.skip_gate: f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :]) return torch.cat([f_cropped, g_cropped], dim=1) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 7eab80115..b2d84bff1 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -131,3 +131,7 @@ class CNNectomeUNetConfig(ArchitectureConfig): default=True, metadata={"help_text": "Whether to use batch normalization."}, ) + skip_gate: bool = attr.ib( + default=True, + metadata={"help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth."}, + ) From 59f00c6cf6284ba1d4f2cba3145fb2c9a85b9cac Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 7 Nov 2024 20:05:39 +0000 Subject: [PATCH 2/2] :art: Format Python code with psf/black --- dacapo/experiments/architectures/cnnectome_unet.py | 5 +++-- dacapo/experiments/architectures/cnnectome_unet_config.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index b1ad1a821..0b03b4ecc 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) + class CNNectomeUNet(Architecture): """ A U-Net architecture for 3D or 4D data. The U-Net expects 3D or 4D tensors @@ -181,7 +182,7 @@ def __init__(self, architecture_config): @property def skip_gate(self): return self._skip_gate - + @skip_gate.setter def skip_gate(self, skip): self._skip_gate = skip @@ -1081,7 +1082,7 @@ def __init__( crop_factor=None, next_conv_kernel_sizes=None, activation=None, - skip_gate = True, + skip_gate=True, ): """ Upsample module. This module performs upsampling of the input tensor diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index b2d84bff1..d6a17c6e7 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -133,5 +133,7 @@ class CNNectomeUNetConfig(ArchitectureConfig): ) skip_gate: bool = attr.ib( default=True, - metadata={"help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth."}, + metadata={ + "help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth." + }, )