From 6d705355fa786474076eb2147718913c7e7933de Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 23 Jul 2024 10:39:07 +0200 Subject: [PATCH] (feat): `read_elem_as_dask` method (#1469) * (feat): `read_elem_lazy` method * (revert): error message * (refactor): declare `is_csc` reading elem directly in h5 * (chore): `read_elem_lazy` -> `read_elem_as_dask` * (chore): remove string handling * (refactor): use `elem` for h5 where posssble * (chore): remove invlaud syntax * (fix): put dask import inside function * (refactor): try maybe open? * (fix): revert `encoding-version` * (chore): document `create_sparse_store` test function * (chore): sort indices to prevent warning * (fix): remove utility function `make_dask_array` * (chore): `read_sparse_as_dask_h5` -> `read_sparse_as_dask` * (feat): make params of `h5_chunks` and `stride` * (chore): add distributed test * (fix): `TypeVar` bind * (chore): release note * (chore): `0.10.8` -> `0.11.0` * (fix): `ruff` for default `pytest.fixture` `scope` * Apply suggestions from code review Co-authored-by: Philipp A. * (fix): `Any` to `DaskArray` * (fix): type `make_index` + fix undeclared * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rest * (fix): use `chunks` kwarg * (feat): expose `chunks` as an option to `read_elem_as_dask` via `dataset_kwargs` * (fix): `test_read_dispatched_null_case` test * (fix): disallowed spread syntax? * (refactor): reuse `compute_chunk_layout_for_axis_shape` functionality * (fix): remove unneeded `slice` arguments * (fix): revert message * (refactor): `make_index` -> `make_block_indexer` * (fix): export from `experimental` * (fix): `callback` signature for `test_read_dispatched_null_case * (chore): `get_elem_name` helper * (chore): use `H5Group` consistently * (refactor): make `chunks` public facing API instead of `dataset_kwargs` * (fix): regsiter for group not array * (chore): add warning test * (chore): make arg order consistent * (feat): add `callback` typing for `read_dispatched` * (chore): use `npt.NDArray` * (fix): remove uneceesary union * (chore): release note * (fix); try protocol docs * (feat): create `InMemoryElem` + `DictElemType` to remove `Any` * (chore): refactor `DictElemType` -> `InMemoryArrayOrScalarType` for reuse * (fix): use `Union` * (fix): more `Union` * (refactor): `InMemoryElem` -> `InMemoryReadElem` * (chore): add needed types to public export + docs fix * (chore): type `write_elem` functions * (chore): create `write_callback` protocol * (chore): export + docs * (fix): add string descriptions * (fix): try sphinx protocol doc * (fix): try ignoring exports * (fix): remap callback internal usages * (fix): add docstring * Discard changes to pyproject.toml * re-add dep * Fix docs * Almost works * works! * (chore): use pascal-case * (feat): type read/write funcs in callback * (fix): use generic for `Read` as well. * (fix): need more aliases * Split table, format * (refactor): move to `_types` file * bump scanpydoc * Some basic syntax fixes * (fix): change `Read{Callback}` type for kwargs * (chore): test `chunks `argument * (fix): type `read_recarray` * (fix): `GroupyStorageType` not `StorageType` * (fix): little type fixes * (fix): clarify `H5File` typing * (fix): dask doc * (fix): dask docs * (fix): typing * (fix): handle case when `chunks` is `None` * (feat): add string-array reading * (fix): remove `string-array` because it is not tested * (refactor): clean up tests * (fix): overfetching problem * Fix circular import * add some typing * fix mapping types * Fix Read/Write * Fix one more * unify names * claift ReadCallback signature * Fix type aliases * (fix): clean up typing to use `RWAble` * (fix): use `Union` * (fix): add qualname override * (fix): ignore dask and masked array * (fix): ignore erroneous class warning * (fix): upgrade `scanpydoc` * (fix): use `MutableMapping` instead of `dict` due to broken docstring * Add data docs * Revert "(fix): use `MutableMapping` instead of `dict` due to broken docstring" This reverts commit 79d3fdc54c775b88f6ac9c65e83fed08049c5484. * (fix): add clarification * Simplify * (fix): remove double `dask` intersphinx * (fix): remove `_types.DaskArray` from type checking block * (refactor): use `block_info` for resolving fetch location * (fix): dtype for reading * (fix): ignore import cycle problem (why??) * (fix): add issue * (fix): subclass `Reader` to remove `datasetkwargs` * (fix): add message tp errpr * Update tests/test_io_elementwise.py Co-authored-by: Isaac Virshup * (fix): correct `self.callback` check * (fix): erroneous diffs * (fix): extra `read_elem` `dataset_kwargs` * (fix): remove more `dataset_kwargs` nonsense * (chore): add docs * (fix): use `block_info` for dense * (fix): more erroneous diffs * (fix): use context again * (fix): change size by dimension in tests * (refactor): clean up `get_elem_name` * (fix): try new sphinx for error * (fix): return type * (fix): protocol for reading * (fix): bring back ignored warning * Fix docs * almost fix typing * add wrapper * move into type checking * (fix): small type fxes * block info types * simplify * rename * simplify more --------- Co-authored-by: Philipp A. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Isaac Virshup --- docs/api.md | 1 + docs/release-notes/0.11.0.md | 1 + pyproject.toml | 2 +- src/anndata/_core/file_backing.py | 17 ++- src/anndata/_io/specs/__init__.py | 6 +- src/anndata/_io/specs/lazy_methods.py | 164 ++++++++++++++++++++++ src/anndata/_io/specs/registry.py | 96 ++++++++----- src/anndata/_types.py | 50 +++++-- src/anndata/experimental/__init__.py | 3 +- tests/test_io_elementwise.py | 188 +++++++++++++++++++++++--- 10 files changed, 459 insertions(+), 69 deletions(-) create mode 100644 src/anndata/_io/specs/lazy_methods.py diff --git a/docs/api.md b/docs/api.md index 36ebeac88..92139fe06 100644 --- a/docs/api.md +++ b/docs/api.md @@ -121,6 +121,7 @@ Low level methods for reading and writing elements of an `AnnData` object to a s experimental.read_elem experimental.write_elem + experimental.read_elem_as_dask ``` Utilities for customizing the IO process: diff --git a/docs/release-notes/0.11.0.md b/docs/release-notes/0.11.0.md index 8bb61de99..618d4f549 100644 --- a/docs/release-notes/0.11.0.md +++ b/docs/release-notes/0.11.0.md @@ -8,6 +8,7 @@ * Add `should_remove_unused_categories` option to `anndata.settings` to override current behavior. Default is `True` (i.e., previous behavior). Please refer to the [documentation](https://anndata.readthedocs.io/en/latest/generated/anndata.settings.html) for usage. {pr}`1340` {user}`ilan-gold` * `scipy.sparse.csr_array` and `scipy.sparse.csc_array` are now supported when constructing `AnnData` objects {pr}`1028` {user}`ilan-gold` {user}`isaac-virshup` * Add `should_check_uniqueness` option to `anndata.settings` to override current behavior. Default is `True` (i.e., previous behavior). Please refer to the [documentation](https://anndata.readthedocs.io/en/latest/generated/anndata.settings.html) for usage. {pr}`1507` {user}`ilan-gold` +* Add :func:`~anndata.experimental.read_elem_as_dask` function to handle i/o with sparse and dense arrays {pr}`1469` {user}`ilan-gold` * Add functionality to write from GPU {class}`dask.array.Array` to disk {pr}`1550` {user}`ilan-gold` #### Bugfix diff --git a/pyproject.toml b/pyproject.toml index 43e5ab416..ef97699f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dev = [ "pytest-xdist", ] doc = [ - "sphinx>=4.4", + "sphinx>=7.4.6", "sphinx-book-theme>=1.1.0", "sphinx-autodoc-typehints>=2.2.0", "sphinx-issues", diff --git a/src/anndata/_core/file_backing.py b/src/anndata/_core/file_backing.py index 6346100ba..dbef41d5d 100644 --- a/src/anndata/_core/file_backing.py +++ b/src/anndata/_core/file_backing.py @@ -3,7 +3,7 @@ import weakref from collections.abc import Mapping from functools import singledispatch -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING import h5py @@ -175,3 +175,18 @@ def _(x): @filename.register(ZarrGroup) def _(x): return x.store.path + + +@singledispatch +def get_elem_name(x): + raise NotImplementedError(f"Not implemented for {type(x)}") + + +@get_elem_name.register(h5py.Group) +def _(x): + return x.name + + +@get_elem_name.register(ZarrGroup) +def _(x): + return PurePosixPath(x.path).name diff --git a/src/anndata/_io/specs/__init__.py b/src/anndata/_io/specs/__init__.py index ceff8b3d6..5eadfdb50 100644 --- a/src/anndata/_io/specs/__init__.py +++ b/src/anndata/_io/specs/__init__.py @@ -1,21 +1,25 @@ from __future__ import annotations -from . import methods +from . import lazy_methods, methods from .registry import ( + _LAZY_REGISTRY, # noqa: F401 _REGISTRY, # noqa: F401 IOSpec, Reader, Writer, get_spec, read_elem, + read_elem_as_dask, write_elem, ) __all__ = [ "methods", + "lazy_methods", "write_elem", "get_spec", "read_elem", + "read_elem_as_dask", "Reader", "Writer", "IOSpec", diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py new file mode 100644 index 000000000..8a1b31e6b --- /dev/null +++ b/src/anndata/_io/specs/lazy_methods.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING + +import h5py +import numpy as np +from scipy import sparse + +import anndata as ad + +from ..._core.file_backing import filename, get_elem_name +from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup +from .registry import _LAZY_REGISTRY, IOSpec + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Mapping, Sequence + from typing import Literal, ParamSpec, TypeVar + + from ..._core.sparse_dataset import CSCDataset, CSRDataset + from ..._types import ArrayStorageType, StorageType + from ...compat import DaskArray + from .registry import DaskReader + + BlockInfo = Mapping[ + Literal[None], + dict[str, Sequence[tuple[int, int]]], + ] + + P = ParamSpec("P") + R = TypeVar("R") + + +@contextmanager +def maybe_open_h5( + path_or_group: Path | ZarrGroup, elem_name: str +) -> Generator[StorageType, None, None]: + if not isinstance(path_or_group, Path): + yield path_or_group + return + file = h5py.File(path_or_group, "r") + try: + yield file[elem_name] + finally: + file.close() + + +_DEFAULT_STRIDE = 1000 + + +def compute_chunk_layout_for_axis_shape( + chunk_axis_shape: int, full_axis_shape: int +) -> tuple[int, ...]: + n_strides, rest = np.divmod(full_axis_shape, chunk_axis_shape) + chunk = (chunk_axis_shape,) * n_strides + if rest > 0: + chunk += (rest,) + return chunk + + +def make_dask_chunk( + path_or_group: Path | ZarrGroup, + elem_name: str, + block_info: BlockInfo | None = None, + *, + wrap: Callable[[ArrayStorageType], ArrayStorageType] + | Callable[[H5Group | ZarrGroup], CSRDataset | CSCDataset] = lambda g: g, +): + if block_info is None: + msg = "Block info is required" + raise ValueError(msg) + # We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed` + # https://github.com/scverse/anndata/issues/1105 + with maybe_open_h5(path_or_group, elem_name) as f: + mtx = wrap(f) + idx = tuple( + slice(start, stop) for start, stop in block_info[None]["array-location"] + ) + chunk = mtx[idx] + return chunk + + +@_LAZY_REGISTRY.register_read(H5Group, IOSpec("csc_matrix", "0.1.0")) +@_LAZY_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0")) +@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0")) +@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("csr_matrix", "0.1.0")) +def read_sparse_as_dask( + elem: H5Group | ZarrGroup, + *, + _reader: DaskReader, + chunks: tuple[int, ...] | None = None, # only tuple[int, int] is supported here +) -> DaskArray: + import dask.array as da + + path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem + elem_name = get_elem_name(elem) + shape: tuple[int, int] = tuple(elem.attrs["shape"]) + dtype = elem["data"].dtype + is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix" + + stride: int = _DEFAULT_STRIDE + major_dim, minor_dim = (1, 0) if is_csc else (0, 1) + if chunks is not None: + if len(chunks) != 2: + raise ValueError("`chunks` must be a tuple of two integers") + if chunks[minor_dim] != shape[minor_dim]: + raise ValueError( + "Only the major axis can be chunked. " + f"Try setting chunks to {((-1, _DEFAULT_STRIDE) if is_csc else (_DEFAULT_STRIDE, -1))}" + ) + stride = chunks[major_dim] + + shape_minor, shape_major = shape if is_csc else shape[::-1] + chunks_major = compute_chunk_layout_for_axis_shape(stride, shape_major) + chunks_minor = (shape_minor,) + chunk_layout = ( + (chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor) + ) + memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix + make_chunk = partial( + make_dask_chunk, path_or_group, elem_name, wrap=ad.experimental.sparse_dataset + ) + da_mtx = da.map_blocks( + make_chunk, + dtype=dtype, + chunks=chunk_layout, + meta=memory_format((0, 0), dtype=dtype), + ) + return da_mtx + + +@_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0")) +def read_h5_array( + elem: H5Array, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None +) -> DaskArray: + import dask.array as da + + path = Path(elem.file.filename) + elem_name: str = elem.name + shape = tuple(elem.shape) + dtype = elem.dtype + chunks: tuple[int, ...] = ( + chunks if chunks is not None else (_DEFAULT_STRIDE,) * len(shape) + ) + + chunk_layout = tuple( + compute_chunk_layout_for_axis_shape(chunks[i], shape[i]) + for i in range(len(shape)) + ) + + make_chunk = partial(make_dask_chunk, path, elem_name) + return da.map_blocks(make_chunk, dtype=dtype, chunks=chunk_layout) + + +@_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0")) +def read_zarr_array( + elem: ZarrArray, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None +) -> DaskArray: + chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks + import dask.array as da + + return da.from_zarr(elem, chunks=chunks) diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 6268f5ea7..e3003cc52 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -1,27 +1,27 @@ from __future__ import annotations +import warnings from collections.abc import Mapping from dataclasses import dataclass from functools import partial, singledispatch, wraps from types import MappingProxyType -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar from anndata._io.utils import report_read_key_on_error, report_write_key_on_error +from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal from anndata.compat import DaskArray, _read_attr if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable - from typing import Any, TypeVar + from typing import Any from anndata._core.storage import StorageType from anndata._types import ( GroupStorageType, InMemoryElem, - Read, ReadCallback, Write, WriteCallback, - _ReadInternal, _WriteInternal, ) @@ -78,9 +78,13 @@ def wrapper(g: GroupStorageType, k: str, *args, **kwargs): return decorator -class IORegistry: +_R = TypeVar("_R", _ReadInternal, _ReadDaskInternal) +R = TypeVar("R", Read, ReadDask) + + +class IORegistry(Generic[_R, R]): def __init__(self): - self.read: dict[tuple[type, IOSpec, frozenset[str]], _ReadInternal] = {} + self.read: dict[tuple[type, IOSpec, frozenset[str]], _R] = {} self.read_partial: dict[tuple[type, IOSpec, frozenset[str]], Callable] = {} self.write: dict[ tuple[type, type | tuple[type, str], frozenset[str]], _WriteInternal @@ -145,7 +149,7 @@ def register_read( src_type: type, spec: IOSpec | Mapping[str, str], modifiers: Iterable[str] = frozenset(), - ) -> Callable[[_ReadInternal[T]], _ReadInternal[T]]: + ) -> Callable[[_R], _R]: spec = proc_spec(spec) modifiers = frozenset(modifiers) @@ -162,11 +166,9 @@ def get_read( modifiers: frozenset[str] = frozenset(), *, reader: Reader, - ) -> Read: + ) -> R: if (src_type, spec, modifiers) not in self.read: - raise IORegistryError._from_read_parts( - "read", _REGISTRY.read, src_type, spec - ) + raise IORegistryError._from_read_parts("read", self.read, src_type, spec) internal = self.read[(src_type, spec, modifiers)] return partial(internal, _reader=reader) @@ -197,7 +199,7 @@ def get_partial_read( return self.read_partial[(src_type, spec, modifiers)] else: raise IORegistryError._from_read_parts( - "read_partial", _REGISTRY.read_partial, src_type, spec + "read_partial", self.read_partial, src_type, spec ) def get_spec(self, elem: Any) -> IOSpec: @@ -210,7 +212,8 @@ def get_spec(self, elem: Any) -> IOSpec: return self.write_specs[type(elem)] -_REGISTRY = IORegistry() +_REGISTRY: IORegistry[_ReadInternal, Read] = IORegistry() +_LAZY_REGISTRY: IORegistry[_ReadDaskInternal, ReadDask] = IORegistry() @singledispatch @@ -271,12 +274,34 @@ def read_elem( """Read an element from a store. See exported function for more details.""" iospec = get_spec(elem) - read_func = self.registry.get_read(type(elem), iospec, modifiers, reader=self) + read_func: Read = self.registry.get_read( + type(elem), iospec, modifiers, reader=self + ) if self.callback is None: return read_func(elem) return self.callback(read_func, elem.name, elem, iospec=iospec) +class DaskReader(Reader): + @report_read_key_on_error + def read_elem( + self, + elem: StorageType, + modifiers: frozenset[str] = frozenset(), + chunks: tuple[int, ...] | None = None, + ) -> DaskArray: + """Read a dask element from a store. See exported function for more details.""" + + iospec = get_spec(elem) + read_func: ReadDask = self.registry.get_read( + type(elem), iospec, modifiers, reader=self + ) + if self.callback is not None: + msg = "Dask reading does not use a callback. Ignoring callback." + warnings.warn(msg, stacklevel=2) + return read_func(elem, chunks=chunks) + + class Writer: def __init__(self, registry: IORegistry, callback: WriteCallback | None = None): self.registry = registry @@ -353,6 +378,31 @@ def read_elem(elem: StorageType) -> InMemoryElem: return Reader(_REGISTRY).read_elem(elem) +def read_elem_as_dask( + elem: StorageType, chunks: tuple[int, ...] | None = None +) -> DaskArray: + """ + Read an element from a store lazily. + + Assumes that the element is encoded using the anndata encoding. This function will + determine the encoded type using the encoding metadata stored in elem's attributes. + + + Parameters + ---------- + elem + The stored element. + chunks, optional + length `n`, the same `n` as the size of the underlying array. + Note that the minor axis dimension must match the shape for sparse. + + Returns + ------- + DaskArray + """ + return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks) + + def write_elem( store: GroupStorageType, k: str, @@ -393,21 +443,3 @@ def read_elem_partial( type(elem), get_spec(elem), frozenset(modifiers) ) return read_partial(elem, items=items, indices=indices) - - -@singledispatch -def elem_key(elem) -> str: - return elem.name - - -# raise NotImplementedError() - -# @elem_key.register(ZarrGroup) -# @elem_key.register(ZarrArray) -# def _(elem): -# return elem.name - -# @elem_key.register(H5Array) -# @elem_key.register(H5Group) -# def _(elem): -# re diff --git a/src/anndata/_types.py b/src/anndata/_types.py index e0b663f16..3549152f5 100644 --- a/src/anndata/_types.py +++ b/src/anndata/_types.py @@ -31,8 +31,9 @@ from collections.abc import Mapping from typing import Any, TypeAlias + from anndata._io.specs.registry import DaskReader + from ._io.specs.registry import IOSpec, Reader, Writer - from .compat import H5File __all__ = [ "ArrayStorageType", @@ -80,28 +81,28 @@ ) InvariantInMemoryType = TypeVar("InvariantInMemoryType", bound="InMemoryElem") +SCo = TypeVar("SCo", covariant=True, bound=StorageType) +SCon = TypeVar("SCon", contravariant=True, bound=StorageType) -class _ReadInternal(Protocol[CovariantInMemoryType]): - def __call__( - self, - elem: StorageType | H5File, - *, - _reader: Reader, - ) -> CovariantInMemoryType: ... + +class _ReadInternal(Protocol[SCon, CovariantInMemoryType]): + def __call__(self, elem: SCon, *, _reader: Reader) -> CovariantInMemoryType: ... -class Read(Protocol[CovariantInMemoryType]): +class _ReadDaskInternal(Protocol[SCon]): def __call__( - self, - elem: StorageType | H5File, - ) -> CovariantInMemoryType: + self, elem: SCon, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None + ) -> DaskArray: ... + + +class Read(Protocol[SCon, CovariantInMemoryType]): + def __call__(self, elem: SCon) -> CovariantInMemoryType: """Low-level reading function for an element. Parameters ---------- elem The element to read from. - Returns ------- The element read from the store. @@ -109,6 +110,25 @@ def __call__( ... +class ReadDask(Protocol[SCon]): + def __call__( + self, elem: SCon, *, chunks: tuple[int, ...] | None = None + ) -> DaskArray: + """Low-level reading function for a dask element. + + Parameters + ---------- + elem + The element to read from. + chunks + The chunk size to be used. + Returns + ------- + The dask element read from the store. + """ + ... + + class _WriteInternal(Protocol[ContravariantInMemoryType]): def __call__( self, @@ -146,11 +166,11 @@ def __call__( ... -class ReadCallback(Protocol[InvariantInMemoryType]): +class ReadCallback(Protocol[SCo, InvariantInMemoryType]): def __call__( self, /, - read_func: Read[InvariantInMemoryType], + read_func: Read[SCo, InvariantInMemoryType], elem_name: str, elem: StorageType, *, diff --git a/src/anndata/experimental/__init__.py b/src/anndata/experimental/__init__.py index 904dd5807..7ef8f6adc 100644 --- a/src/anndata/experimental/__init__.py +++ b/src/anndata/experimental/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from anndata._core.sparse_dataset import CSCDataset, CSRDataset, sparse_dataset -from anndata._io.specs import IOSpec, read_elem, write_elem +from anndata._io.specs import IOSpec, read_elem, read_elem_as_dask, write_elem from .._core.storage import StorageType from .._types import InMemoryElem as _InMemoryElem @@ -23,6 +23,7 @@ "AnnLoader", "read_elem", "write_elem", + "read_elem_as_dask", "read_dispatched", "write_dispatched", "IOSpec", diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index c2d34223d..62284a0c9 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -16,7 +16,14 @@ from scipy import sparse import anndata as ad -from anndata._io.specs import _REGISTRY, IOSpec, get_spec, read_elem, write_elem +from anndata._io.specs import ( + _REGISTRY, + IOSpec, + get_spec, + read_elem, + read_elem_as_dask, + write_elem, +) from anndata._io.specs.registry import IORegistryError from anndata.compat import ZarrGroup, _read_attr from anndata.tests.helpers import ( @@ -28,8 +35,12 @@ ) if TYPE_CHECKING: + from typing import Literal, TypeVar + from anndata.compat import H5Group + G = TypeVar("G", H5Group, ZarrGroup) + @pytest.fixture(params=["h5ad", "zarr"]) def diskfmt(request): @@ -53,6 +64,55 @@ def store(request, tmp_path) -> H5Group | ZarrGroup: file.close() +sparse_formats = ["csr", "csc"] +SIZE = 2500 + + +@pytest.fixture(params=sparse_formats) +def sparse_format(request): + return request.param + + +def create_dense_store(store, n_dims: int = 2): + X = np.random.randn(*[SIZE * (i + 1) for i in range(n_dims)]) + + write_elem(store, "X", X) + return store + + +def create_sparse_store( + sparse_format: Literal["csc", "csr"], store: G, shape=(SIZE, SIZE * 2) +) -> G: + """Returns a store + + Parameters + ---------- + sparse_format + store + + Returns + ------- + A store with a key, `X` that is simply a sparse matrix, and `X_dask` where that same array is wrapped by dask + """ + import dask.array as da + + X = sparse.random( + shape[0], + shape[1], + format=sparse_format, + density=0.01, + random_state=np.random.default_rng(), + ) + X_dask = da.from_array( + X, + chunks=(100 if format == "csr" else SIZE, SIZE * 2 if format == "csr" else 100), + ) + + write_elem(store, "X", X) + write_elem(store, "X_dask", X_dask) + return store + + @pytest.mark.parametrize( ("value", "encoding_type"), [ @@ -156,30 +216,122 @@ def test_io_spec_cupy(store, value, encoding_type, as_dask): assert get_spec(store[key]) == _REGISTRY.get_spec(value) -@pytest.mark.parametrize("sparse_format", ["csr", "csc"]) -def test_dask_write_sparse(store, sparse_format): - import dask.array as da +def test_dask_write_sparse(sparse_format, store): + x_sparse_store = create_sparse_store(sparse_format, store) + X_from_disk = read_elem(x_sparse_store["X"]) + X_dask_from_disk = read_elem(x_sparse_store["X_dask"]) - X = sparse.random( - 1000, - 1000, - format=sparse_format, - density=0.01, - random_state=np.random.default_rng(), + assert_equal(X_from_disk, X_dask_from_disk) + assert_equal(dict(x_sparse_store["X"].attrs), dict(x_sparse_store["X_dask"].attrs)) + + assert x_sparse_store["X_dask/indptr"].dtype == np.int64 + assert x_sparse_store["X_dask/indices"].dtype == np.int64 + + +def test_read_lazy_2d_dask(sparse_format, store): + arr_store = create_sparse_store(sparse_format, store) + X_dask_from_disk = read_elem_as_dask(arr_store["X"]) + X_from_disk = read_elem(arr_store["X"]) + + assert_equal(X_from_disk, X_dask_from_disk) + random_int_indices = np.random.randint(0, SIZE, (SIZE // 10,)) + random_int_indices.sort() + index_slice = slice(0, SIZE // 10) + for index in [random_int_indices, index_slice]: + assert_equal(X_from_disk[index, :], X_dask_from_disk[index, :]) + assert_equal(X_from_disk[:, index], X_dask_from_disk[:, index]) + random_bool_mask = np.random.randn(SIZE) > 0 + assert_equal( + X_from_disk[random_bool_mask, :], X_dask_from_disk[random_bool_mask, :] + ) + random_bool_mask = np.random.randn(SIZE * 2) > 0 + assert_equal( + X_from_disk[:, random_bool_mask], X_dask_from_disk[:, random_bool_mask] ) - X_dask = da.from_array(X, chunks=(100, 100)) - write_elem(store, "X", X) - write_elem(store, "X_dask", X_dask) + assert arr_store["X_dask/indptr"].dtype == np.int64 + assert arr_store["X_dask/indices"].dtype == np.int64 + + +@pytest.mark.parametrize( + ("n_dims", "chunks"), + [ + (1, (100,)), + (1, (400,)), + (2, (100, 100)), + (2, (400, 400)), + (2, (200, 400)), + (1, None), + (2, None), + ], +) +def test_read_lazy_subsets_nd_dask(store, n_dims, chunks): + arr_store = create_dense_store(store, n_dims) + X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + X_from_disk = read_elem(arr_store["X"]) + assert_equal(X_from_disk, X_dask_from_disk) + + random_int_indices = np.random.randint(0, SIZE, (SIZE // 10,)) + random_int_indices.sort() + random_bool_mask = np.random.randn(SIZE) > 0 + index_slice = slice(0, SIZE // 10) + for index in [random_int_indices, index_slice, random_bool_mask]: + assert_equal(X_from_disk[index], X_dask_from_disk[index]) + + +def test_read_lazy_h5_cluster(sparse_format, tmp_path): + import dask.distributed as dd + + with h5py.File(tmp_path / "test.h5", "w") as file: + store = file["/"] + arr_store = create_sparse_store(sparse_format, store) + X_dask_from_disk = read_elem_as_dask(arr_store["X"]) + X_from_disk = read_elem(arr_store["X"]) + with ( + dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster, + dd.Client(cluster) as _client, + ): + assert_equal(X_from_disk, X_dask_from_disk) - X_from_disk = read_elem(store["X"]) - X_dask_from_disk = read_elem(store["X_dask"]) +@pytest.mark.parametrize( + ("arr_type", "chunks"), + [ + ("dense", (100, 100)), + ("csc", (SIZE, 10)), + ("csr", (10, SIZE * 2)), + ("csc", None), + ("csr", None), + ], +) +def test_read_lazy_2d_chunk_kwargs(store, arr_type, chunks): + if arr_type == "dense": + arr_store = create_dense_store(store) + X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + else: + arr_store = create_sparse_store(arr_type, store) + X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + if chunks is not None: + assert X_dask_from_disk.chunksize == chunks + else: + minor_index = int(arr_type == "csr") + # assert that sparse chunks are set correctly by default + assert X_dask_from_disk.chunksize[minor_index] == SIZE * (1 + minor_index) + X_from_disk = read_elem(arr_store["X"]) assert_equal(X_from_disk, X_dask_from_disk) - assert_equal(dict(store["X"].attrs), dict(store["X_dask"].attrs)) - assert store["X_dask/indptr"].dtype == np.int64 - assert store["X_dask/indices"].dtype == np.int64 + +def test_read_lazy_bad_chunk_kwargs(tmp_path): + arr_type = "csr" + with h5py.File(tmp_path / "test.h5", "w") as file: + store = file["/"] + arr_store = create_sparse_store(arr_type, store) + with pytest.raises( + ValueError, match=r"`chunks` must be a tuple of two integers" + ): + read_elem_as_dask(arr_store["X"], chunks=(SIZE,)) + with pytest.raises(ValueError, match=r"Only the major axis can be chunked"): + read_elem_as_dask(arr_store["X"], chunks=(SIZE, 10)) @pytest.mark.parametrize("sparse_format", ["csr", "csc"])