Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): cache arrays in BaseCompressedSparseDataset #1744

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
be4be30
(fix): lazy chunking respects -1
ilan-gold Nov 8, 2024
2115298
(fix): cache arrays in `BaseCompressedSparseDataset`
ilan-gold Nov 8, 2024
2edabe2
(fix): clean up typing
ilan-gold Nov 8, 2024
2860116
(fix): doctest double >>>
ilan-gold Nov 8, 2024
fa96348
(chore): add tests
ilan-gold Nov 8, 2024
a0e2d52
(fix): more typing updates
ilan-gold Nov 8, 2024
dc01a3a
(chore): add tests
ilan-gold Nov 8, 2024
1ba4920
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
37aba1b
(fix): remove extra >>>
ilan-gold Nov 8, 2024
5eb3a6c
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
32fbef9
(fix): spelling
ilan-gold Nov 8, 2024
fcebbf7
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
0d3278e
Merge branch 'main' into ig/fix_chunking
ilan-gold Nov 8, 2024
ceb70b4
(chore): release note
ilan-gold Nov 8, 2024
e7d14ae
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
fedd827
(chore): release note
ilan-gold Nov 8, 2024
5960331
(fix): support `None` and `-1`
ilan-gold Nov 8, 2024
f59e5ca
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
e652c44
(chore): typing
ilan-gold Nov 8, 2024
fc8495f
Merge branch 'ig/fix_chunking' into ig/cache_arrays
ilan-gold Nov 8, 2024
59849a8
(chore): add cache bust test
ilan-gold Nov 8, 2024
41bd62e
(chore): type
ilan-gold Nov 8, 2024
0304d31
(chore): types
ilan-gold Nov 8, 2024
76ecda5
(chore): better name
ilan-gold Nov 8, 2024
e538a12
(Fix): overload type
ilan-gold Nov 8, 2024
b513217
(chore): bring back test comment
ilan-gold Nov 8, 2024
0cca2fe
Merge branch 'main' into ig/cache_arrays
ilan-gold Nov 11, 2024
d2d9f55
Update 1744.bugfix.md
flying-sheep Nov 11, 2024
3dc0ddd
(fix): revert erroneous change
ilan-gold Nov 21, 2024
05c16ff
Merge branch 'main' into ig/cache_arrays
ilan-gold Nov 21, 2024
1dcf7ad
(fix): dont generate coo matrices
ilan-gold Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1744.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache accesses to the `data` and `indices` arrays in {class}`~anndata.abc.CSRDataset` and {class}`~anndata.abc.CSCDataset` {user}`ilan-gold`
33 changes: 25 additions & 8 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from scipy.sparse._compressed import _cs_matrix

from .._types import GroupStorageType
from ..compat import H5Array
from .index import Index
else:
from scipy.sparse import spmatrix as _cs_matrix
Expand Down Expand Up @@ -380,7 +381,7 @@ def backend(self) -> Literal["zarr", "hdf5"]:
@property
def dtype(self) -> np.dtype:
"""The :class:`numpy.dtype` of the `data` attribute of the sparse matrix."""
return self.group["data"].dtype
return self._data.dtype

@classmethod
def _check_group_format(cls, group):
Expand Down Expand Up @@ -545,16 +546,18 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
indptr[orig_data_size:] = (
sparse_matrix.indptr[1:].astype(np.int64) + indptr_offset
)
# Clear cached property
if hasattr(self, "indptr"):
del self._indptr

# indices
indices = self.group["indices"]
orig_data_size = indices.shape[0]
indices.resize((orig_data_size + sparse_matrix.indices.shape[0],))
indices[orig_data_size:] = sparse_matrix.indices

# Clear cached property
for attr in ["_indptr", "_indices", "_data"]:
if hasattr(self, attr):
delattr(self, attr)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved

