Skip to content

Commit

Permalink
Fix sparse dataset 2D slicing (#1523)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilan Gold <[email protected]>
  • Loading branch information
flying-sheep and ilan-gold authored Jul 4, 2024
1 parent 9664ce4 commit 6d3392b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 30 deletions.
25 changes: 12 additions & 13 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csr_matrix:
)
if out_shape[0] == 1:
return self._get_intXslice(slice_as_int(row, self.shape[0]), col)
elif out_shape[1] == self.shape[1] and out_shape[0] < self.shape[0]:
if row.step == 1:
return ss.csr_matrix(
self._get_contiguous_compressed_slice(row), shape=out_shape
)
if row.step != 1:
return self._get_arrayXslice(np.arange(*row.indices(self.shape[0])), col)
return super()._get_sliceXslice(row, col)
res = ss.csr_matrix(
self._get_contiguous_compressed_slice(row),
shape=(out_shape[0], self.shape[1]),
)
return res if out_shape[1] == self.shape[1] else res[:, col]

def _get_arrayXslice(self, row: Sequence[int], col: slice) -> ss.csr_matrix:
idxs = np.asarray(row)
Expand All @@ -208,16 +208,15 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csc_matrix:
slice_len(row, self.shape[0]),
slice_len(col, self.shape[1]),
)

if out_shape[1] == 1:
return self._get_sliceXint(row, slice_as_int(col, self.shape[1]))
elif out_shape[0] == self.shape[0] and out_shape[1] < self.shape[1]:
if col.step == 1:
return ss.csc_matrix(
self._get_contiguous_compressed_slice(col), shape=out_shape
)
if col.step != 1:
return self._get_sliceXarray(row, np.arange(*col.indices(self.shape[1])))
return super()._get_sliceXslice(row, col)
res = ss.csc_matrix(
self._get_contiguous_compressed_slice(col),
shape=(self.shape[0], out_shape[1]),
)
return res if out_shape[0] == self.shape[0] else res[row, :]

def _get_sliceXarray(self, row: slice, col: Sequence[int]) -> ss.csc_matrix:
idxs = np.asarray(col)
Expand Down
104 changes: 87 additions & 17 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Callable, Literal
from itertools import product
from typing import TYPE_CHECKING, Callable, Literal, get_args

import h5py
import numpy as np
Expand All @@ -17,11 +18,16 @@
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from pathlib import Path

from numpy.typing import ArrayLike
from _pytest.mark import ParameterSet
from numpy.typing import ArrayLike, NDArray
from pytest_mock import MockerFixture

Idx = slice | int | NDArray[np.integer] | NDArray[np.bool_]


subset_func2 = subset_func


Expand Down Expand Up @@ -302,7 +308,7 @@ def test_indptr_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
path = tmp_path / "test.zarr"
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad._io.specs.write_elem(f, "X", a)
Expand All @@ -318,12 +324,76 @@ def test_indptr_cache(
assert store.get_access_count("X/indptr") == 2


Kind = Literal["slice", "int", "array", "mask"]


def mk_idx_kind(idx: Sequence[int], *, kind: Kind, l: int) -> Idx | None:
"""Convert sequence of consecutive integers (e.g. range with step=1) into different kinds of indexing."""
if kind == "slice":
start = idx[0] if idx[0] > 0 else None
if len(idx) == 1:
return slice(start, idx[0] + 1)
if all(np.diff(idx) == 1):
stop = idx[-1] + 1 if idx[-1] < l - 1 else None
return slice(start, stop)
if kind == "int":
if len(idx) == 1:
return idx[0]
if kind == "array":
return np.asarray(idx)
if kind == "mask":
return np.isin(np.arange(l), idx)
return None


def idify(x: object) -> str:
if isinstance(x, slice):
start, stop = ("" if s is None else str(s) for s in (x.start, x.stop))
return f"{start}:{stop}" + (f":{x.step}" if x.step not in (1, None) else "")
return str(x)


def width_idx_kinds(
*idxs: tuple[Sequence[int], Idx, Sequence[str]], l: int
) -> Generator[ParameterSet, None, None]:
"""Convert major (first) index into various identical kinds of indexing."""
for (idx_maj_raw, idx_min, exp), maj_kind in product(idxs, get_args(Kind)):
if (idx_maj := mk_idx_kind(idx_maj_raw, kind=maj_kind, l=l)) is None:
continue
id_ = "-".join(map(idify, [idx_maj_raw, idx_min, maj_kind]))
yield pytest.param(idx_maj, idx_min, exp, id=id_)


@pytest.mark.parametrize("sparse_format", [sparse.csr_matrix, sparse.csc_matrix])
@pytest.mark.parametrize(
("idx_maj", "idx_min", "exp"),
width_idx_kinds(
(
[0],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
),
(
[0],
slice(None, 3),
["X/data/.zarray", "X/data/.zarray", "X/data/0"],
),
(
[3, 4, 5],
slice(None, None),
["X/data/.zarray", "X/data/.zarray", "X/data/3", "X/data/4", "X/data/5"],
),
l=10,
),
)
def test_data_access(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
idx_maj: Idx,
idx_min: Idx,
exp: Sequence[str],
):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
path = tmp_path / "test.zarr"
a = sparse_format(np.eye(10, 10))
f = zarr.open_group(path, "a")
ad._io.specs.write_elem(f, "X", a)
Expand All @@ -335,18 +405,18 @@ def test_data_access(
store.initialize_key_trackers(["X/data"])
f = zarr.open_group(store)
a_disk = sparse_dataset(f["X"])
for idx in [slice(0, 1), 0, np.array([0]), np.array([True] + [False] * 9)]:
store.reset_key_trackers()
if a_disk.format == "csr":
a_disk[idx, :]
else:
a_disk[:, idx]
assert store.get_access_count("X/data") == 3
assert store.get_accessed_keys("X/data") == [
"X/data/.zarray",
"X/data/.zarray",
"X/data/0",
]

# Do the slicing with idx
store.reset_key_trackers()
if a_disk.format == "csr":
a_disk[idx_maj, idx_min]
else:
a_disk[idx_min, idx_maj]

assert store.get_access_count("X/data") == len(exp), store.get_accessed_keys(
"X/data"
)
assert store.get_accessed_keys("X/data") == exp


@pytest.mark.parametrize(
Expand Down Expand Up @@ -382,7 +452,7 @@ def test_wrong_shape(


def test_reset_group(tmp_path: Path):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
path = tmp_path / "test.zarr"
base = sparse.random(100, 100, format="csr")

if diskfmt == "zarr":
Expand Down

0 comments on commit 6d3392b

Please sign in to comment.