diff --git a/docs/release-notes/1744.bugfix.md b/docs/release-notes/1744.bugfix.md new file mode 100644 index 000000000..2a425d43c --- /dev/null +++ b/docs/release-notes/1744.bugfix.md @@ -0,0 +1 @@ +Cache accesses to the `data` and `indices` arrays in {class}`~anndata.abc.CSRDataset` and {class}`~anndata.abc.CSCDataset` {user}`ilan-gold` diff --git a/src/anndata/_core/sparse_dataset.py b/src/anndata/_core/sparse_dataset.py index ae6b47c7f..a155352af 100644 --- a/src/anndata/_core/sparse_dataset.py +++ b/src/anndata/_core/sparse_dataset.py @@ -38,6 +38,7 @@ from scipy.sparse._compressed import _cs_matrix from .._types import GroupStorageType + from ..compat import H5Array from .index import Index else: from scipy.sparse import spmatrix as _cs_matrix @@ -380,7 +381,7 @@ def backend(self) -> Literal["zarr", "hdf5"]: @property def dtype(self) -> np.dtype: """The :class:`numpy.dtype` of the `data` attribute of the sparse matrix.""" - return self.group["data"].dtype + return self._data.dtype @classmethod def _check_group_format(cls, group): @@ -545,9 +546,6 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None indptr[orig_data_size:] = ( sparse_matrix.indptr[1:].astype(np.int64) + indptr_offset ) - # Clear cached property - if hasattr(self, "indptr"): - del self._indptr # indices indices = self.group["indices"] @@ -555,6 +553,11 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None indices.resize((orig_data_size + sparse_matrix.indices.shape[0],)) indices[orig_data_size:] = sparse_matrix.indices + # Clear cached property + for attr in ["_indptr", "_indices", "_data"]: + if hasattr(self, attr): + delattr(self, attr) + @cached_property def _indptr(self) -> np.ndarray: """\ @@ -565,11 +568,25 @@ def _indptr(self) -> np.ndarray: arr = self.group["indptr"][...] return arr + @cached_property + def _indices(self) -> H5Array | ZarrArray: + """\ + Cache access to the indices to prevent unnecessary reads of the zarray + """ + return self.group["indices"] + + @cached_property + def _data(self) -> H5Array | ZarrArray: + """\ + Cache access to the data to prevent unnecessary reads of the zarray + """ + return self.group["data"] + def _to_backed(self) -> BackedSparseMatrix: format_class = get_backed_class(self.format) mtx = format_class(self.shape, dtype=self.dtype) - mtx.data = self.group["data"] - mtx.indices = self.group["indices"] + mtx.data = self._data + mtx.indices = self._indices mtx.indptr = self._indptr return mtx @@ -578,8 +595,8 @@ def to_memory(self) -> ss.csr_matrix | ss.csc_matrix | SpArray: self.format, use_sparray_in_io=settings.use_sparse_array_on_read ) mtx = format_class(self.shape, dtype=self.dtype) - mtx.data = self.group["data"][...] - mtx.indices = self.group["indices"][...] + mtx.data = self._data[...] + mtx.indices = self._indices[...] mtx.indptr = self._indptr return mtx diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index a34f627e7..1092f6fca 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -3,25 +3,24 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload import h5py import numpy as np from scipy import sparse import anndata as ad +from anndata.abc import CSCDataset, CSRDataset from ..._core.file_backing import filename, get_elem_name from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup from .registry import _LAZY_REGISTRY, IOSpec if TYPE_CHECKING: - from collections.abc import Callable, Generator, Mapping, Sequence + from collections.abc import Generator, Mapping, Sequence from typing import Literal, ParamSpec, TypeVar - from ..._core.sparse_dataset import _CSCDataset, _CSRDataset - from ..._types import ArrayStorageType, StorageType - from ...compat import DaskArray + from ...compat import DaskArray, H5File, SpArray from .registry import DaskReader BlockInfo = Mapping[ @@ -31,16 +30,25 @@ P = ParamSpec("P") R = TypeVar("R") + D = TypeVar("D") +@overload @contextmanager def maybe_open_h5( - path_or_group: Path | ZarrGroup, elem_name: str -) -> Generator[StorageType, None, None]: - if not isinstance(path_or_group, Path): - yield path_or_group + path_or_other: Path, elem_name: str +) -> Generator[H5File, None, None]: ... +@overload +@contextmanager +def maybe_open_h5(path_or_other: D, elem_name: str) -> Generator[D, None, None]: ... +@contextmanager +def maybe_open_h5( + path_or_other: H5File | D, elem_name: str +) -> Generator[H5File | D, None, None]: + if not isinstance(path_or_other, Path): + yield path_or_other return - file = h5py.File(path_or_group, "r") + file = h5py.File(path_or_other, "r") try: yield file[elem_name] finally: @@ -61,20 +69,17 @@ def compute_chunk_layout_for_axis_shape( def make_dask_chunk( - path_or_group: Path | ZarrGroup, + path_or_sparse_dataset: Path | D, elem_name: str, block_info: BlockInfo | None = None, - *, - wrap: Callable[[ArrayStorageType], ArrayStorageType] - | Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g, -): +) -> sparse.csr_matrix | sparse.csc_matrix | SpArray: if block_info is None: msg = "Block info is required" raise ValueError(msg) # We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed` # https://github.com/scverse/anndata/issues/1105 - with maybe_open_h5(path_or_group, elem_name) as f: - mtx = wrap(f) + with maybe_open_h5(path_or_sparse_dataset, elem_name) as f: + mtx = ad.io.sparse_dataset(f) if isinstance(f, H5Group) else f idx = tuple( slice(start, stop) for start, stop in block_info[None]["array-location"] ) @@ -94,10 +99,17 @@ def read_sparse_as_dask( ) -> DaskArray: import dask.array as da - path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem + path_or_sparse_dataset = ( + Path(filename(elem)) + if isinstance(elem, H5Group) + else ad.io.sparse_dataset(elem) + ) elem_name = get_elem_name(elem) shape: tuple[int, int] = tuple(elem.attrs["shape"]) - dtype = elem["data"].dtype + if isinstance(path_or_sparse_dataset, CSRDataset | CSCDataset): + dtype = path_or_sparse_dataset.dtype + else: + dtype = elem["data"].dtype is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix" stride: int = _DEFAULT_STRIDE @@ -123,9 +135,7 @@ def read_sparse_as_dask( (chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor) ) memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix - make_chunk = partial( - make_dask_chunk, path_or_group, elem_name, wrap=ad.io.sparse_dataset - ) + make_chunk = partial(make_dask_chunk, path_or_sparse_dataset, elem_name) da_mtx = da.map_blocks( make_chunk, dtype=dtype, diff --git a/tests/test_backed_sparse.py b/tests/test_backed_sparse.py index 2778c76bb..03155d0a3 100644 --- a/tests/test_backed_sparse.py +++ b/tests/test_backed_sparse.py @@ -13,7 +13,8 @@ import anndata as ad from anndata._core.anndata import AnnData from anndata._core.sparse_dataset import sparse_dataset -from anndata.compat import CAN_USE_SPARSE_ARRAY, SpArray +from anndata._io.specs.registry import read_elem_as_dask +from anndata.compat import CAN_USE_SPARSE_ARRAY, DaskArray, SpArray from anndata.experimental import read_dispatched from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func @@ -26,6 +27,9 @@ from numpy.typing import ArrayLike, NDArray from pytest_mock import MockerFixture + from anndata.abc import CSCDataset, CSRDataset + from anndata.compat import ZarrGroup + Idx = slice | int | NDArray[np.integer] | NDArray[np.bool_] @@ -281,6 +285,25 @@ def test_dataset_append_memory( assert_equal(fromdisk, frommem) +def test_append_array_cache_bust(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]): + path = tmp_path / f"test.{diskfmt.replace('ad', '')}" + a = sparse.random(100, 100, format="csr") + if diskfmt == "zarr": + f = zarr.open_group(path, "a") + else: + f = h5py.File(path, "a") + ad.io.write_elem(f, "mtx", a) + ad.io.write_elem(f, "mtx_2", a) + diskmtx = sparse_dataset(f["mtx"]) + old_array_shapes = {} + array_names = ["indptr", "indices", "data"] + for name in array_names: + old_array_shapes[name] = getattr(diskmtx, f"_{name}").shape + diskmtx.append(sparse_dataset(f["mtx_2"])) + for name in array_names: + assert old_array_shapes[name] != getattr(diskmtx, f"_{name}").shape + + @pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix]) @pytest.mark.parametrize( ("subset_func", "subset_func2"), @@ -354,16 +377,18 @@ def test_dataset_append_disk( @pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix]) -def test_indptr_cache( +def test_lazy_array_cache( tmp_path: Path, sparse_format: Callable[[ArrayLike], sparse.spmatrix], ): + elems = {"indptr", "indices", "data"} path = tmp_path / "test.zarr" a = sparse_format(sparse.random(10, 10)) f = zarr.open_group(path, "a") ad.io.write_elem(f, "X", a) store = AccessTrackingStore(path) - store.initialize_key_trackers(["X/indptr"]) + for elem in elems: + store.initialize_key_trackers([f"X/{elem}"]) f = zarr.open_group(store, "a") a_disk = sparse_dataset(f["X"]) a_disk[:1] @@ -372,6 +397,14 @@ def test_indptr_cache( a_disk[8:9] # one each for .zarray and actual access assert store.get_access_count("X/indptr") == 2 + for elem_not_indptr in elems - {"indptr"}: + assert ( + sum( + ".zarray" in key_accessed + for key_accessed in store.get_accessed_keys(f"X/{elem_not_indptr}") + ) + == 1 + ) Kind = Literal["slice", "int", "array", "mask"] @@ -421,27 +454,38 @@ def width_idx_kinds( ( [0], slice(None, None), - ["X/data/.zarray", "X/data/.zarray", "X/data/0"], + ["X/data/.zarray", "X/data/0"], ), ( [0], slice(None, 3), - ["X/data/.zarray", "X/data/.zarray", "X/data/0"], + ["X/data/.zarray", "X/data/0"], ), ( [3, 4, 5], slice(None, None), - ["X/data/.zarray", "X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"], + ["X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"], ), l=10, ), ) +@pytest.mark.parametrize( + "open_func", + [ + sparse_dataset, + lambda x: read_elem_as_dask( + x, chunks=(1, -1) if x.attrs["encoding-type"] == "csr_matrix" else (-1, 1) + ), + ], + ids=["sparse_dataset", "read_elem_as_dask"], +) def test_data_access( tmp_path: Path, sparse_format: Callable[[ArrayLike], sparse.spmatrix], idx_maj: Idx, idx_min: Idx, exp: Sequence[str], + open_func: Callable[[ZarrGroup], CSRDataset | CSCDataset | DaskArray], ): path = tmp_path / "test.zarr" a = sparse_format(np.eye(10, 10)) @@ -454,19 +498,19 @@ def test_data_access( store = AccessTrackingStore(path) store.initialize_key_trackers(["X/data"]) f = zarr.open_group(store) - a_disk = sparse_dataset(f["X"]) - - # Do the slicing with idx - store.reset_key_trackers() - if a_disk.format == "csr": - a_disk[idx_maj, idx_min] + a_disk = AnnData(X=open_func(f["X"])) + if a.format == "csr": + subset = a_disk[idx_maj, idx_min] else: - a_disk[idx_min, idx_maj] + subset = a_disk[idx_min, idx_maj] + if isinstance(subset.X, DaskArray): + subset.X.compute(scheduler="single-threaded") assert store.get_access_count("X/data") == len(exp), store.get_accessed_keys( "X/data" ) - assert store.get_accessed_keys("X/data") == exp + # dask access order is not guaranteed so need to sort + assert sorted(store.get_accessed_keys("X/data")) == sorted(exp) @pytest.mark.parametrize( diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index e034debd2..d9f399dd6 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1044,7 +1044,9 @@ def gen_list(n): def gen_sparse(n): - return sparse.random(np.random.randint(1, 100), np.random.randint(1, 100)) + return sparse.random( + np.random.randint(1, 100), np.random.randint(1, 100), format="csr" + ) def gen_something(n):