Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip gate as option #320

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import torch.nn as nn

import math
import logging

logger = logging.getLogger(__name__)


class CNNectomeUNet(Architecture):
Expand Down Expand Up @@ -172,9 +175,19 @@ def __init__(self, architecture_config):
)
self.use_attention = architecture_config.use_attention
self.batch_norm = architecture_config.batch_norm
self._skip_gate = architecture_config.skip_gate

self.unet = self.module()

@property
def skip_gate(self):
return self._skip_gate

@skip_gate.setter
def skip_gate(self, skip):
self._skip_gate = skip
self.unet.skip_gate = skip

@property
def eval_shape_increase(self):
"""
Expand Down Expand Up @@ -264,6 +277,7 @@ def module(self):
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
batch_norm=self.batch_norm,
skip_gate=self.skip_gate,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand Down Expand Up @@ -460,6 +474,7 @@ def __init__(
activation_on_upsample=False,
use_attention=False,
batch_norm=True,
skip_gate=True,
):
"""
Create a U-Net::
Expand Down Expand Up @@ -579,6 +594,7 @@ def __init__(
self.dims = len(downsample_factors[0])
self.use_attention = use_attention
self.batch_norm = batch_norm
self._skip_gate = skip_gate

# default arguments

Expand Down Expand Up @@ -647,6 +663,7 @@ def __init__(
crop_factor=crop_factors[level],
next_conv_kernel_sizes=kernel_size_up[level],
activation=activation if activation_on_upsample else None,
skip_gate=skip_gate,
)
for level in range(self.num_levels - 1)
]
Expand Down Expand Up @@ -711,6 +728,33 @@ def __init__(
]
)

@property
def skip_gate(self):
return self._skip_gate

@skip_gate.setter
def skip_gate(self, skip):
for head in self.r_up:
for layer in head:
if isinstance(layer, Upsample):
layer.skip_gate = skip
else:
logger.error(f"Layer {layer} is not an Upsample layer")

def set_skip(self, skip):
"""
Set the skip_gate for all the Upsample layers.

Args:
skip (bool): The value to set for skip_gate.
"""
for head in self.r_up:
for layer in head:
if isinstance(layer, Upsample):
layer.skip_gate = skip
else:
logger.error(f"Layer {layer} is not an Upsample layer")

def rec_forward(self, level, f_in):
"""
Recursive forward pass of the U-Net.
Expand Down Expand Up @@ -1038,6 +1082,7 @@ def __init__(
crop_factor=None,
next_conv_kernel_sizes=None,
activation=None,
skip_gate=True,
):
"""
Upsample module. This module performs upsampling of the input tensor
Expand Down Expand Up @@ -1070,6 +1115,7 @@ def __init__(

self.crop_factor = crop_factor
self.next_conv_kernel_sizes = next_conv_kernel_sizes
self.skip_gate = skip_gate

self.dims = len(scale_factor)

Expand Down Expand Up @@ -1250,7 +1296,7 @@ def forward(self, g_out, f_left=None):
else:
g_cropped = g_up

if f_left is not None:
if f_left is not None and self.skip_gate:
f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :])

return torch.cat([f_cropped, g_cropped], dim=1)
Expand Down
6 changes: 6 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,9 @@ class CNNectomeUNetConfig(ArchitectureConfig):
default=True,
metadata={"help_text": "Whether to use batch normalization."},
)
skip_gate: bool = attr.ib(
default=True,
metadata={
"help_text": "Whether to use skip gates. using skip gates concatenates the left feature map with the right feature map which helps for training. disabling the skip gate will make the model like a encoder-decoder model. example pipeline: start with skip gate false, we can train with only raw data. then we can train with skip gate true to fine tune the model with groundtruth."
},
)
Loading