Skip to content

Commit

Permalink
enable skip gate
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 4, 2024
1 parent 937371e commit c7f8388
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
47 changes: 46 additions & 1 deletion dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch.nn as nn

import math
import logging

logger = logging.getLogger(__name__)

class CNNectomeUNet(Architecture):
"""
Expand Down Expand Up @@ -172,9 +174,19 @@ def __init__(self, architecture_config):
)
self.use_attention = architecture_config.use_attention
self.batch_norm = architecture_config.batch_norm
self._skip_gate = architecture_config.skip_gate

self.unet = self.module()

@property
def skip_gate(self):
return self._skip_gate

@skip_gate.setter
def skip_gate(self, skip):
self._skip_gate = skip
self.unet.skip_gate = skip

@property
def eval_shape_increase(self):
"""
Expand Down Expand Up @@ -264,6 +276,7 @@ def module(self):
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
batch_norm=self.batch_norm,
skip_gate=self.skip_gate,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand Down Expand Up @@ -460,6 +473,7 @@ def __init__(
activation_on_upsample=False,
use_attention=False,
batch_norm=True,
skip_gate=True,
):
"""
Create a U-Net::
Expand Down Expand Up @@ -579,6 +593,7 @@ def __init__(
self.dims = len(downsample_factors[0])
self.use_attention = use_attention
self.batch_norm = batch_norm
self._skip_gate = skip_gate

# default arguments

Expand Down Expand Up @@ -647,6 +662,7 @@ def __init__(
crop_factor=crop_factors[level],
next_conv_kernel_sizes=kernel_size_up[level],
activation=activation if activation_on_upsample else None,
skip_gate=skip_gate,
)
for level in range(self.num_levels - 1)
]
Expand Down Expand Up @@ -711,6 +727,33 @@ def __init__(
]
)

@property
def skip_gate(self):
return self._skip_gate

@skip_gate.setter
def skip_gate(self, skip):
for head in self.r_up:
for layer in head:
if isinstance(layer, Upsample):
layer.skip_gate = skip
else:
logger.error(f"Layer {layer} is not an Upsample layer")

def set_skip(self, skip):
"""
Set the skip_gate for all the Upsample layers.
Args:
skip (bool): The value to set for skip_gate.
"""
for head in self.r_up:
for layer in head:
if isinstance(layer, Upsample):
layer.skip_gate = skip
else:
logger.error(f"Layer {layer} is not an Upsample layer")

def rec_forward(self, level, f_in):
"""
Recursive forward pass of the U-Net.
Expand Down Expand Up @@ -1038,6 +1081,7 @@ def __init__(
crop_factor=None,
next_conv_kernel_sizes=None,
activation=None,
skip_gate = True,
):
"""
Upsample module. This module performs upsampling of the input tensor
Expand Down Expand Up @@ -1070,6 +1114,7 @@ def __init__(

self.crop_factor = crop_factor
self.next_conv_kernel_sizes = next_conv_kernel_sizes
self.skip_gate = skip_gate

self.dims = len(scale_factor)

Expand Down Expand Up @@ -1250,7 +1295,7 @@ def forward(self, g_out, f_left=None):
else:
g_cropped = g_up

if f_left is not None:
if f_left is not None and self.skip_gate:
f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :])

return torch.cat([f_cropped, g_cropped], dim=1)
Expand Down
4 changes: 4 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,7 @@ class CNNectomeUNetConfig(ArchitectureConfig):
default=True,
metadata={"help_text": "Whether to use batch normalization."},
)
skip_gate: bool = attr.ib(
default=True,
metadata={"help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth."},
)

0 comments on commit c7f8388

Please sign in to comment.