Skip to content

Commit

Permalink
Tests v0 3 5 (#338)
Browse files Browse the repository at this point in the history
Fixed a couple bugs
1) CNNectome UNet wasn't following the api defined by the Architecture
`ABC`. It was supposed to be returning `Coordinate` class instances for
voxel_size, input_shape etc.
2) CNNectome UNet had logical errors in the definition of kernel_size_up
and down.
3) New tests was expecting multi channel data for the 2D UNet and single
channel for the 3D unet despite always getting single channel data.
4) fix voxel_size attr in test fixtures
  • Loading branch information
mzouink authored Nov 13, 2024
2 parents ea4f2b1 + a5e30e0 commit 080f00a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
23 changes: 12 additions & 11 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn

from funlib.geometry import Coordinate

import math


Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(self, architecture_config):
self.unet = self.module()

@property
def eval_shape_increase(self):
def eval_shape_increase(self) -> Coordinate:
"""
The increase in shape due to the U-Net.
Expand All @@ -192,7 +194,7 @@ def eval_shape_increase(self):
"""
if self._eval_shape_increase is None:
return super().eval_shape_increase
return self._eval_shape_increase
return Coordinate(self._eval_shape_increase)

def module(self):
"""
Expand Down Expand Up @@ -235,16 +237,15 @@ def module(self):
"""
fmaps_in = self.fmaps_in
levels = len(self.downsample_factors) + 1
dims = len(self.downsample_factors[0])

if hasattr(self, "kernel_size_down"):
if self.kernel_size_down is not None:
kernel_size_down = self.kernel_size_down
else:
kernel_size_down = [[(3,) * dims, (3,) * dims]] * levels
if hasattr(self, "kernel_size_up"):
kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels
if self.kernel_size_up is not None:
kernel_size_up = self.kernel_size_up
else:
kernel_size_up = [[(3,) * dims, (3,) * dims]] * (levels - 1)
kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1)

# downsample factors has to be a list of tuples
downsample_factors = [tuple(x) for x in self.downsample_factors]
Expand Down Expand Up @@ -280,7 +281,7 @@ def module(self):
conv = ConvPass(
self.fmaps_out,
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
kernel_size_up[-1],
activation="ReLU",
batch_norm=self.batch_norm,
)
Expand All @@ -306,11 +307,11 @@ def scale(self, voxel_size):
The voxel size should be given as a tuple ``(z, y, x)``.
"""
for upsample_factor in self.upsample_factors:
voxel_size = voxel_size / upsample_factor
voxel_size = voxel_size / Coordinate(upsample_factor)
return voxel_size

@property
def input_shape(self):
def input_shape(self) -> Coordinate:
"""
Return the input shape of the U-Net.
Expand All @@ -324,7 +325,7 @@ def input_shape(self):
Note:
The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``.
"""
return self._input_shape
return Coordinate(self._input_shape)

@property
def num_in_channels(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
assert isinstance(dataset.weight, int), dataset

raw_source = gp.ArraySource(raw_key, dataset.raw)
if dataset.raw.channel_dims == 0:
raw_source += gp.Unsqueeze([raw_key], axis=0)
if self.clip_raw:
raw_source += gp.Crop(
raw_key, dataset.gt.roi.snap_to_grid(dataset.raw.voxel_size)
Expand Down
8 changes: 4 additions & 4 deletions tests/fixtures/datasplits.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def twelve_class_datasplit(tmp_path):
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
raw_dataset.attrs["resolution"] = (2, 2, 2)
raw_dataset.attrs["voxel_size"] = (2, 2, 2)
raw_dataset.attrs["axis_names"] = ("z", "y", "x")
gt_dataset.attrs["offset"] = (0, 0, 0)
gt_dataset.attrs["resolution"] = (2, 2, 2)
gt_dataset.attrs["voxel_size"] = (2, 2, 2)
gt_dataset.attrs["axis_names"] = ("z", "y", "x")

crop1 = RawGTDatasetConfig(name="crop1", raw_config=crop1_raw, gt_config=crop1_gt)
Expand Down Expand Up @@ -184,10 +184,10 @@ def six_class_datasplit(tmp_path):
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
raw_dataset.attrs["resolution"] = (2, 2, 2)
raw_dataset.attrs["voxel_size"] = (2, 2, 2)
raw_dataset.attrs["axis_names"] = ("z", "y", "x")
gt_dataset.attrs["offset"] = (0, 0, 0)
gt_dataset.attrs["resolution"] = (2, 2, 2)
gt_dataset.attrs["voxel_size"] = (2, 2, 2)
gt_dataset.attrs["axis_names"] = ("z", "y", "x")

crop1 = RawGTDatasetConfig(
Expand Down
2 changes: 1 addition & 1 deletion tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def unet_architecture(batch_norm, upsample, use_attention, three_d):
name=name,
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=2,
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
Expand Down

0 comments on commit 080f00a

Please sign in to comment.