diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index 7a03f89eb..9333b5b41 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -5,6 +5,26 @@ from funlib.geometry import Coordinate from funlib.persistence import Array +from xarray_multiscale.multiscale import downscale_dask +from xarray_multiscale import windowed_mean +import numpy as np +import dask.array as da + +from typing import Sequence + + +def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array: + """ + Crop array to a shape that is a multiple of the scale factors. + This allows for clean downsampling. + """ + misalignment = np.any(np.mod(array.shape, scale_factors)) + if misalignment: + new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors)) + slices = tuple(slice(0, s) for s in new_shape) + array = array[slices] + return array + @attr.s class ResampledArrayConfig(ArrayConfig): @@ -38,6 +58,19 @@ class ResampledArrayConfig(ArrayConfig): ) def array(self, mode: str = "r") -> Array: - # This is non trivial. We want to upsample or downsample the source - # array lazily. Not entirely sure how to do this with dask arrays. - raise NotImplementedError() + source_array = self.source_array_config.array(mode) + + if self.downsample is not None: + return Array( + data=downscale_dask( + adjust_shape(source_array.data, self.downsample), + windowed_mean, + scale_factors=self.downsample, + ), + offset=source_array.offset, + voxel_size=source_array.voxel_size * 2, + axis_names=source_array.axis_names, + units=source_array.units, + ) + elif self.upsample is not None: + raise NotImplementedError("Upsampling not yet implemented") diff --git a/tests/components/test_preprocessing.py b/tests/components/test_preprocessing.py new file mode 100644 index 000000000..a3f526940 --- /dev/null +++ b/tests/components/test_preprocessing.py @@ -0,0 +1,32 @@ +from dacapo.experiments.datasplits.datasets.arrays.resampled_array_config import ( + ResampledArrayConfig, +) + +import numpy as np +from funlib.persistence import Array +from funlib.geometry import Coordinate + + +def test_resample(): + # test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works + for top in [11,12]: + arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,)) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2,), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate((6,)) + assert resampled.shape == (5,) + assert np.allclose(resampled[:], np.array([1.5, 3.5, 5.5, 7.5, 9.5])) + + # test 2D array + arr = Array( + np.array(np.arange(1, 11).reshape(5, 2).T), offset=(0, 0), voxel_size=(3, 3) + ) + resample_config = ResampledArrayConfig( + "test_resample", None, upsample=None, downsample=(2, 1), interp_order=1 + ) + resampled = resample_config.preprocess(arr) + assert resampled.voxel_size == Coordinate(6, 3) + assert resampled.shape == (1, 5) + assert np.allclose(resampled[:], np.array([[1.5, 3.5, 5.5, 7.5, 9.5]]))