From bd1d104b9d52b269eb58da06f918435ee242d4ef Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 13:14:09 -0800 Subject: [PATCH 1/7] Add downsampling support and a test for it --- .../datasets/arrays/resampled_array_config.py | 39 +++++++++++++++++-- tests/components/test_preprocessing.py | 32 +++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 tests/components/test_preprocessing.py diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index 7a03f89eb..9333b5b41 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -5,6 +5,26 @@ from funlib.geometry import Coordinate from funlib.persistence import Array +from xarray_multiscale.multiscale import downscale_dask +from xarray_multiscale import windowed_mean +import numpy as np +import dask.array as da + +from typing import Sequence + + +def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array: + """ + Crop array to a shape that is a multiple of the scale factors. + This allows for clean downsampling. + """ + misalignment = np.any(np.mod(array.shape, scale_factors)) + if misalignment: + new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors)) + slices = tuple(slice(0, s) for s in new_shape) + array = array[slices] + return array + @attr.s class ResampledArrayConfig(ArrayConfig): @@ -38,6 +58,19 @@ class ResampledArrayConfig(ArrayConfig): ) def array(self, mode: str = "r") -> Array: - # This is non trivial. We want to upsample or downsample the source - # array lazily. Not entirely sure how to do this with dask arrays. - raise NotImplementedError() + source_array = self.source_array_config.array(mode) + + if self.downsample is not None: + return Array( + data=downscale_dask( + adjust_shape(source_array.data, self.downsample), + windowed_mean, + scale_factors=self.downsample, + ), + offset=source_array.offset, + voxel_size=source_array.voxel_size * 2, + axis_names=source_array.axis_names, + units=source_array.units, + ) + elif self.upsample is not None: + raise NotImplementedError("Upsampling not yet implemented") diff --git a/tests/components/test_preprocessing.py b/tests/components/test_preprocessing.py new file mode 100644 index 000000000..a3f526940 --- /dev/null +++ b/tests/components/test_preprocessing.py @@ -0,0 +1,32 @@ +from dacapo.experiments.datasplits.datasets.arrays.resampled_array_config import ( + ResampledArrayConfig, +) + +import numpy as np +from funlib.persistence import Array +from funlib.geometry import Coordinate + + +def test_resample(): + # test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works + for top in [11,12]: + arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,)) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2,), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate((6,)) + assert resampled.shape == (5,) + assert np.allclose(resampled[:], np.array([1.5, 3.5, 5.5, 7.5, 9.5])) + + # test 2D array + arr = Array( + np.array(np.arange(1, 11).reshape(5, 2).T), offset=(0, 0), voxel_size=(3, 3) + ) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2, 1), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate(6, 3) + assert resampled.shape == (1, 5) + assert np.allclose(resampled[:], np.array([[1.5, 3.5, 5.5, 7.5, 9.5]])) From 3ed2d55a3052ada1cbebe8909f825674d73b76a8 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:18:00 -0800 Subject: [PATCH 2/7] pass tests --- .../datasets/arrays/resampled_array_config.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index 9333b5b41..e2889d01f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -13,7 +13,7 @@ from typing import Sequence -def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array: +def adjust_shape(array: da.Array, scale_factors: Sequence[int]) -> da.Array: """ Crop array to a shape that is a multiple of the scale factors. This allows for clean downsampling. @@ -57,20 +57,27 @@ class ResampledArrayConfig(ArrayConfig): metadata={"help_text": "The order of the interpolation!"} ) - def array(self, mode: str = "r") -> Array: - source_array = self.source_array_config.array(mode) - + def preprocess(self, array: Array) -> Array: + """ + Preprocess an array by resampling it to the desired voxel size. + """ if self.downsample is not None: + downsample = Coordinate(self.downsample) return Array( data=downscale_dask( - adjust_shape(source_array.data, self.downsample), + adjust_shape(array.data, downsample), windowed_mean, - scale_factors=self.downsample, + scale_factors=downsample, ), - offset=source_array.offset, - voxel_size=source_array.voxel_size * 2, - axis_names=source_array.axis_names, - units=source_array.units, + offset=array.offset, + voxel_size=array.voxel_size * downsample, + axis_names=array.axis_names, + units=array.units, ) elif self.upsample is not None: raise NotImplementedError("Upsampling not yet implemented") + + def array(self, mode: str = "r") -> Array: + source_array = self.source_array_config.array(mode) + + return self.preprocess(source_array) From 72212e0b72ab26935b85cbfda0099c70a8459d7d Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:18:25 -0800 Subject: [PATCH 3/7] format preprocessing test --- tests/components/test_preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/components/test_preprocessing.py b/tests/components/test_preprocessing.py index a3f526940..14aa7d1df 100644 --- a/tests/components/test_preprocessing.py +++ b/tests/components/test_preprocessing.py @@ -9,7 +9,7 @@ def test_resample(): # test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works - for top in [11,12]: + for top in [11, 12]: arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,)) resample_config = ResampledArrayConfig( "test_resample", None, upsample=None, downsample=(2,), interp_order=1 From 7f8554ce9dcc0199726c0b3f16bcf13bb22287fa Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:19:59 -0800 Subject: [PATCH 4/7] add xarray-multiscale dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 012def526..7f6bc2bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "upath", "boto3", "matplotlib", + "xarray-multiscale", ] # extras From fd5e1172c2ab416b14de3b0627fba744435c7e11 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 5 Dec 2024 10:23:18 -0800 Subject: [PATCH 5/7] Fix some potential memory leaks in the array configs --- .../datasets/arrays/binarize_array_config.py | 13 +++++++------ .../datasets/arrays/concat_array_config.py | 12 ++++-------- .../datasets/arrays/logical_or_array_config.py | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py index 570739f63..4df39abb4 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py @@ -55,12 +55,13 @@ def array(self, mode="r") -> Array: assert num_channels is None, "Input labels cannot have a channel dimension" def group_array(data): - out = da.zeros((len(self.groupings), *array.physical_shape), dtype=np.uint8) - for i, (_, group_ids) in enumerate(self.groupings): - if len(group_ids) == 0: - out[i] = data != self.background - else: - out[i] = da.isin(data, group_ids) + groups = [ + da.isin(data, group_ids) + if len(group_ids) > 0 + else data != self.background + for _, group_ids in self.groupings + ] + out = da.stack(groups, axis=0) return out data = group_array(array.data) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py index 4de730b18..1de31ccdf 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py @@ -4,7 +4,6 @@ from typing import List, Dict, Optional from funlib.persistence import Array -import numpy as np import dask.array as da @@ -45,18 +44,15 @@ class ConcatArrayConfig(ArrayConfig): def array(self, mode: str = "r") -> Array: arrays = [config.array(mode) for _, config in self.source_array_configs.items()] + out_data = da.stack([array.data for array in arrays], axis=0) out_array = Array( - da.zeros(len(arrays), *arrays[0].physical_shape, dtype=arrays[0].dtype), + out_data, offset=arrays[0].offset, voxel_size=arrays[0].voxel_size, axis_names=["c^"] + arrays[0].axis_names, units=arrays[0].units, ) - def set_channels(data): - for i, array in enumerate(arrays): - data[i] = array.data[:] - return data - - out_array.lazy_op(set_channels) + # callable lazy op so funlib.persistence doesn't try to recoginize this data as writable + out_array.lazy_op(lambda data: data) return out_array diff --git a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py index a9cde5daa..432e70d3f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py @@ -30,14 +30,14 @@ def array(self, mode: str = "r") -> Array: assert num_channels_from_array(array) is not None out_array = Array( - da.zeros(*array.physical_shape, dtype=array.dtype), + da.zeros(array.physical_shape, dtype=array.dtype), offset=array.offset, voxel_size=array.voxel_size, axis_names=array.axis_names[1:], units=array.units, ) - out_array.data = da.maximum(array.data, axis=0) + out_array.data = da.max(array.data, axis=0) # mark data as non-writable out_array.lazy_op(lambda data: data) From bf04fb30ddceefb32a508bac72a5827137446696 Mon Sep 17 00:00:00 2001 From: William Patton Date: Mon, 9 Dec 2024 19:51:51 -0800 Subject: [PATCH 6/7] remove unnecessary assert and handle boolean labels --- dacapo/experiments/tasks/predictors/hot_distance_predictor.py | 2 +- tests/components/test_gp_arraysource.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index f2ec4f874..ec27e3346 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -1,7 +1,6 @@ from dacapo.experiments.arraytypes.probabilities import ProbabilityArray from .predictor import Predictor from dacapo.experiments import Model -from dacapo.experiments.arraytypes import DistanceArray from dacapo.tmp import np_to_funlib_array from dacapo.utils.balance_weights import balance_weights @@ -394,6 +393,7 @@ def __find_boundaries(self, labels): # bound.: 00000001000100000001000 2n - 1 logger.debug(f"computing boundaries for {labels.shape}") + labels = labels.astype(np.uint8) dims = len(labels.shape) in_shape = labels.shape diff --git a/tests/components/test_gp_arraysource.py b/tests/components/test_gp_arraysource.py index 58a4b23ba..68a5b5ae2 100644 --- a/tests/components/test_gp_arraysource.py +++ b/tests/components/test_gp_arraysource.py @@ -30,6 +30,5 @@ def test_gp_dacapo_array_source(array_config): batch = source_node.request_batch(request) data = batch[key].data if data.dtype == bool: - raise ValueError("Data should not be bools") data = data.astype(np.uint8) assert (data - array[array.roi]).sum() == 0 From 2467cd88cba89ce42c6b6a4dea063016c87edcac Mon Sep 17 00:00:00 2001 From: William Patton Date: Mon, 9 Dec 2024 19:56:21 -0800 Subject: [PATCH 7/7] add xarray_multiscale to mypy linting ignore list --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7f6bc2bcf..9070e4b37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,6 +202,7 @@ module = [ "napari.*", "empanada.*", "IPython.*", + "xarray_multiscale.*" ] ignore_missing_imports = true