Skip to content

Commit

Permalink
Fix local predict (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Oct 17, 2024
2 parents 8de17a3 + 27f531c commit 82df104
Show file tree
Hide file tree
Showing 23 changed files with 468 additions and 304 deletions.
7 changes: 6 additions & 1 deletion dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def run_blockwise(
)
print("Running blockwise with worker_file: ", worker_file)
print(f"Using compute context: {create_compute_context()}")
success = daisy.run_blockwise([task])
compute_context = create_compute_context()
print(f"Using compute context: {compute_context}")

multiprocessing = compute_context.distribute_workers

success = daisy.run_blockwise([task], multiprocessing=multiprocessing)
return success


Expand Down
15 changes: 13 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ def cli(log_level):
@click.option(
"-r", "--run-name", required=True, type=str, help="The NAME of the run to train."
)
def train(run_name):
dacapo.train(run_name) # TODO: run with compute_context
@click.option(
"--no-validation", is_flag=True, help="Disable validation after training."
)
def train(run_name, no_validation):
"""
Train a model with the specified run name.
Args:
run_name (str): The name of the run to train.
no_validation (bool): Flag to disable validation after training.
"""
do_validate = not no_validation
dacapo.train(run_name, do_validate=do_validate)


@cli.command()
Expand Down
38 changes: 35 additions & 3 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, architecture_config):
activation after the upsample operation.
- use_attention (optional): Whether or not to use an attention block
in the U-Net.
- batch_norm (optional): Whether to use batch normalization.
Raises:
ValueError: If the input shape is not given.
Examples:
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(self, architecture_config):
self.upsample_factors if self.upsample_factors is not None else []
)
self.use_attention = architecture_config.use_attention
self.batch_norm = architecture_config.batch_norm

self.unet = self.module()

Expand Down Expand Up @@ -261,6 +263,7 @@ def module(self):
upsample_channel_contraction=[False]
+ [True] * (len(downsample_factors) - 1),
use_attention=self.use_attention,
batch_norm=self.batch_norm,
)
if len(self.upsample_factors) > 0:
layers = [unet]
Expand All @@ -279,6 +282,7 @@ def module(self):
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
activation="ReLU",
batch_norm=self.batch_norm,
)
layers.append(conv)
unet = torch.nn.Sequential(*layers)
Expand Down Expand Up @@ -455,6 +459,7 @@ def __init__(
upsample_channel_contraction=False,
activation_on_upsample=False,
use_attention=False,
batch_norm=True,
):
"""
Create a U-Net::
Expand Down Expand Up @@ -573,6 +578,7 @@ def __init__(

self.dims = len(downsample_factors[0])
self.use_attention = use_attention
self.batch_norm = batch_norm

# default arguments

Expand Down Expand Up @@ -611,6 +617,7 @@ def __init__(
kernel_size_down[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels)
]
Expand Down Expand Up @@ -668,6 +675,7 @@ def __init__(
),
dims=self.dims,
upsample_factor=downsample_factors[level],
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand All @@ -694,6 +702,7 @@ def __init__(
kernel_size_up[level],
activation=activation,
padding=padding,
batch_norm=self.batch_norm,
)
for level in range(self.num_levels - 1)
]
Expand Down Expand Up @@ -827,7 +836,13 @@ class ConvPass(torch.nn.Module):
"""

def __init__(
self, in_channels, out_channels, kernel_sizes, activation, padding="valid"
self,
in_channels,
out_channels,
kernel_sizes,
activation,
padding="valid",
batch_norm=True,
):
"""
Convolutional pass module. This module performs a series of
Expand Down Expand Up @@ -869,6 +884,15 @@ def __init__(

try:
layers.append(conv(in_channels, out_channels, kernel_size, padding=pad))
if batch_norm:
layers.append(
{
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}[
self.dims
](out_channels)
)
except KeyError:
raise RuntimeError("%dD convolution not implemented" % self.dims)

Expand Down Expand Up @@ -1283,7 +1307,7 @@ class AttentionBlockModule(nn.Module):
The AttentionBlockModule is an instance of the ``torch.nn.Module`` class.
"""

