Skip to content

Commit

Permalink
Upsample (#354)
Browse files Browse the repository at this point in the history
Add support for upsample.

Unfortunately this implementation seems to accumulate memory as a result
of repeated querying. See this memory plot while randomly reading 1000
small (100x100x100) cubes

![upscale_mem](https://github.com/user-attachments/assets/27fb6137-0a78-426e-b682-26f82e4fa2d6)

Its a fairly direct usage of the `dask` `map_overlap` function so I am
guessing there is a way to avoid any caching, I just haven't found it
yet
  • Loading branch information
mzouink authored Jan 2, 2025
2 parents e9f255c + 59c8d64 commit c123a44
Showing 1 changed file with 46 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xarray_multiscale import windowed_mean
import numpy as np
import dask.array as da
from skimage.transform import rescale

from typing import Sequence

Expand Down Expand Up @@ -47,10 +48,10 @@ class ResampledArrayConfig(ArrayConfig):
metadata={"help_text": "The Array that you want to upsample or downsample."}
)

upsample: Coordinate = attr.ib(
_upsample: Coordinate = attr.ib(
metadata={"help_text": "The amount by which to upsample!"}
)
downsample: Coordinate = attr.ib(
_downsample: Coordinate = attr.ib(
metadata={"help_text": "The amount by which to downsample!"}
)
interp_order: bool = attr.ib(
Expand All @@ -62,20 +63,60 @@ def preprocess(self, array: Array) -> Array:
Preprocess an array by resampling it to the desired voxel size.
"""
if self.downsample is not None:
downsample = Coordinate(self.downsample)
downsample = list(self.downsample)
for i, axis_name in enumerate(array.axis_names):
if "^" in axis_name:
downsample = downsample[:i] + [1] + downsample[i:]
return Array(
data=downscale_dask(
adjust_shape(array.data, downsample),
windowed_mean,
scale_factors=downsample,
),
offset=array.offset,
voxel_size=array.voxel_size * downsample,
voxel_size=array.voxel_size * self.downsample,
axis_names=array.axis_names,
units=array.units,
)
elif self.upsample is not None:
raise NotImplementedError("Upsampling not yet implemented")
upsample = list(self.upsample)
for i, axis_name in enumerate(array.axis_names):
if "^" in axis_name:
upsample = upsample[:i] + [1] + upsample[i:]

depth = [int(x > 1) for x in upsample]
trim_slicing = tuple(
slice(d * s, (-d * s)) if d > 1 else slice(None)
for d, s in zip(depth, upsample)
)

rescaled_arr = da.map_overlap(
lambda x: rescale(
x, upsample, order=int(self.interp_order), preserve_range=True
)[trim_slicing],
array.data,
depth=depth,
boundary="reflect",
trim=False,
dtype=array.data.dtype,
chunks=tuple(c * u for c, u in zip(array.data.chunksize, upsample)),
)

return Array(
data=rescaled_arr,
offset=array.offset,
voxel_size=array.voxel_size / self.upsample,
axis_names=array.axis_names,
units=array.units,
)

@property
def upsample(self) -> Coordinate:
return Coordinate(self._upsample)

@property
def downsample(self) -> Coordinate:
return Coordinate(self._downsample)

def array(self, mode: str = "r") -> Array:
source_array = self.source_array_config.array(mode)
Expand Down

0 comments on commit c123a44

Please sign in to comment.