From bd1d104b9d52b269eb58da06f918435ee242d4ef Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 13:14:09 -0800 Subject: [PATCH 1/4] Add downsampling support and a test for it --- .../datasets/arrays/resampled_array_config.py | 39 +++++++++++++++++-- tests/components/test_preprocessing.py | 32 +++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 tests/components/test_preprocessing.py diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index 7a03f89eb..9333b5b41 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -5,6 +5,26 @@ from funlib.geometry import Coordinate from funlib.persistence import Array +from xarray_multiscale.multiscale import downscale_dask +from xarray_multiscale import windowed_mean +import numpy as np +import dask.array as da + +from typing import Sequence + + +def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array: + """ + Crop array to a shape that is a multiple of the scale factors. + This allows for clean downsampling. + """ + misalignment = np.any(np.mod(array.shape, scale_factors)) + if misalignment: + new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors)) + slices = tuple(slice(0, s) for s in new_shape) + array = array[slices] + return array + @attr.s class ResampledArrayConfig(ArrayConfig): @@ -38,6 +58,19 @@ class ResampledArrayConfig(ArrayConfig): ) def array(self, mode: str = "r") -> Array: - # This is non trivial. We want to upsample or downsample the source - # array lazily. Not entirely sure how to do this with dask arrays. - raise NotImplementedError() + source_array = self.source_array_config.array(mode) + + if self.downsample is not None: + return Array( + data=downscale_dask( + adjust_shape(source_array.data, self.downsample), + windowed_mean, + scale_factors=self.downsample, + ), + offset=source_array.offset, + voxel_size=source_array.voxel_size * 2, + axis_names=source_array.axis_names, + units=source_array.units, + ) + elif self.upsample is not None: + raise NotImplementedError("Upsampling not yet implemented") diff --git a/tests/components/test_preprocessing.py b/tests/components/test_preprocessing.py new file mode 100644 index 000000000..a3f526940 --- /dev/null +++ b/tests/components/test_preprocessing.py @@ -0,0 +1,32 @@ +from dacapo.experiments.datasplits.datasets.arrays.resampled_array_config import ( + ResampledArrayConfig, +) + +import numpy as np +from funlib.persistence import Array +from funlib.geometry import Coordinate + + +def test_resample(): + # test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works + for top in [11,12]: + arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,)) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2,), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate((6,)) + assert resampled.shape == (5,) + assert np.allclose(resampled[:], np.array([1.5, 3.5, 5.5, 7.5, 9.5])) + + # test 2D array + arr = Array( + np.array(np.arange(1, 11).reshape(5, 2).T), offset=(0, 0), voxel_size=(3, 3) + ) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2, 1), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate(6, 3) + assert resampled.shape == (1, 5) + assert np.allclose(resampled[:], np.array([[1.5, 3.5, 5.5, 7.5, 9.5]])) From 3ed2d55a3052ada1cbebe8909f825674d73b76a8 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:18:00 -0800 Subject: [PATCH 2/4] pass tests --- .../datasets/arrays/resampled_array_config.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index 9333b5b41..e2889d01f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -13,7 +13,7 @@ from typing import Sequence -def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array: +def adjust_shape(array: da.Array, scale_factors: Sequence[int]) -> da.Array: """ Crop array to a shape that is a multiple of the scale factors. This allows for clean downsampling. @@ -57,20 +57,27 @@ class ResampledArrayConfig(ArrayConfig): metadata={"help_text": "The order of the interpolation!"} ) - def array(self, mode: str = "r") -> Array: - source_array = self.source_array_config.array(mode) - + def preprocess(self, array: Array) -> Array: + """ + Preprocess an array by resampling it to the desired voxel size. + """ if self.downsample is not None: + downsample = Coordinate(self.downsample) return Array( data=downscale_dask( - adjust_shape(source_array.data, self.downsample), + adjust_shape(array.data, downsample), windowed_mean, - scale_factors=self.downsample, + scale_factors=downsample, ), - offset=source_array.offset, - voxel_size=source_array.voxel_size * 2, - axis_names=source_array.axis_names, - units=source_array.units, + offset=array.offset, + voxel_size=array.voxel_size * downsample, + axis_names=array.axis_names, + units=array.units, ) elif self.upsample is not None: raise NotImplementedError("Upsampling not yet implemented") + + def array(self, mode: str = "r") -> Array: + source_array = self.source_array_config.array(mode) + + return self.preprocess(source_array) From 72212e0b72ab26935b85cbfda0099c70a8459d7d Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:18:25 -0800 Subject: [PATCH 3/4] format preprocessing test --- tests/components/test_preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/components/test_preprocessing.py b/tests/components/test_preprocessing.py index a3f526940..14aa7d1df 100644 --- a/tests/components/test_preprocessing.py +++ b/tests/components/test_preprocessing.py @@ -9,7 +9,7 @@ def test_resample(): # test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works - for top in [11,12]: + for top in [11, 12]: arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,)) resample_config = ResampledArrayConfig( "test_resample", None, upsample=None, downsample=(2,), interp_order=1 From 7f8554ce9dcc0199726c0b3f16bcf13bb22287fa Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 26 Nov 2024 14:19:59 -0800 Subject: [PATCH 4/4] add xarray-multiscale dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 012def526..7f6bc2bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "upath", "boto3", "matplotlib", + "xarray-multiscale", ] # extras