def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True):
"""
Initialize the Attention Block Module.
Expand All @@ -1306,13 +1330,19 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
self.batch_norm = batch_norm
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"
F_g,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
)

self.W_x = nn.Sequential(
Expand All @@ -1322,6 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
),
Downsample(upsample_factor),
)
Expand All @@ -1332,6 +1363,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
kernel_sizes=self.kernel_sizes,
activation="Sigmoid",
padding="same",
batch_norm=self.batch_norm,
)

up_mode = {2: "bilinear", 3: "trilinear"}[self.dims]
Expand Down
4 changes: 4 additions & 0 deletions dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ class CNNectomeUNetConfig(ArchitectureConfig):
"help_text": "Whether to use attention blocks in the UNet. This is supported for 2D and 3D."
},
)
batch_norm: bool = attr.ib(
default=True,
metadata={"help_text": "Whether to use batch normalization."},
)
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .dummy_datasplit_config import DummyDataSplitConfig
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator
from .datasplit_generator import DataSplitGenerator, DatasetSpec
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def roi(self) -> Roi:
This method returns the region of interest of the resampled array.
"""
return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink")
return self._source_array.roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink"
)

@property
def writable(self) -> bool:
Expand Down Expand Up @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
Note:
This method returns the data of the resampled array within the given region of interest.
"""
snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow")
snapped_roi = roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow"
)
resampled_array = funlib.persistence.Array(
rescale(
self._source_array[snapped_roi].astype(np.float32),
Expand Down
8 changes: 4 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from collections import OrderedDict
import logging
from upath import UPath as Path
import json
from typing import Dict, Tuple, Any, Optional, List

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -273,7 +271,9 @@ def roi(self) -> Roi:
This method is used to return the region of interest of the array.
"""
if self.snap_to_grid is not None:
return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink")
return self._daisy_array.roi.snap_to_grid(
np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink"
)
else:
return self._daisy_array.roi

Expand Down Expand Up @@ -469,7 +469,7 @@ def create_from_array_identifier(
write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size
write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape)))
zarr_container = zarr.open(array_identifier.container, "a")
if num_channels is None or num_channels == 1:
if num_channels is None:
axes = [axis for axis in axes if "c" not in axis]
num_channels = None
else:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/datasplits/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __repr__(self) -> str:
Notes:
This method is used to return the official string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def __str__(self) -> str:
"""
Expand All @@ -109,7 +109,7 @@ def __str__(self) -> str:
Notes:
This method is used to return the string representation of the dataset object.
"""
return f"Dataset({self.name})"
return f"ds_{self.name.replace('/', '_')}"

def _neuroglancer_layers(self, prefix="", exclude_layers=None):
"""
Expand Down
38 changes: 21 additions & 17 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def resize_if_needed(
f"have different dimensions {zarr_array.dims}"
)
if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]):
print(
f"dataset {array_config} needs resampling to {target_resolution}, upsample: {raw_upsample}, downsample: {raw_downsample}"
)
return ResampledArrayConfig(
name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled",
source_array_config=array_config,
Expand All @@ -93,6 +96,7 @@ def resize_if_needed(
interp_order=False,
)
else:
# print(f"dataset {array_config.dataset} does not need resampling found {raw_voxel_size}=={target_resolution}")
return array_config


Expand Down Expand Up @@ -959,23 +963,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
constant=1,
)

if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
else:
gt_config = list(target_images.values())[0]
mask_config = list(target_masks.values())[0]
# if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
# else:
# gt_config = list(target_images.values())[0]
# mask_config = list(target_masks.values())[0]

return raw_config, gt_config, mask_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def evaluate(self, output_array_identifier, evaluation_array):
This function is used to evaluate the output array against the evaluation array.
"""
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
evaluation_data = evaluation_array[evaluation_array.roi].squeeze()
output_data = output_array[output_array.roi].squeeze()
# removed the .squeeze() because it was used for batch size and now we are feeding 4d c, z, y, x
evaluation_data = evaluation_array[evaluation_array.roi]
output_data = output_array[output_array.roi]
print(
f"Evaluating binary segmentations on evaluation_data of shape: {evaluation_data.shape}"
)
Expand Down
Loading

0 comments on commit 82df104

Please sign in to comment.