Skip to content

Commit

Permalink
Add downsampling support and a test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 26, 2024
1 parent 1df45db commit bd1d104
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from funlib.geometry import Coordinate
from funlib.persistence import Array

from xarray_multiscale.multiscale import downscale_dask
from xarray_multiscale import windowed_mean
import numpy as np
import dask.array as da

from typing import Sequence


def adjust_shape(array: da.Array, scale_factors: Sequence(int)) -> da.Array:
"""
Crop array to a shape that is a multiple of the scale factors.
This allows for clean downsampling.
"""
misalignment = np.any(np.mod(array.shape, scale_factors))
if misalignment:
new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors))
slices = tuple(slice(0, s) for s in new_shape)
array = array[slices]
return array


@attr.s
class ResampledArrayConfig(ArrayConfig):
Expand Down Expand Up @@ -38,6 +58,19 @@ class ResampledArrayConfig(ArrayConfig):
)

def array(self, mode: str = "r") -> Array:
# This is non trivial. We want to upsample or downsample the source
# array lazily. Not entirely sure how to do this with dask arrays.
raise NotImplementedError()
source_array = self.source_array_config.array(mode)

if self.downsample is not None:
return Array(
data=downscale_dask(
adjust_shape(source_array.data, self.downsample),
windowed_mean,
scale_factors=self.downsample,
),
offset=source_array.offset,
voxel_size=source_array.voxel_size * 2,
axis_names=source_array.axis_names,
units=source_array.units,
)
elif self.upsample is not None:
raise NotImplementedError("Upsampling not yet implemented")
32 changes: 32 additions & 0 deletions tests/components/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dacapo.experiments.datasplits.datasets.arrays.resampled_array_config import (
ResampledArrayConfig,
)

import numpy as np
from funlib.persistence import Array
from funlib.geometry import Coordinate


def test_resample():
# test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works
for top in [11,12]:
arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,))
resample_config = ResampledArrayConfig(
"test_resample", None, upsample=None, downsample=(2,), interp_order=1
)
resampled = resample_config.preprocess(arr)
assert resampled.voxel_size == Coordinate((6,))
assert resampled.shape == (5,)
assert np.allclose(resampled[:], np.array([1.5, 3.5, 5.5, 7.5, 9.5]))

# test 2D array
arr = Array(
np.array(np.arange(1, 11).reshape(5, 2).T), offset=(0, 0), voxel_size=(3, 3)
)
resample_config = ResampledArrayConfig(
"test_resample", None, upsample=None, downsample=(2, 1), interp_order=1
)
resampled = resample_config.preprocess(arr)
assert resampled.voxel_size == Coordinate(6, 3)
assert resampled.shape == (1, 5)
assert np.allclose(resampled[:], np.array([[1.5, 3.5, 5.5, 7.5, 9.5]]))

0 comments on commit bd1d104

Please sign in to comment.