Skip to content

Commit

Permalink
local blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Sep 23, 2024
1 parent 8d0dbf6 commit a1e46bd
Show file tree
Hide file tree
Showing 12 changed files with 375 additions and 325 deletions.
7 changes: 6 additions & 1 deletion dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def run_blockwise(
)
print("Running blockwise with worker_file: ", worker_file)
print(f"Using compute context: {create_compute_context()}")
success = daisy.run_blockwise([task])
compute_context = create_compute_context()
print(f"Using compute context: {compute_context}")

multiprocessing = compute_context.distribute_workers

success = daisy.run_blockwise([task], multiprocessing=multiprocessing)
return success


Expand Down
15 changes: 13 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ def cli(log_level):
@click.option(
"-r", "--run-name", required=True, type=str, help="The NAME of the run to train."
)
def train(run_name):
dacapo.train(run_name) # TODO: run with compute_context
@click.option(
"--no-validation", is_flag=True, help="Disable validation after training."
)
def train(run_name, no_validation):
"""
Train a model with the specified run name.
Args:
run_name (str): The name of the run to train.
no_validation (bool): Flag to disable validation after training.
"""
do_validate = not no_validation
dacapo.train(run_name, do_validate=do_validate)


@cli.command()
Expand Down
10 changes: 7 additions & 3 deletions dacapo/experiments/datasplits/datasets/arrays/resampled_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def roi(self) -> Roi:
This method returns the region of interest of the resampled array.
"""
return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink")
return self._source_array.roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="shrink"
)

@property
def writable(self) -> bool:
Expand Down Expand Up @@ -281,7 +283,9 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
Note:
This method returns the data of the resampled array within the given region of interest.
"""
snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow")
snapped_roi = roi.snap_to_grid(
np.lcm(self._source_array.voxel_size, self.voxel_size), mode="grow"
)
resampled_array = funlib.persistence.Array(
rescale(
self._source_array[snapped_roi].astype(np.float32),
Expand Down Expand Up @@ -352,4 +356,4 @@ def _source_name(self):
Note:
This method returns the name of the source array.
"""
return self._source_array._source_name()
return self._source_array._source_name()
81 changes: 19 additions & 62 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import OrderedDict
import logging
from upath import UPath as Path
import os
import json
from typing import Dict, Tuple, Any, Optional, List

Expand Down Expand Up @@ -273,7 +274,9 @@ def roi(self) -> Roi:
This method is used to return the region of interest of the array.
"""
if self.snap_to_grid is not None:
return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink")
return self._daisy_array.roi.snap_to_grid(
np.lcm(self.voxel_size, self.snap_to_grid), mode="shrink"
)
else:
return self._daisy_array.roi

Expand Down Expand Up @@ -426,33 +429,12 @@ def create_from_array_identifier(
num_channels,
voxel_size,
dtype,
mode="a",
write_size=None,
name=None,
overwrite=False,
):
"""
Create a new ZarrArray given an array identifier. It is assumed that
this array_identifier points to a dataset that does not yet exist.
Args:
array_identifier (ArrayIdentifier): The array identifier.
axes (List[str]): The axes of the array.
roi (Roi): The region of interest.
num_channels (int): The number of channels.
voxel_size (Coordinate): The voxel size.
dtype (Any): The data type.
write_size (Optional[Coordinate]): The write size.
name (Optional[str]): The name of the array.
overwrite (bool): The boolean value to overwrite the array.
Returns:
ZarrArray: The ZarrArray.
Raises:
NotImplementedError
Examples:
>>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False)
Notes:
This method is used to create a new ZarrArray given an array identifier.
this array_identifier points to a dataset that does not yet exist
"""
if write_size is None:
# total storage per block is approx c*x*y*z*dtype_size
Expand All @@ -469,11 +451,6 @@ def create_from_array_identifier(
write_size = Coordinate((axis_length,) * voxel_size.dims) * voxel_size
write_size = Coordinate((min(a, b) for a, b in zip(write_size, roi.shape)))
zarr_container = zarr.open(array_identifier.container, "a")
if num_channels is None or num_channels == 1:
axes = [axis for axis in axes if "c" not in axis]
num_channels = None
else:
axes = ["c"] + [axis for axis in axes if "c" not in axis]
try:
funlib.persistence.prepare_ds(
f"{array_identifier.container}",
Expand All @@ -483,41 +460,21 @@ def create_from_array_identifier(
dtype,
num_channels=num_channels,
write_size=write_size,
delete=overwrite,
force_exact_write_size=True,
)
zarr_dataset = zarr_container[array_identifier.dataset]
if array_identifier.container.name.endswith("n5"):
zarr_dataset.attrs["offset"] = roi.offset[::-1]
zarr_dataset.attrs["resolution"] = voxel_size[::-1]
zarr_dataset.attrs["axes"] = axes[::-1]
# to make display right in neuroglancer: TODO ADD CHANNELS
zarr_dataset.attrs["dimension_units"] = [
f"{size} nm" for size in voxel_size[::-1]
]
zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [
a if a != "c" else "c^" for a in axes[::-1]
]
else:
zarr_dataset.attrs["offset"] = roi.offset
zarr_dataset.attrs["resolution"] = voxel_size
zarr_dataset.attrs["axes"] = axes
# to make display right in neuroglancer: TODO ADD CHANNELS
zarr_dataset.attrs["dimension_units"] = [
f"{size} nm" for size in voxel_size
]
zarr_dataset.attrs["_ARRAY_DIMENSIONS"] = [
a if a != "c" else "c^" for a in axes
]
if "c" in axes:
if axes.index("c") == 0:
zarr_dataset.attrs["dimension_units"] = [
str(num_channels)
] + zarr_dataset.attrs["dimension_units"]
else:
zarr_dataset.attrs["dimension_units"] = zarr_dataset.attrs[
"dimension_units"
] + [str(num_channels)]
zarr_dataset.attrs["offset"] = (
roi.offset[::-1]
if array_identifier.container.name.endswith("n5")
else roi.offset
)
zarr_dataset.attrs["resolution"] = (
voxel_size[::-1]
if array_identifier.container.name.endswith("n5")
else voxel_size
)
zarr_dataset.attrs["axes"] = (
axes[::-1] if array_identifier.container.name.endswith("n5") else axes
)
except zarr.errors.ContainsArrayError:
zarr_dataset = zarr_container[array_identifier.dataset]
assert (
Expand Down Expand Up @@ -733,4 +690,4 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None:
"""
dataset = zarr.open(self.file_name, mode="a")[self.dataset]
for k, v in metadata.items():
dataset.attrs[k] = v
dataset.attrs[k] = v
34 changes: 17 additions & 17 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,23 +913,23 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
constant=1,
)

if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
else:
gt_config = list(target_images.values())[0]
mask_config = list(target_masks.values())[0]
# if len(target_images) > 1:
gt_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_gt",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: gt for k, gt in target_images.items()},
source_array_configs={k: target_images[k] for k in current_targets},
)
mask_config = ConcatArrayConfig(
name=f"{dataset}_{targets_str}_{self.output_resolution[0]}nm_mask",
channels=[organelle for organelle in current_targets],
# source_array_configs={k: mask for k, mask in target_masks.items()},
# to be sure to have the same order
source_array_configs={k: target_masks[k] for k in current_targets},
)
# else:
# gt_config = list(target_images.values())[0]
# mask_config = list(target_masks.values())[0]

return raw_config, gt_config, mask_config

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from upath import UPath as Path
from dacapo.blockwise.scheduler import run_blockwise
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from .threshold_post_processor_parameters import ThresholdPostProcessorParameters
from dacapo.store.array_store import LocalArrayIdentifier
from .post_processor import PostProcessor
import dacapo.blockwise
import numpy as np
import daisy
from daisy import Roi, Coordinate
from dacapo.utils.array_utils import to_ndarray, save_ndarray
from funlib.persistence import open_ds

from typing import Iterable

Expand Down Expand Up @@ -43,7 +43,7 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]:
Note:
This method should return a generator of instances of ``ThresholdPostProcessorParameters``.
"""
for i, threshold in enumerate([100, 127, 150]):
for i, threshold in enumerate([127]):
yield ThresholdPostProcessorParameters(id=i, threshold=threshold)

def set_prediction(self, prediction_array_identifier):
Expand Down Expand Up @@ -117,28 +117,31 @@ def process(
self.prediction_array.num_channels,
self.prediction_array.voxel_size,
np.uint8,
write_size,
)


read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
# run blockwise post-processing
sucess = run_blockwise(
worker_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py")
),
input_array = open_ds(self.prediction_array_identifier.container.path,self.prediction_array_identifier.dataset)

def process_block(block):
print("Predicting block", block.read_roi)
data = to_ndarray(input_array,block.read_roi) > parameters.threshold
if int(data.max()) == 0:
print("No data in block", block.read_roi)
return
save_ndarray(data, block.write_roi, output_array)

task = daisy.Task(
f"threshold_{output_array.dataset}",
total_roi=self.prediction_array.roi,
read_roi=read_roi,
write_roi=read_roi,
num_workers=num_workers,
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
input_array_identifier=self.prediction_array_identifier,
output_array_identifier=output_array_identifier,
threshold=parameters.threshold,
process_function=process_block,
check_function=None,
read_write_conflict=False,
fit="overhang",
max_retries=0,
timeout=None,
)

if not sucess:
raise RuntimeError("Blockwise post-processing failed.")

return output_array
return daisy.run_blockwise([task], multiprocessing=False)
22 changes: 17 additions & 5 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def create_optimizer(self, model):
optimizer = torch.optim.RAdam(
lr=self.learning_rate,
params=model.parameters(),
decoupled_weight_decay=True,
# decoupled_weight_decay=True,
)
self.scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
Expand Down Expand Up @@ -161,6 +161,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder = gp.ArrayKey("MASK_PLACEHOLDER")

target_key = gp.ArrayKey("TARGET")
dataset_weight_key = gp.ArrayKey("DATASET_WEIGHT")
datasets_weight_key = gp.ArrayKey("DATASETS_WEIGHT")
weight_key = gp.ArrayKey("WEIGHT")
sample_points_key = gp.GraphKey("SAMPLE_POINTS")

Expand Down Expand Up @@ -207,9 +209,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
mask_placeholder,
drop_channels=True,
)
+ gp.Pad(raw_key, None)
+ gp.Pad(gt_key, None)
+ gp.Pad(mask_key, None)
+ gp.Pad(raw_key, None, mode="constant", value=0)
+ gp.Pad(gt_key, None, mode="constant", value=0)
+ gp.Pad(mask_key, None, mode="constant", value=0)
+ gp.RandomLocation(
ensure_nonempty=(
sample_points_key if points_source is not None else None
Expand All @@ -225,6 +227,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)

Expand All @@ -233,10 +243,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=weight_key,
weights_key=datasets_weight_key,
mask_key=mask_key,
)

pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
pipeline += gp.PreCache(num_workers=self.num_data_fetchers)
Expand Down
Loading

0 comments on commit a1e46bd

Please sign in to comment.