Skip to content

Commit

Permalink
Allow rechunk to take a dict for chunks (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jun 27, 2023
1 parent 19ec275 commit 647a68b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 19 deletions.
7 changes: 5 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,19 @@ def wrap(*a, block_id=None, **kw):


def rechunk(x, chunks, target_store=None):
if x.chunks == normalize_chunks(chunks, x.shape, dtype=x.dtype):
normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
if x.chunks == normalized_chunks:
return x
# normalizing takes care of dict args for chunks
target_chunks = to_chunksize(normalized_chunks)
name = gensym()
spec = x.spec
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)
temp_store = new_temp_path(name=f"{name}-intermediate", spec=spec)
pipeline = primitive_rechunk(
x.zarray_maybe_lazy,
target_chunks=chunks,
target_chunks=target_chunks,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
target_store=target_store,
Expand Down
16 changes: 1 addition & 15 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from cubed.runtime.pipeline import spec_to_pipeline
from cubed.storage.zarr import lazy_empty
from cubed.vendor.rechunker.algorithm import rechunking_plan
from cubed.vendor.rechunker.api import (
_get_dims_from_zarr_array,
_shape_dict_to_tuple,
_validate_options,
)
from cubed.vendor.rechunker.api import _validate_options


def rechunk(
Expand Down Expand Up @@ -119,16 +115,6 @@ def _setup_array_rechunk(
# this is just a pass-through copy
target_chunks = source_chunks

if isinstance(target_chunks, dict):
array_dims = _get_dims_from_zarr_array(source_array)
try:
target_chunks = _shape_dict_to_tuple(array_dims, target_chunks)
except KeyError:
raise KeyError(
"You must explicitly specify each dimension size in target_chunks. "
f"Got array_dims {array_dims}, target_chunks {target_chunks}."
)

# TODO: rewrite to avoid the hard dependency on dask
max_mem = cubed.vendor.dask.utils.parse_bytes(max_mem)

Expand Down
5 changes: 3 additions & 2 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ def test_multiple_ops(spec, executor):
)


def test_rechunk(spec, executor):
@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}])
def test_rechunk(spec, executor, new_chunks):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
b = a.rechunk((1, 2))
b = a.rechunk(new_chunks)
assert_array_equal(
b.compute(executor=executor),
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
Expand Down

0 comments on commit 647a68b

Please sign in to comment.