diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 2d09106b0..360e7b146 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -16,7 +16,7 @@ from .registry import _LAZY_REGISTRY, IOSpec if TYPE_CHECKING: - from collections.abc import Callable, Generator, Mapping, Sequence + from collections.abc import Callable, Iterator, Mapping, Sequence from typing import Literal, ParamSpec, TypeVar from ..._core.sparse_dataset import _CSCDataset, _CSRDataset @@ -36,7 +36,7 @@ @contextmanager def maybe_open_h5( path_or_group: Path | ZarrGroup, elem_name: str -) -> Generator[StorageType, None, None]: +) -> Callable[[], Iterator[StorageType]]: if not isinstance(path_or_group, Path): yield path_or_group return @@ -67,13 +67,18 @@ def make_dask_chunk( *, wrap: Callable[[ArrayStorageType], ArrayStorageType] | Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g, + reopen: None | Callable[[], Iterator[StorageType]] = None, ): if block_info is None: msg = "Block info is required" raise ValueError(msg) # We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed` # https://github.com/scverse/anndata/issues/1105 - with maybe_open_h5(path_or_group, elem_name) as f: + with ( + contextmanager(reopen)() + if reopen is not None + else maybe_open_h5(path_or_group, elem_name) + ) as f: mtx = wrap(f) idx = tuple( slice(start, stop) for start, stop in block_info[None]["array-location"] @@ -91,6 +96,7 @@ def read_sparse_as_dask( *, _reader: DaskReader, chunks: tuple[int, ...] | None = None, # only tuple[int, int] is supported here + reopen: None | Callable[[], Iterator[StorageType]] = None, ) -> DaskArray: import dask.array as da @@ -120,7 +126,7 @@ def read_sparse_as_dask( ) memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix make_chunk = partial( - make_dask_chunk, path_or_group, elem_name, wrap=ad.sparse_dataset + make_dask_chunk, path_or_group, elem_name, wrap=ad.sparse_dataset, reopen=reopen ) da_mtx = da.map_blocks( make_chunk, @@ -133,7 +139,11 @@ def read_sparse_as_dask( @_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0")) def read_h5_array( - elem: H5Array, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None + elem: H5Array, + *, + _reader: DaskReader, + chunks: tuple[int, ...] | None = None, + reopen: None | Callable[[], Iterator[StorageType]] = None, ) -> DaskArray: import dask.array as da @@ -156,7 +166,11 @@ def read_h5_array( @_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0")) def read_zarr_array( - elem: ZarrArray, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None + elem: ZarrArray, + *, + _reader: DaskReader, + chunks: tuple[int, ...] | None = None, + reopen: None | Callable[[], Iterator[StorageType]] = None, ) -> DaskArray: chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks import dask.array as da diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 2cd21b5fc..d3f32dc3e 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -12,7 +12,7 @@ from anndata.compat import DaskArray, _read_attr if TYPE_CHECKING: - from collections.abc import Callable, Generator, Iterable + from collections.abc import Callable, Generator, Iterable, Iterator from typing import Any from anndata._types import ( @@ -289,6 +289,7 @@ def read_elem( elem: StorageType, modifiers: frozenset[str] = frozenset(), chunks: tuple[int, ...] | None = None, + reopen: None | Callable[[], Iterator[StorageType]] = None, ) -> DaskArray: """Read a dask element from a store. See exported function for more details.""" @@ -299,7 +300,7 @@ def read_elem( if self.callback is not None: msg = "Dask reading does not use a callback. Ignoring callback." warnings.warn(msg, stacklevel=2) - return read_func(elem, chunks=chunks) + return read_func(elem, chunks=chunks, reopen=reopen) class Writer: @@ -379,7 +380,9 @@ def read_elem(elem: StorageType) -> RWAble: def read_elem_as_dask( - elem: StorageType, chunks: tuple[int, ...] | None = None + elem: StorageType, + chunks: tuple[int, ...] | None = None, + reopen: None | Callable[[], Iterator[StorageType]] = None, ) -> DaskArray: """ Read an element from a store lazily. @@ -395,12 +398,13 @@ def read_elem_as_dask( chunks, optional length `n`, the same `n` as the size of the underlying array. Note that the minor axis dimension must match the shape for sparse. - + reopen, optional + A custom function for re-opening your store in the dask reader. Returns ------- DaskArray """ - return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks) + return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks, reopen=reopen) def write_elem(