Skip to content

Commit

Permalink
Handle kwargs better in store (#14)
Browse files Browse the repository at this point in the history
* Handle kwargs better in store

* Add a test

* Only run with processes executor on Python 3.11

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tomwhite and pre-commit-ci[bot] authored Aug 7, 2024
1 parent 0b35a57 commit 495bcf3
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
1 change: 0 additions & 1 deletion cubed_xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from importlib.metadata import version


try:
__version__ = version("cubed-xarray")
except Exception:
Expand Down
28 changes: 23 additions & 5 deletions cubed_xarray/cubedmanager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Union

import numpy as np

from tlz import partition

from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint


if TYPE_CHECKING:
from xarray.core.types import T_Chunks, T_NormalizedChunks
from cubed import Array as CubedArray
from xarray.core.types import T_Chunks, T_NormalizedChunks


class CubedManager(ChunkManagerEntrypoint["CubedArray"]):
Expand Down Expand Up @@ -204,6 +201,27 @@ def store(
"""Used when writing to any backend."""
from cubed.core.ops import store

compute = kwargs.pop("compute", True)
if not compute:
raise NotImplementedError("Delayed compute is not supported.")

lock = kwargs.pop("lock", None)
if lock:
raise NotImplementedError("Locking is not supported.")

regions = kwargs.pop("regions", None)
if regions:
# regions is either a tuple of slices or a collection of tuples of slices
if isinstance(regions, tuple):
regions = [regions]
for t in regions:
if not all(r == slice(None) for r in t):
raise NotImplementedError(
"Only whole slices are supported for regions."
)

kwargs.pop("flush", None) # not used

return store(
sources,
targets,
Expand Down
53 changes: 47 additions & 6 deletions cubed_xarray/tests/test_wrapping.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
import sys

import cubed
import pytest
import xarray as xr
from cubed.runtime.create import create_executor
from xarray.namedarray.parallelcompat import list_chunkmanagers
import cubed
from xarray.tests import assert_allclose, create_test_data

from cubed_xarray.cubedmanager import CubedManager

EXECUTORS = [create_executor("single-threaded")]

if sys.version_info >= (3, 11):
EXECUTORS.append(create_executor("processes"))


@pytest.fixture(
scope="module",
params=EXECUTORS,
ids=[executor.name for executor in EXECUTORS],
)
def executor(request):
return request.param


class TestDiscoverCubedManager:
def test_list_cubedmanager(self):
chunkmanagers = list_chunkmanagers()
assert 'cubed' in chunkmanagers
assert isinstance(chunkmanagers['cubed'], CubedManager)
assert "cubed" in chunkmanagers
assert isinstance(chunkmanagers["cubed"], CubedManager)

def test_chunk(self):
da = xr.DataArray([1, 2], dims='x')
chunked = da.chunk(x=1, chunked_array_type='cubed')
da = xr.DataArray([1, 2], dims="x")
chunked = da.chunk(x=1, chunked_array_type="cubed")
assert isinstance(chunked.data, cubed.Array)
assert chunked.chunksizes == {'x': (1, 1)}
assert chunked.chunksizes == {"x": (1, 1)}

# TODO test cubed is default when dask not installed

# TODO test dask is default over cubed when both installed


def test_to_zarr(tmpdir, executor):
spec = cubed.Spec(allowed_mem="200MB", executor=executor)

original = create_test_data().chunk(
chunked_array_type="cubed", from_array_kwargs={"spec": spec}
)

filename = tmpdir / "out.zarr"
original.to_zarr(filename)

with xr.open_dataset(
filename,
chunks="auto",
engine="zarr",
chunked_array_type="cubed",
from_array_kwargs={"spec": spec},
) as restored:
assert isinstance(restored.var1.data, cubed.Array)
computed = restored.compute()
assert_allclose(original, computed)

0 comments on commit 495bcf3

Please sign in to comment.