@cached_property
def _indptr(self) -> np.ndarray:
"""\
Expand All @@ -565,11 +568,25 @@ def _indptr(self) -> np.ndarray:
arr = self.group["indptr"][...]
return arr

@cached_property
def _indices(self) -> H5Array | ZarrArray:
"""\
Cache access to the indices to prevent unnecessary reads of the zarray
"""
return self.group["indices"]

@cached_property
def _data(self) -> H5Array | ZarrArray:
"""\
Cache access to the data to prevent unnecessary reads of the zarray
"""
return self.group["data"]

def _to_backed(self) -> BackedSparseMatrix:
format_class = get_backed_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.data = self._data
mtx.indices = self._indices
mtx.indptr = self._indptr
return mtx

Expand All @@ -578,8 +595,8 @@ def to_memory(self) -> ss.csr_matrix | ss.csc_matrix | SpArray:
self.format, use_sparray_in_io=settings.use_sparse_array_on_read
)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.data = self._data[...]
mtx.indices = self._indices[...]
mtx.indptr = self._indptr
return mtx

Expand Down
54 changes: 32 additions & 22 deletions src/anndata/_io/specs/lazy_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, overload

import h5py
import numpy as np
from scipy import sparse

import anndata as ad
from anndata.abc import CSCDataset, CSRDataset

from ..._core.file_backing import filename, get_elem_name
from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup
from .registry import _LAZY_REGISTRY, IOSpec

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from typing import Literal, ParamSpec, TypeVar

from ..._core.sparse_dataset import _CSCDataset, _CSRDataset
from ..._types import ArrayStorageType, StorageType
from ...compat import DaskArray
from ...compat import DaskArray, H5File, SpArray
from .registry import DaskReader

BlockInfo = Mapping[
Expand All @@ -31,16 +30,25 @@

P = ParamSpec("P")
R = TypeVar("R")
D = TypeVar("D")


@overload
@contextmanager
def maybe_open_h5(
path_or_group: Path | ZarrGroup, elem_name: str
) -> Generator[StorageType, None, None]:
if not isinstance(path_or_group, Path):
yield path_or_group
path_or_other: Path, elem_name: str
) -> Generator[H5File, None, None]: ...
@overload
@contextmanager
def maybe_open_h5(path_or_other: D, elem_name: str) -> Generator[D, None, None]: ...
@contextmanager
def maybe_open_h5(
path_or_other: H5File | D, elem_name: str
) -> Generator[H5File | D, None, None]:
if not isinstance(path_or_other, Path):
yield path_or_other
return
file = h5py.File(path_or_group, "r")
file = h5py.File(path_or_other, "r")
try:
yield file[elem_name]
finally:
Expand All @@ -61,20 +69,17 @@ def compute_chunk_layout_for_axis_shape(


def make_dask_chunk(
path_or_group: Path | ZarrGroup,
path_or_sparse_dataset: Path | D,
elem_name: str,
block_info: BlockInfo | None = None,
*,
wrap: Callable[[ArrayStorageType], ArrayStorageType]
| Callable[[H5Group | ZarrGroup], _CSRDataset | _CSCDataset] = lambda g: g,
):
) -> sparse.csr_matrix | sparse.csc_matrix | SpArray:
if block_info is None:
msg = "Block info is required"
raise ValueError(msg)
# We need to open the file in each task since `dask` cannot share h5py objects when using `dask.distributed`
# https://github.com/scverse/anndata/issues/1105
with maybe_open_h5(path_or_group, elem_name) as f:
mtx = wrap(f)
with maybe_open_h5(path_or_sparse_dataset, elem_name) as f:
mtx = ad.io.sparse_dataset(f) if isinstance(f, H5Group) else f
idx = tuple(
slice(start, stop) for start, stop in block_info[None]["array-location"]
)
Expand All @@ -94,10 +99,17 @@ def read_sparse_as_dask(
) -> DaskArray:
import dask.array as da

path_or_group = Path(filename(elem)) if isinstance(elem, H5Group) else elem
path_or_sparse_dataset = (
Path(filename(elem))
if isinstance(elem, H5Group)
else ad.io.sparse_dataset(elem)
)
elem_name = get_elem_name(elem)
shape: tuple[int, int] = tuple(elem.attrs["shape"])
dtype = elem["data"].dtype
if isinstance(path_or_sparse_dataset, CSRDataset | CSCDataset):
dtype = path_or_sparse_dataset.dtype
else:
dtype = elem["data"].dtype
is_csc: bool = elem.attrs["encoding-type"] == "csc_matrix"

stride: int = _DEFAULT_STRIDE
Expand All @@ -123,9 +135,7 @@ def read_sparse_as_dask(
(chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor)
)
memory_format = sparse.csc_matrix if is_csc else sparse.csr_matrix
make_chunk = partial(
make_dask_chunk, path_or_group, elem_name, wrap=ad.io.sparse_dataset
)
make_chunk = partial(make_dask_chunk, path_or_sparse_dataset, elem_name)
da_mtx = da.map_blocks(
make_chunk,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ class AccessTrackingStore(DirectoryStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = Counter()
self._accessed_keys = {}
self._accessed_keys = dict()
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, key: str) -> object:
for tracked in self._access_count:
Expand Down
72 changes: 58 additions & 14 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import anndata as ad
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.compat import CAN_USE_SPARSE_ARRAY, SpArray
from anndata._io.specs.registry import read_elem_as_dask
from anndata.compat import CAN_USE_SPARSE_ARRAY, DaskArray, SpArray
from anndata.experimental import read_dispatched
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

Expand All @@ -26,6 +27,9 @@
from numpy.typing import ArrayLike, NDArray
from pytest_mock import MockerFixture

from anndata.abc import CSCDataset, CSRDataset
from anndata.compat import ZarrGroup

Idx = slice | int | NDArray[np.integer] | NDArray[np.bool_]


Expand Down Expand Up @@ -281,6 +285,25 @@ def test_dataset_append_memory(
assert_equal(fromdisk, frommem)


def test_append_array_cache_bust(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = tmp_path / f"test.{diskfmt.replace('ad', '')}"
a = sparse.random(100, 100, format="csr")
if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")
ad.io.write_elem(f, "mtx", a)
ad.io.write_elem(f, "mtx_2", a)
diskmtx = sparse_dataset(f["mtx"])
old_array_shapes = {}
array_names = ["indptr", "indices", "data"]
for name in array_names:
old_array_shapes[name] = getattr(diskmtx, f"_{name}").shape
diskmtx.append(sparse_dataset(f["mtx_2"]))
for name in array_names:
assert old_array_shapes[name] != getattr(diskmtx, f"_{name}").shape


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
@pytest.mark.parametrize(
("subset_func", "subset_func2"),
Expand Down Expand Up @@ -354,16 +377,18 @@ def test_dataset_append_disk(


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
def test_indptr_cache(
def test_lazy_array_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
elems = {"indptr", "indices", "data"}
path = tmp_path / "test.zarr"
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad.io.write_elem(f, "X", a)
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/indptr"])
for elem in elems:
store.initialize_key_trackers([f"X/{elem}"])
f = zarr.open_group(store, "a")
a_disk = sparse_dataset(f["X"])
a_disk[:1]
Expand All @@ -372,6 +397,14 @@ def test_indptr_cache(
a_disk[8:9]
# one each for .zarray and actual access
assert store.get_access_count("X/indptr") == 2
for elem_not_indptr in elems - {"indptr"}:
assert (
sum(
".zarray" in key_accessed
for key_accessed in store.get_accessed_keys(f"X/{elem_not_indptr}")
)
== 1
)


Kind = Literal["slice", "int", "array", "mask"]
Expand Down Expand Up @@ -421,27 +454,38 @@ def width_idx_kinds(
(
[0],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[0],
slice(None, 3),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
["X/data/.zarray", "X/data/0"],
),
(
[3, 4, 5],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
["X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
),
l=10,
),
)
@pytest.mark.parametrize(
"open_func",
[
sparse_dataset,
lambda x: read_elem_as_dask(
x, chunks=(1, -1) if x.attrs["encoding-type"] == "csr_matrix" else (-1, 1)
),
],
ids=["sparse_dataset", "read_elem_as_dask"],
)
def test_data_access(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
idx_maj: Idx,
idx_min: Idx,
exp: Sequence[str],
open_func: Callable[[ZarrGroup], CSRDataset | CSCDataset | DaskArray],
):
path = tmp_path / "test.zarr"
a = sparse_format(np.eye(10, 10))
Expand All @@ -454,19 +498,19 @@ def test_data_access(
store = AccessTrackingStore(path)
store.initialize_key_trackers(["X/data"])
f = zarr.open_group(store)
a_disk = sparse_dataset(f["X"])

# Do the slicing with idx
store.reset_key_trackers()
if a_disk.format == "csr":
a_disk[idx_maj, idx_min]
a_disk = AnnData(X=open_func(f["X"]))
if a.format == "csr":
subset = a_disk[idx_maj, idx_min]
else:
a_disk[idx_min, idx_maj]
subset = a_disk[idx_min, idx_maj]
if isinstance(subset.X, DaskArray):
subset.X.compute(scheduler="single-threaded")

assert store.get_access_count("X/data") == len(exp), store.get_accessed_keys(
"X/data"
)
assert store.get_accessed_keys("X/data") == exp
# dask access order is not guaranteed so need to sort
assert sorted(store.get_accessed_keys("X/data")) == sorted(exp)


@pytest.mark.parametrize(
Expand Down