Skip to content

Commit

Permalink
Merge branch 'main' into dev/post-process
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 14, 2024
2 parents cb0d30d + 06bf991 commit 45bee7e
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.10","3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
*.sw[pmno]
*.hdf
*.h5
*.ipynb
# *.ipynb
*.pyc
*.egg-info
*.dat
Expand All @@ -12,6 +12,7 @@
dist
build
dacapo.yaml
__pycache__

# vscode stuff
.vscode
Expand Down
28 changes: 28 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
BSD 3-Clause License

Copyright (c) 2024, Howard Hughes Medical Institute

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
153 changes: 149 additions & 4 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, architecture_config):
self.upsample_factors = (
self.upsample_factors if self.upsample_factors is not None else []
)
self.use_attention = architecture_config.use_attention

self.unet = self.module()

Expand Down Expand Up @@ -64,6 +65,7 @@ def module(self):
activation_on_upsample=True,
upsample_channel_contraction=[False]
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand Down Expand Up @@ -125,6 +127,7 @@ def __init__(
padding="valid",
upsample_channel_contraction=False,
activation_on_upsample=False,
use_attention=False,
):
"""Create a U-Net::
Expand Down Expand Up @@ -244,6 +247,7 @@ def __init__(
)

self.dims = len(downsample_factors[0])
self.use_attention = use_attention

# default arguments

Expand Down Expand Up @@ -316,6 +320,29 @@ def __init__(
for _ in range(num_heads)
]
)
# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out
if self.use_attention:
self.attention = nn.ModuleList(
[
nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (level + 1),
F_l=num_fmaps * fmap_inc_factor**level,
F_int=num_fmaps
* fmap_inc_factor
** (level + (1 - upsample_channel_contraction[level]))
if num_fmaps_out is None or level != 0
else num_fmaps_out,
dims=self.dims,
upsample_factor=downsample_factors[level],
)
for level in range(self.num_levels - 1)
]
)
for _ in range(num_heads)
]
)

# right convolutional passes
self.r_conv = nn.ModuleList(
Expand Down Expand Up @@ -359,10 +386,19 @@ def rec_forward(self, level, f_in):
# nested levels
gs_out = self.rec_forward(level - 1, g_in)

# up, concat, and crop
fs_right = [
self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads)
]
if self.use_attention:
f_left_attented = [
self.attention[h][i](gs_out[h], f_left)
for h in range(self.num_heads)
]
fs_right = [
self.r_up[h][i](gs_out[h], f_left_attented[h])
for h in range(self.num_heads)
]
else: # up, concat, and crop
fs_right = [
self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads)
]

# convolve
fs_out = [self.r_conv[h][i](fs_right[h]) for h in range(self.num_heads)]
Expand Down Expand Up @@ -580,3 +616,112 @@ def forward(self, g_out, f_left=None):
return torch.cat([f_cropped, g_cropped], dim=1)
else:
return g_cropped


class AttentionBlockModule(nn.Module):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
"""Attention Block Module::
The attention block takes two inputs: 'g' (gating signal) and 'x' (input features).
[g] --> W_g --\ /--> psi --> * --> [output]
\ /
[x] --> W_x --> [+] --> relu --
Where:
- W_g and W_x are 1x1 Convolution followed by Batch Normalization
- [+] indicates element-wise addition
- relu is the Rectified Linear Unit activation function
- psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation
- * indicates element-wise multiplication between the output of psi and input feature 'x'
- [output] has the same dimensions as input 'x', selectively emphasized by attention weights
Args:
F_g (int): The number of feature channels in the gating signal (g).
This is the input channel dimension for the W_g convolutional layer.
F_l (int): The number of feature channels in the input features (x).
This is the input channel dimension for the W_x convolutional layer.
F_int (int): The number of intermediate feature channels.
This represents the output channel dimension of the W_g and W_x convolutional layers
and the input channel dimension for the psi layer. Typically, F_int is smaller
than F_g and F_l, as it serves to compress the feature representations before
applying the attention mechanism.
The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them,
and applies a sigmoid activation to generate an attention map. This map is then used
to scale the input features 'x', resulting in an output that focuses on important
features as dictated by the gating signal 'g'.
"""

super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"
)

self.W_x = nn.Sequential(
ConvPass(
F_l,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
),
Downsample(upsample_factor),
)

self.psi = ConvPass(
F_int,
1,
kernel_sizes=self.kernel_sizes,
activation="Sigmoid",
padding="same",
)

up_mode = {2: "bilinear", 3: "trilinear"}[self.dims]

self.up = nn.Upsample(
scale_factor=upsample_factor, mode=up_mode, align_corners=True
)

self.relu = nn.ReLU(inplace=True)

def calculate_and_apply_padding(self, smaller_tensor, larger_tensor):
"""
Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor.
Args:
smaller_tensor (Tensor): The tensor to be padded.
larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match.
Returns:
Tensor: The padded smaller tensor with the same dimensions as the larger tensor.
"""
padding = []
for i in range(2, 2 + self.dims):
diff = larger_tensor.size(i) - smaller_tensor.size(i)
padding.extend([diff // 2, diff - diff // 2])

# Reverse padding to match the 'pad' function's expectation
padding = padding[::-1]

# Apply symmetric padding
return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0)

def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
g1 = self.calculate_and_apply_padding(g1, x1)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
psi = self.up(psi)
return x * psi
6 changes: 6 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,9 @@ class CNNectomeUNetConfig(ArchitectureConfig):
default="valid",
metadata={"help_text": "The padding to use in convolution operations."},
)
use_attention: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D."
},
)
11 changes: 8 additions & 3 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ def initialize_weights(self, model):
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
common_layers = set(model_dict.keys()) & set(weights.model.keys())
for layer in common_layers:
model_dict[layer] = weights.model[layer]
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
6 changes: 3 additions & 3 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder,
drop_channels=True,
)
+ gp.Pad(raw_key, None, 0)
+ gp.Pad(gt_key, None, 0)
+ gp.Pad(mask_key, None, 0)
+ gp.Pad(raw_key, None)
+ gp.Pad(gt_key, None)
+ gp.Pad(mask_key, None)
+ gp.RandomLocation(
ensure_nonempty=sample_points_key
if points_source is not None
Expand Down
7 changes: 6 additions & 1 deletion dacapo/gp/elastic_augment_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,15 @@ def _affine(self, array, scale, offset, target_roi, dtype=np.float32, order=1):
"""
ndim = array.shape[0]
output = np.empty((ndim,) + target_roi.get_shape(), dtype=dtype)
# Create a diagonal matrix if scale is a 1-D array
if np.isscalar(scale) or np.ndim(scale) == 1:
transform_matrix = np.diag(scale)
else:
transform_matrix = scale
for d in range(ndim):
scipy.ndimage.affine_transform(
input=array[d],
matrix=scale,
matrix=transform_matrix,
offset=offset,
output=output[d],
output_shape=output[d].shape,
Expand Down
2 changes: 1 addition & 1 deletion dacapo/utils/affinities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def seg_to_affgraph(seg: np.ndarray, neighborhood: List[Coordinate]) -> np.ndarray:
nhood = np.array(neighborhood)
nhood: np.ndarray = np.array(neighborhood)

# constructs an affinity graph from a segmentation
# assume affinity graph is represented as:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
entry_points={"console_scripts": ["dacapo=dacapo.cli:cli"]},
include_package_data=True,
install_requires=[
"numpy==1.22.3",
"numpy",
"pyyaml",
"zarr",
"cattrs",
Expand Down

0 comments on commit 45bee7e

Please sign in to comment.