diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d89e902ac..0b03b4ecc 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -4,6 +4,9 @@ import torch.nn as nn import math +import logging + +logger = logging.getLogger(__name__) class CNNectomeUNet(Architecture): @@ -172,9 +175,19 @@ def __init__(self, architecture_config): ) self.use_attention = architecture_config.use_attention self.batch_norm = architecture_config.batch_norm + self._skip_gate = architecture_config.skip_gate self.unet = self.module() + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + self._skip_gate = skip + self.unet.skip_gate = skip + @property def eval_shape_increase(self): """ @@ -264,6 +277,7 @@ def module(self): + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, batch_norm=self.batch_norm, + skip_gate=self.skip_gate, ) if len(self.upsample_factors) > 0: layers = [unet] @@ -460,6 +474,7 @@ def __init__( activation_on_upsample=False, use_attention=False, batch_norm=True, + skip_gate=True, ): """ Create a U-Net:: @@ -579,6 +594,7 @@ def __init__( self.dims = len(downsample_factors[0]) self.use_attention = use_attention self.batch_norm = batch_norm + self._skip_gate = skip_gate # default arguments @@ -647,6 +663,7 @@ def __init__( crop_factor=crop_factors[level], next_conv_kernel_sizes=kernel_size_up[level], activation=activation if activation_on_upsample else None, + skip_gate=skip_gate, ) for level in range(self.num_levels - 1) ] @@ -711,6 +728,33 @@ def __init__( ] ) + @property + def skip_gate(self): + return self._skip_gate + + @skip_gate.setter + def skip_gate(self, skip): + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + + def set_skip(self, skip): + """ + Set the skip_gate for all the Upsample layers. + + Args: + skip (bool): The value to set for skip_gate. + """ + for head in self.r_up: + for layer in head: + if isinstance(layer, Upsample): + layer.skip_gate = skip + else: + logger.error(f"Layer {layer} is not an Upsample layer") + def rec_forward(self, level, f_in): """ Recursive forward pass of the U-Net. @@ -1038,6 +1082,7 @@ def __init__( crop_factor=None, next_conv_kernel_sizes=None, activation=None, + skip_gate=True, ): """ Upsample module. This module performs upsampling of the input tensor @@ -1070,6 +1115,7 @@ def __init__( self.crop_factor = crop_factor self.next_conv_kernel_sizes = next_conv_kernel_sizes + self.skip_gate = skip_gate self.dims = len(scale_factor) @@ -1250,7 +1296,7 @@ def forward(self, g_out, f_left=None): else: g_cropped = g_up - if f_left is not None: + if f_left is not None and self.skip_gate: f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :]) return torch.cat([f_cropped, g_cropped], dim=1) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index 7eab80115..d6a17c6e7 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -131,3 +131,9 @@ class CNNectomeUNetConfig(ArchitectureConfig): default=True, metadata={"help_text": "Whether to use batch normalization."}, ) + skip_gate: bool = attr.ib( + default=True, + metadata={ + "help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth." + }, + )