Skip to content

Commit

Permalink
Fix map_overlap with new_axis (dask#11128)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby authored Aug 19, 2024
1 parent b7d9bf4 commit 7373f4b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
21 changes: 21 additions & 0 deletions dask/array/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,27 @@ def assert_int_chunksize(xs):
# note that keys are relabeled to match values in range(x.ndim)
depth = {n: depth[ax] for n, ax in enumerate(kept_axes)}
boundary = {n: boundary[ax] for n, ax in enumerate(kept_axes)}

# add any new axes to depth and boundary variables
new_axis = kwargs.pop("new_axis", None)
if new_axis is not None:
if isinstance(new_axis, Number):
new_axis = [new_axis]

# convert negative new_axis to equivalent positive value
ndim_out = max(a.ndim for a in args if isinstance(a, Array))
new_axis = [d % ndim_out for d in new_axis]

for axis in new_axis:
for existing_axis in list(depth.keys()):
if existing_axis >= axis:
# Shuffle existing axis forward to give room to insert new_axis
depth[existing_axis + 1] = depth[existing_axis]
boundary[existing_axis + 1] = boundary[existing_axis]

depth[axis] = 0
boundary[axis] = "none"

return trim_internal(x, depth, boundary)
else:
return x
Expand Down
14 changes: 14 additions & 0 deletions dask/array/tests/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,17 @@ def test_sliding_window_errors(window_shape, axis):
arr = da.zeros((4, 3))
with pytest.raises(ValueError):
sliding_window_view(arr, window_shape, axis)


def test_map_overlap_new_axis():
arr = da.arange(6, chunks=2)
assert arr.shape == (6,)
assert arr.chunks == ((2, 2, 2),)

actual = arr.map_overlap(lambda x: np.stack([x, x + 0.5]), depth=1, new_axis=[0])
expected = np.stack([np.arange(6), np.arange(6) + 0.5])

assert actual.chunks == ((1,), (2, 2, 2))
# Shape and chunks aren't known until array is computed,
# so don't expclitly check shape or chunks in assert_eq
assert_eq(expected, actual, check_shape=False, check_chunks=False)

0 comments on commit 7373f4b

Please sign in to comment.