Skip to content

Commit

Permalink
(feat): read_elem_as_dask method (#1469)
Browse files Browse the repository at this point in the history
* (feat): `read_elem_lazy` method

* (revert): error message

* (refactor): declare `is_csc` reading elem directly in h5

* (chore): `read_elem_lazy` -> `read_elem_as_dask`

* (chore): remove string handling

* (refactor): use `elem` for h5 where posssble

* (chore): remove invlaud syntax

* (fix): put dask import inside function

* (refactor): try maybe open?

* (fix): revert `encoding-version`

* (chore): document `create_sparse_store` test function

* (chore): sort indices to prevent warning

* (fix): remove utility function `make_dask_array`

* (chore): `read_sparse_as_dask_h5` -> `read_sparse_as_dask`

* (feat): make params of `h5_chunks` and `stride`

* (chore): add distributed test

* (fix): `TypeVar` bind

* (chore): release note

* (chore): `0.10.8` -> `0.11.0`

* (fix): `ruff` for default `pytest.fixture` `scope`

* Apply suggestions from code review

Co-authored-by: Philipp A. <[email protected]>

* (fix): `Any` to `DaskArray`

* (fix): type `make_index` + fix undeclared

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rest

* (fix): use `chunks` kwarg

* (feat): expose `chunks` as an option to `read_elem_as_dask` via `dataset_kwargs`

* (fix): `test_read_dispatched_null_case` test

* (fix): disallowed spread syntax?

* (refactor): reuse `compute_chunk_layout_for_axis_shape` functionality

* (fix): remove unneeded `slice` arguments

* (fix): revert message

* (refactor): `make_index` -> `make_block_indexer`

* (fix): export from `experimental`

* (fix): `callback` signature for `test_read_dispatched_null_case

* (chore): `get_elem_name` helper

* (chore): use `H5Group` consistently

* (refactor): make `chunks` public facing API instead of `dataset_kwargs`

* (fix): regsiter for group not array

* (chore): add warning test

* (chore): make arg order consistent

* (feat): add `callback` typing for `read_dispatched`

* (chore): use `npt.NDArray`

* (fix): remove uneceesary union

* (chore): release note

* (fix); try protocol docs

* (feat): create `InMemoryElem` + `DictElemType` to remove `Any`

* (chore): refactor `DictElemType` -> `InMemoryArrayOrScalarType` for reuse

* (fix): use `Union`

* (fix): more `Union`

* (refactor): `InMemoryElem` -> `InMemoryReadElem`

* (chore): add needed types to public export + docs fix

* (chore): type `write_elem` functions

* (chore): create `write_callback` protocol

* (chore): export + docs

* (fix): add string descriptions

* (fix): try sphinx protocol doc

* (fix): try ignoring exports

* (fix): remap callback internal usages

* (fix): add docstring

* Discard changes to pyproject.toml

* re-add dep

* Fix docs

* Almost works

* works!

* (chore): use pascal-case

* (feat): type read/write funcs in callback

* (fix): use generic for `Read` as well.

* (fix): need more aliases

* Split table, format

* (refactor): move to `_types` file

* bump scanpydoc

* Some basic syntax fixes

* (fix): change `Read{Callback}` type for kwargs

* (chore): test `chunks `argument

* (fix): type `read_recarray`

* (fix): `GroupyStorageType` not `StorageType`

* (fix): little type fixes

* (fix): clarify `H5File` typing

* (fix): dask doc

* (fix): dask docs

* (fix): typing

* (fix): handle case when `chunks` is `None`

* (feat): add string-array reading

* (fix): remove `string-array` because it is not tested

* (refactor): clean up tests

* (fix): overfetching problem

* Fix circular import

* add some typing

* fix mapping types

* Fix Read/Write

* Fix one more

* unify names

* claift ReadCallback signature

* Fix type aliases

* (fix): clean up typing to use `RWAble`

* (fix): use `Union`

* (fix): add qualname override

* (fix): ignore dask and masked array

* (fix): ignore erroneous class warning

* (fix): upgrade `scanpydoc`

* (fix): use `MutableMapping` instead of `dict` due to broken docstring

* Add data docs

* Revert "(fix): use `MutableMapping` instead of `dict` due to broken docstring"

This reverts commit 79d3fdc.

* (fix): add clarification

* Simplify

* (fix): remove double `dask` intersphinx

* (fix): remove `_types.DaskArray` from type checking block

* (refactor): use `block_info` for resolving fetch location

* (fix): dtype for reading

* (fix): ignore import cycle problem (why??)

* (fix): add issue

* (fix): subclass `Reader` to remove `datasetkwargs`

* (fix): add message tp errpr

* Update tests/test_io_elementwise.py

Co-authored-by: Isaac Virshup <[email protected]>

* (fix): correct `self.callback` check

* (fix): erroneous diffs

* (fix): extra `read_elem` `dataset_kwargs`

* (fix): remove more `dataset_kwargs` nonsense

* (chore): add docs

* (fix): use `block_info` for dense

* (fix): more erroneous diffs

* (fix): use context again

* (fix): change size by dimension in tests

* (refactor): clean up `get_elem_name`

* (fix): try new sphinx for error

* (fix): return type

* (fix): protocol for reading

* (fix): bring back ignored warning

* Fix docs

* almost fix typing

* add wrapper

* move into type checking

* (fix): small type fxes

* block info types

* simplify

* rename

* simplify more

---------

Co-authored-by: Philipp A. <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Isaac Virshup <[email protected]>
  • Loading branch information
4 people authored Jul 23, 2024
1 parent 2e016cf commit 6d70535
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 69 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Low level methods for reading and writing elements of an `AnnData` object to a s
experimental.read_elem
experimental.write_elem
experimental.read_elem_as_dask
```

Utilities for customizing the IO process:
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/0.11.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Add `should_remove_unused_categories` option to `anndata.settings` to override current behavior. Default is `True` (i.e., previous behavior). Please refer to the [documentation](https://anndata.readthedocs.io/en/latest/generated/anndata.settings.html) for usage. {pr}`1340` {user}`ilan-gold`
* `scipy.sparse.csr_array` and `scipy.sparse.csc_array` are now supported when constructing `AnnData` objects {pr}`1028` {user}`ilan-gold` {user}`isaac-virshup`
* Add `should_check_uniqueness` option to `anndata.settings` to override current behavior. Default is `True` (i.e., previous behavior). Please refer to the [documentation](https://anndata.readthedocs.io/en/latest/generated/anndata.settings.html) for usage. {pr}`1507` {user}`ilan-gold`
* Add :func:`~anndata.experimental.read_elem_as_dask` function to handle i/o with sparse and dense arrays {pr}`1469` {user}`ilan-gold`
* Add functionality to write from GPU {class}`dask.array.Array` to disk {pr}`1550` {user}`ilan-gold`

#### Bugfix
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ dev = [
"pytest-xdist",
]
doc = [
"sphinx>=4.4",
"sphinx>=7.4.6",
"sphinx-book-theme>=1.1.0",
"sphinx-autodoc-typehints>=2.2.0",
"sphinx-issues",
Expand Down
17 changes: 16 additions & 1 deletion src/anndata/_core/file_backing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import weakref
from collections.abc import Mapping
from functools import singledispatch
from pathlib import Path
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING

import h5py
Expand Down Expand Up @@ -175,3 +175,18 @@ def _(x):
@filename.register(ZarrGroup)
def _(x):
return x.store.path


@singledispatch
def get_elem_name(x):
raise NotImplementedError(f"Not implemented for {type(x)}")


@get_elem_name.register(h5py.Group)
def _(x):
return x.name


@get_elem_name.register(ZarrGroup)
def _(x):
return PurePosixPath(x.path).name
6 changes: 5 additions & 1 deletion src/anndata/_io/specs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

from . import methods
from . import lazy_methods, methods
from .registry import (
_LAZY_REGISTRY, # noqa: F401
_REGISTRY, # noqa: F401
IOSpec,
Reader,
Writer,
get_spec,
read_elem,
read_elem_as_dask,
write_elem,
)

__all__ = [
"methods",
"lazy_methods",
"write_elem",
"get_spec",
"read_elem",
"read_elem_as_dask",
"Reader",
"Writer",
"IOSpec",
Expand Down
164 changes: 164 additions & 0 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING

import h5py
import numpy as np
from scipy import sparse

import anndata as ad

from ..._core.file_backing import filename, get_elem_name
from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup
from .registry import _LAZY_REGISTRY, IOSpec

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Literal, ParamSpec, TypeVar

from ..._core.sparse_dataset import CSCDataset, CSRDataset
from ..._types import ArrayStorageType, StorageType
from ...compat import DaskArray
from .registry import DaskReader

BlockInfo = Mapping[
Literal[None],
dict[str, Sequence[tuple[int, int]]],
]

P = ParamSpec("P")
R = TypeVar("R")


@contextmanager
def maybe_open_h5(
path_or_group: Path | ZarrGroup, elem_name: str
) -> Generator[StorageType, None, None]:
if not isinstance(path_or_group, Path):
yield path_or_group
return
file = h5py.File(path_or_group, "r")
try:
yield file[elem_name]
finally:
file.close()


_DEFAULT_STRIDE = 1000


def compute_chunk_layout_for_axis_shape(
chunk_axis_shape: int, full_axis_shape: int
) -> tuple[int, ...]:
n_strides, rest = np.divmod(full_axis_shape, chunk_axis_shape)
chunk = (chunk_axis_shape,) * n_strides
if rest > 0:
chunk += (rest,)
return chunk


def make_dask_chunk(
path_or_group: Path | ZarrGroup,
elem_name: str,
block_info: BlockInfo | None = None,
*,
wrap: Callable[[ArrayStorageType], ArrayStorageType]
| Callable[[H5Group | ZarrGroup], CSRDataset | CSCDataset] = lambda g: g,
):
if block_info is None:
msg = "Block info is required"
raise ValueError(msg)
# We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed`
# https://github.com/scverse/anndata/issues/1105
with maybe_open_h5(path_or_group, elem_name) as f:
mtx = wrap(f)
idx = tuple(
slice(start, stop) for start, stop in block_info[None]["array-location"]
)
chunk = mtx[idx]
return chunk


@_LAZY_REGISTRY.register_read(H5Group, IOSpec("csc_matrix", "0.1.0"))
@_LAZY_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0"))
@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("csr_matrix", "0.1.0"))
def read_sparse_as_dask(
elem: H5Group | ZarrGroup,
*,
_reader: DaskReader,
chunks: tuple[int, ...] | None = None, # only tuple[int, int] is supported here
) -> DaskArray:
import dask.array as da

path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem
elem_name = get_elem_name(elem)
shape: tuple[int, int] = tuple(elem.attrs["shape"])
dtype = elem["data"].dtype
is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix"

stride: int = _DEFAULT_STRIDE
major_dim, minor_dim = (1, 0) if is_csc else (0, 1)
if chunks is not None:
if len(chunks) != 2:
raise ValueError("`chunks` must be a tuple of two integers")
if chunks[minor_dim] != shape[minor_dim]:
raise ValueError(
"Only the major axis can be chunked. "
f"Try setting chunks to {((-1, _DEFAULT_STRIDE) if is_csc else (_DEFAULT_STRIDE, -1))}"
)
stride = chunks[major_dim]

shape_minor, shape_major = shape if is_csc else shape[::-1]
chunks_major = compute_chunk_layout_for_axis_shape(stride, shape_major)
chunks_minor = (shape_minor,)
chunk_layout = (
(chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor)
)
memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix
make_chunk = partial(
make_dask_chunk, path_or_group, elem_name, wrap=ad.experimental.sparse_dataset
)
da_mtx = da.map_blocks(
make_chunk,
dtype=dtype,
chunks=chunk_layout,
meta=memory_format((0, 0), dtype=dtype),
)
return da_mtx


@_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0"))
def read_h5_array(
elem: H5Array, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None
) -> DaskArray:
import dask.array as da

path = Path(elem.file.filename)
elem_name: str = elem.name
shape = tuple(elem.shape)
dtype = elem.dtype
chunks: tuple[int, ...] = (
chunks if chunks is not None else (_DEFAULT_STRIDE,) * len(shape)
)

chunk_layout = tuple(
compute_chunk_layout_for_axis_shape(chunks[i], shape[i])
for i in range(len(shape))
)

make_chunk = partial(make_dask_chunk, path, elem_name)
return da.map_blocks(make_chunk, dtype=dtype, chunks=chunk_layout)


@_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0"))
def read_zarr_array(
elem: ZarrArray, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None
) -> DaskArray:
chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks
import dask.array as da

return da.from_zarr(elem, chunks=chunks)
Loading

0 comments on commit 6d70535

Please sign in to comment.