From 7373f4b80a6b1764fc6ab2a9726373207a92b2f4 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Mon, 19 Aug 2024 14:13:55 +0100 Subject: [PATCH] Fix map_overlap with new_axis (#11128) --- dask/array/overlap.py | 21 +++++++++++++++++++++ dask/array/tests/test_overlap.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/dask/array/overlap.py b/dask/array/overlap.py index ff4e2de03c6..c996831a6b3 100644 --- a/dask/array/overlap.py +++ b/dask/array/overlap.py @@ -741,6 +741,27 @@ def assert_int_chunksize(xs): # note that keys are relabeled to match values in range(x.ndim) depth = {n: depth[ax] for n, ax in enumerate(kept_axes)} boundary = {n: boundary[ax] for n, ax in enumerate(kept_axes)} + + # add any new axes to depth and boundary variables + new_axis = kwargs.pop("new_axis", None) + if new_axis is not None: + if isinstance(new_axis, Number): + new_axis = [new_axis] + + # convert negative new_axis to equivalent positive value + ndim_out = max(a.ndim for a in args if isinstance(a, Array)) + new_axis = [d % ndim_out for d in new_axis] + + for axis in new_axis: + for existing_axis in list(depth.keys()): + if existing_axis >= axis: + # Shuffle existing axis forward to give room to insert new_axis + depth[existing_axis + 1] = depth[existing_axis] + boundary[existing_axis + 1] = boundary[existing_axis] + + depth[axis] = 0 + boundary[axis] = "none" + return trim_internal(x, depth, boundary) else: return x diff --git a/dask/array/tests/test_overlap.py b/dask/array/tests/test_overlap.py index 65908592320..42a5c87e85c 100644 --- a/dask/array/tests/test_overlap.py +++ b/dask/array/tests/test_overlap.py @@ -826,3 +826,17 @@ def test_sliding_window_errors(window_shape, axis): arr = da.zeros((4, 3)) with pytest.raises(ValueError): sliding_window_view(arr, window_shape, axis) + + +def test_map_overlap_new_axis(): + arr = da.arange(6, chunks=2) + assert arr.shape == (6,) + assert arr.chunks == ((2, 2, 2),) + + actual = arr.map_overlap(lambda x: np.stack([x, x + 0.5]), depth=1, new_axis=[0]) + expected = np.stack([np.arange(6), np.arange(6) + 0.5]) + + assert actual.chunks == ((1,), (2, 2, 2)) + # Shape and chunks aren't known until array is computed, + # so don't expclitly check shape or chunks in assert_eq + assert_eq(expected, actual, check_shape=False, check_chunks=False)