Skip to content

Commit

Permalink
Funlib persistence update (#322)
Browse files Browse the repository at this point in the history
Upgrade to funlib.persistence `0.5`.

This update makes a one big improvement:
Custom `Array` class no longer needed. We used this mostly just to apply
preprocessing lazily to large arrays. New `funlib` `Array` class uses
`dask` internally which comes with much better support for lazy array
operations than we built for ourselves. The `ZarrArray` and `NumpyArray`
class which were used extensively throughout `DaCapo` have now been
replaced with simple `funlib.persistence.Array`s.

A minor incompatibility:
`funlib.persistence.Array` has a convention (for now) that all axes have
names, but non-spatial axes have a "^" in their name. This will be fixed
in the near future. For now, DaCapo convention needed to change a little
bit to adapt to this. We now have to use "c^" and "b^" for channel and
batch dimensions instead of just "c" and "b".

TODOs:
This pull request is not quire ready to merge. I pass the tests run with
`pytest`, and the `minimal_tutorial` notebook executes. But there is a
lot of code that is not tested. Specifically many of the `ArrayConfig`
subclasses are not yet tested and some are missing implementations.

Here are the Preprocessing array configs, whether or not their
implementation is complete, and their code coverage:
- [X] BinarizeArrayConfig 96%
- [X] ConcatArrayConfig 60%
- [X] ConstantArrayConfig 57%
- [X] CropArrayConfig 69%
- [X] DummyArrayConfig 91%
- [ ] DVIDArrayConfig 90% (misleading, only skeleton implementation so
not much to test)
- [X] IntensitiesArrayConfig 75%
- [X] LogicalOrArrayConfig 60%
- [x] MergeInstancesArrayConfig 100% (misleading, no implementation so
nothing to test)
- [x] MissingAnnotationsMaskConfig 100% (misleading)
- [x] OnesArrayConfig 100% (misleading)
- [ ] ResampledArrayConfig 100% (misleading)
- [x] SumArrayConfig 100% (misleading)
- [x] TiffArrayConfig 0%
- [X] ZarrArrayConfig 70%

Best practice would be to add tests before merging, but I want to put
this here so others can test it
  • Loading branch information
mzouink authored Nov 6, 2024
2 parents 108db88 + 1d5c501 commit aeb77a6
Show file tree
Hide file tree
Showing 81 changed files with 1,039 additions and 8,622 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
pip install .[docs]
- name: parse notebooks
run: jupytext --to notebook --execute ./docs/source/notebooks/*.py
# continue-on-error: true
run: |
jupytext --to notebook --execute ./examples/starter_tutorial/minimal_tutorial.py --output ./docs/source/notebooks/minimal_tutorial.ipynb
- name: remove notebook scripts
run: rm ./docs/source/notebooks/*.py
- name: Build and Commit
Expand Down
9 changes: 5 additions & 4 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
from funlib.persistence import open_ds
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

Expand All @@ -12,7 +12,6 @@
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store.create_store import (
create_config_store,
create_weights_store,
Expand Down Expand Up @@ -164,7 +163,9 @@ def apply(

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_ds(
f"{input_array_identifier.container}/{input_array_identifier.dataset}"
)
if roi is None:
_roi = input_array.roi
else:
Expand Down Expand Up @@ -226,7 +227,7 @@ def apply_run(
output_dtype (np.dtype | str, optional): The output data type. Defaults to np.uint8.
overwrite (bool, optional): Whether to overwrite existing output. Defaults to True.
Raises:
ValueError: If the input array is not a ZarrArray.
ValueError: If the input array is not a zarr array.
Examples:
>>> apply_run(
... run=run,
Expand Down
9 changes: 5 additions & 4 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from upath import UPath as Path
import sys
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray

from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.compute_context import create_compute_context
from dacapo.tmp import open_from_identifier

import daisy

Expand Down Expand Up @@ -82,12 +83,12 @@ def start_worker_fn(
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_from_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
output_array = open_from_identifier(output_array_identifier)

def io_loop():
# wait for blocks to run pipeline
Expand All @@ -102,7 +103,7 @@ def io_loop():
# write to output array
output_array[block.write_roi] = np.argmax(
input_array[block.write_roi],
axis=input_array.axes.index("c"),
axis=input_array.axis_names.index("c^"),
)

if return_io_loop:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/blockwise/empanada_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def start_consensus_worker(trackers_dict):
assert image.ndim in [3, 4], "Only 3D and 4D input images can be handled!"
if image.ndim == 4:
# channel dimensions are commonly 1, 3 and 4
# check for dimensions on zeroeth and last axes
# check for dimensions on zeroeth and last axis_names
shape = image.shape
if shape[0] in [1, 3, 4]:
image = image[0]
Expand Down
10 changes: 5 additions & 5 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Optional

import torch
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.gp import DaCapoArraySource

from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.experiments import Run
from dacapo.compute_context import create_compute_context
from dacapo.tmp import open_from_identifier
import gunpowder as gp
import gunpowder.torch as gp_torch

Expand Down Expand Up @@ -134,12 +134,12 @@ def io_loop():
input_array_identifier = LocalArrayIdentifier(
Path(input_container), input_dataset
)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)
raw_array = open_from_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
output_array = open_from_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True
Expand All @@ -163,7 +163,7 @@ def io_loop():
# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
pipeline = gp.ArraySource(raw, raw_array)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
Expand Down
7 changes: 4 additions & 3 deletions dacapo/blockwise/segment_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import numpy as np
import yaml
from dacapo.compute_context import create_compute_context
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.tmp import open_from_identifier


from dacapo.store.array_store import LocalArrayIdentifier

Expand Down Expand Up @@ -93,13 +94,13 @@ def start_worker_fn(
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
print(f"Opening input array {input_array_identifier}")
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_from_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
print(f"Opening output array {output_array_identifier}")
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
output_array = open_from_identifier(output_array_identifier)

# Load segment function
function_name = Path(function_path).stem
Expand Down
7 changes: 4 additions & 3 deletions dacapo/blockwise/threshold_worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from upath import UPath as Path
import sys
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray

from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.compute_context import create_compute_context
from dacapo.tmp import open_from_identifier

import daisy

Expand Down Expand Up @@ -82,12 +83,12 @@ def start_worker_fn(
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_from_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
output_array = open_from_identifier(output_array_identifier)

def io_loop():
# wait for blocks to run pipeline
Expand Down
20 changes: 11 additions & 9 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import click
import logging
from funlib.geometry import Roi, Coordinate
from funlib.persistence import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
Expand All @@ -16,7 +17,8 @@
segment_blockwise as _segment_blockwise,
)
from dacapo.store.local_array_store import LocalArrayIdentifier
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.tmp import open_from_identifier, create_from_identifier

from dacapo.options import DaCapoConfig
import os

Expand Down Expand Up @@ -474,7 +476,7 @@ def run_blockwise(
parameters = unpack_ctx(ctx)

input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_from_identifier(input_array_identifier)

_total_roi, read_roi, write_roi, _ = get_rois(
total_roi, read_roi_size, write_roi_size, input_array
Expand All @@ -485,9 +487,9 @@ def run_blockwise(
Path(output_container), output_dataset
)

ZarrArray.create_from_array_identifier(
create_from_identifier(
output_array_identifier,
input_array.axes,
input_array.axis_names,
_total_roi,
channels_out,
input_array.voxel_size,
Expand Down Expand Up @@ -652,7 +654,7 @@ def segment_blockwise(
parameters = unpack_ctx(ctx)

input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
input_array = open_from_identifier(input_array_identifier)

_total_roi, read_roi, write_roi, _context = get_rois(
total_roi, read_roi_size, write_roi_size, input_array
Expand All @@ -668,9 +670,9 @@ def segment_blockwise(
Path(output_container), output_dataset
)

ZarrArray.create_from_array_identifier(
create_from_identifier(
output_array_identifier,
input_array.axes,
input_array.axis_names,
_total_roi,
channels_out,
input_array.voxel_size,
Expand Down Expand Up @@ -845,15 +847,15 @@ def unpack_ctx(ctx):
return kwargs


def get_rois(total_roi, read_roi_size, write_roi_size, input_array):
def get_rois(total_roi, read_roi_size, write_roi_size, input_array: Array):
"""
Get the ROIs for processing.
Args:
total_roi (str): The total ROI to be processed.
read_roi_size (str): The size of the ROI to be read for each block.
write_roi_size (str): The size of the ROI to be written for each block.
input_array (ZarrArray): The input array.
input_array: The input array.
Returns:
tuple: A tuple containing the total ROI, read ROI, write ROI, and context.
Raises:
Expand Down
28 changes: 12 additions & 16 deletions dacapo/experiments/datasplits/datasets/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
from .array import Array # noqa
from .array_config import ArrayConfig # noqa

# configurable arrays
from .dummy_array_config import DummyArray, DummyArrayConfig # noqa
from .zarr_array_config import ZarrArray, ZarrArrayConfig # noqa
from .binarize_array_config import BinarizeArray, BinarizeArrayConfig # noqa
from .resampled_array_config import ResampledArray, ResampledArrayConfig # noqa
from .intensity_array_config import IntensitiesArray, IntensitiesArrayConfig # noqa
from .missing_annotations_mask import MissingAnnotationsMask # noqa
from .dummy_array_config import DummyArrayConfig # noqa
from .zarr_array_config import ZarrArrayConfig # noqa
from .binarize_array_config import BinarizeArrayConfig # noqa
from .resampled_array_config import ResampledArrayConfig # noqa
from .intensity_array_config import IntensitiesArrayConfig # noqa
from .missing_annotations_mask_config import MissingAnnotationsMaskConfig # noqa
from .ones_array_config import OnesArray, OnesArrayConfig # noqa
from .concat_array_config import ConcatArray, ConcatArrayConfig # noqa
from .logical_or_array_config import LogicalOrArray, LogicalOrArrayConfig # noqa
from .crop_array_config import CropArray, CropArrayConfig # noqa
from .ones_array_config import OnesArrayConfig # noqa
from .concat_array_config import ConcatArrayConfig # noqa
from .logical_or_array_config import LogicalOrArrayConfig # noqa
from .crop_array_config import CropArrayConfig # noqa
from .merge_instances_array_config import (
MergeInstancesArray,
MergeInstancesArrayConfig,
) # noqa
from .dvid_array_config import DVIDArray, DVIDArrayConfig
from .sum_array_config import SumArray, SumArrayConfig
from .dvid_array_config import DVIDArrayConfig
from .sum_array_config import SumArrayConfig

# nonconfigurable arrays (helpers)
from .numpy_array import NumpyArray # noqa
from .constant_array_config import ConstantArray, ConstantArrayConfig # noqa
from .constant_array_config import ConstantArrayConfig # noqa
Loading

0 comments on commit aeb77a6

Please sign in to comment.