Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): {csr,csc}_array read support #6

Draft
wants to merge 11 commits into
base: sparse-arrays
Choose a base branch
from
55 changes: 30 additions & 25 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from anndata._core.index import _fix_slice_bounds
from anndata.compat import H5Group, ZarrArray, ZarrGroup

from .._settings import settings
from ..compat import SpArray, _read_attr

try:
Expand Down Expand Up @@ -155,12 +156,12 @@ def _get_contiguous_compressed_slice(


class backed_csr(BackedSparseMatrix):
def _get_intXslice(self, row: int, col: slice) -> ss.csr_matrix:
return ss.csr_matrix(
def _get_intXslice(self, row: int, col: slice) -> ss.spmatrix | SpArray:
return get_memory_class(self.format)(
get_compressed_vector(self, row), shape=(1, self.shape[1])
)[:, col]

def _get_sliceXslice(self, row: slice, col: slice) -> ss.csr_matrix:
def _get_sliceXslice(self, row: slice, col: slice) -> ss.spmatrix | SpArray:
row = _fix_slice_bounds(row, self.shape[0])
col = _fix_slice_bounds(col, self.shape[1])

Expand All @@ -172,30 +173,30 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csr_matrix:
return self._get_intXslice(slice_as_int(row, self.shape[0]), col)
elif out_shape[1] == self.shape[1] and out_shape[0] < self.shape[0]:
if row.step == 1:
return ss.csr_matrix(
return get_memory_class(self.format)(
self._get_contiguous_compressed_slice(row), shape=out_shape
)
return self._get_arrayXslice(np.arange(*row.indices(self.shape[0])), col)
return super()._get_sliceXslice(row, col)

def _get_arrayXslice(self, row: Sequence[int], col: slice) -> ss.csr_matrix:
def _get_arrayXslice(self, row: Sequence[int], col: slice) -> ss.spmatrix | SpArray:
idxs = np.asarray(row)
if len(idxs) == 0:
return ss.csr_matrix((0, self.shape[1]))
return get_memory_class(self.format)((0, self.shape[1]))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csr_matrix(
return get_memory_class(self.format)(
get_compressed_vectors(self, idxs), shape=(len(idxs), self.shape[1])
)[:, col]


class backed_csc(BackedSparseMatrix):
def _get_sliceXint(self, row: slice, col: int) -> ss.csc_matrix:
return ss.csc_matrix(
def _get_sliceXint(self, row: slice, col: int) -> ss.spmatrix | SpArray:
return get_memory_class(self.format)(
get_compressed_vector(self, col), shape=(self.shape[0], 1)
)[row, :]

def _get_sliceXslice(self, row: slice, col: slice) -> ss.csc_matrix:
def _get_sliceXslice(self, row: slice, col: slice) -> ss.spmatrix | SpArray:
row = _fix_slice_bounds(row, self.shape[0])
col = _fix_slice_bounds(col, self.shape[1])

Expand All @@ -208,19 +209,19 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csc_matrix:
return self._get_sliceXint(row, slice_as_int(col, self.shape[1]))
elif out_shape[0] == self.shape[0] and out_shape[1] < self.shape[1]:
if col.step == 1:
return ss.csc_matrix(
return get_memory_class(self.format)(
self._get_contiguous_compressed_slice(col), shape=out_shape
)
return self._get_sliceXarray(row, np.arange(*col.indices(self.shape[1])))
return super()._get_sliceXslice(row, col)

def _get_sliceXarray(self, row: slice, col: Sequence[int]) -> ss.csc_matrix:
def _get_sliceXarray(self, row: slice, col: Sequence[int]) -> ss.spmatrix | SpArray:
idxs = np.asarray(col)
if len(idxs) == 0:
return ss.csc_matrix((self.shape[0], 0))
return get_memory_class(self.format)((self.shape[0], 0))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csc_matrix(
return get_memory_class(self.format)(
get_compressed_vectors(self, idxs), shape=(self.shape[0], len(idxs))
)[row, :]

Expand Down Expand Up @@ -327,22 +328,26 @@ def get_format(data: ss.spmatrix) -> str:
raise ValueError(f"Data type {type(data)} is not supported.")


def get_memory_class(format: str, use_sparray_in_io=False) -> type[ss.spmatrix]:
def get_memory_class(format: str) -> type[ss.spmatrix | SpArray]:
for fmt, _, memory_class in FORMATS:
if format == fmt:
if use_sparray_in_io and issubclass(memory_class, SpArray):
if settings.use_sparse_array_in_io and issubclass(memory_class, SpArray):
return memory_class
elif not use_sparray_in_io and issubclass(memory_class, ss.spmatrix):
elif not settings.use_sparse_array_in_io and issubclass(
memory_class, ss.spmatrix
):
return memory_class
raise ValueError(f"Format string {format} is not supported.")


def get_backed_class(format: str, use_sparray_in_io=False) -> type[BackedSparseMatrix]:
def get_backed_class(format: str) -> type[BackedSparseMatrix]:
for fmt, backed_class, _ in FORMATS:
if format == fmt:
if use_sparray_in_io and issubclass(backed_class, SpArray):
if settings.use_sparse_array_in_io and issubclass(backed_class, SpArray):
return backed_class
elif not use_sparray_in_io and issubclass(backed_class, ss.spmatrix):
elif not settings.use_sparse_array_in_io and issubclass(
backed_class, ss.spmatrix
):
return backed_class
raise ValueError(f"Format string {format} is not supported.")

Expand Down Expand Up @@ -433,18 +438,18 @@ def value(self) -> ss.spmatrix:
def __repr__(self) -> str:
return f"{type(self).__name__}: backend {self.backend}, shape {self.shape}, data_dtype {self.dtype}"

def __getitem__(self, index: Index | tuple[()]) -> float | ss.spmatrix:
def __getitem__(self, index: Index | tuple[()]) -> float | ss.spmatrix | SpArray:
indices = self._normalize_index(index)
row, col = indices
mtx = self._to_backed()

memory_class = get_memory_class(self.format)
# Handle masked indexing along major axis
if self.format == "csr" and np.array(row).dtype == bool:
sub = ss.csr_matrix(
sub = memory_class(
subset_by_major_axis_mask(mtx, row), shape=(row.sum(), mtx.shape[1])
)[:, col]
elif self.format == "csc" and np.array(col).dtype == bool:
sub = ss.csc_matrix(
sub = memory_class(
subset_by_major_axis_mask(mtx, col), shape=(mtx.shape[0], col.sum())
)[row, :]
else:
Expand All @@ -453,7 +458,7 @@ def __getitem__(self, index: Index | tuple[()]) -> float | ss.spmatrix:
# If indexing is array x array it returns a backed_sparse_matrix
# Not sure what the performance is on that operation
if isinstance(sub, BackedSparseMatrix):
return get_memory_class(self.format)(sub)
return memory_class(sub)
else:
return sub

Expand Down
13 changes: 13 additions & 0 deletions anndata/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,18 @@ def validate_bool(val) -> bool:
get_from_env=check_and_get_bool,
)


sparray_option = "use_sparse_array_in_io"
sparray_default_value = False
sparray_description = "Whether or not to use :class:`~scipy.sparse.sparray` as the sparse class when reading in sparse data."

settings.register(
sparray_option,
sparray_default_value,
sparray_description,
validate_bool,
get_from_env=check_and_get_bool,
)

##################################################################################
##################################################################################
9 changes: 9 additions & 0 deletions anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@
if not CAN_USE_SPARSE_ARRAY:

class SpArray:
def __init__(self, *args, **kwargs) -> None:
pass

@staticmethod
def __repr__():
return "mock scipy.sparse.sparray"

class CsrArray:
def __init__(self, *args, **kwargs) -> None:
pass

@staticmethod
def __repr__():
return "mock scipy.sparse.csr_array"

class CscArray:
def __init__(self, *args, **kwargs) -> None:
pass

@staticmethod
def __repr__():
return "mock scipy.sparse.csc_array"
Expand Down
100 changes: 66 additions & 34 deletions anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
import anndata as ad
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata._settings import settings
from anndata.compat import SpArray
from anndata.experimental import read_dispatched, write_elem
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func
from anndata.tests.helpers import (
AccessTrackingStore,
assert_equal,
subset_func,
)

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -29,21 +35,28 @@ def diskfmt(request):
return request.param


@pytest.fixture(params=[sparse.csr_matrix, sparse.csr_array])
def sparse_csr_format(request):
return request.param


M = 50
N = 50


@pytest.fixture(scope="function")
def ondisk_equivalent_adata(
tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]
) -> tuple[AnnData, AnnData, AnnData, AnnData]:
tmp_path: Path, diskfmt: Literal["h5ad", "zarr"], sparse_csr_format
) -> tuple[AnnData, AnnData, AnnData, AnnData, AnnData]:
csr_path = tmp_path / f"csr.{diskfmt}"
csc_path = tmp_path / f"csc.{diskfmt}"
dense_path = tmp_path / f"dense.{diskfmt}"

write = lambda x, pth, **kwargs: getattr(x, f"write_{diskfmt}")(pth, **kwargs)

csr_mem = ad.AnnData(X=sparse.random(M, N, format="csr", density=0.1))
csr_mem = ad.AnnData(
X=sparse_csr_format(sparse.random(M, N, format="csr", density=0.1))
)
csc_mem = ad.AnnData(X=csr_mem.X.tocsc())
dense_mem = ad.AnnData(X=csr_mem.X.toarray())

Expand Down Expand Up @@ -80,20 +93,22 @@ def callback(func, elem_name, elem, iospec):
csc_disk = read_zarr_backed(csc_path)
dense_disk = read_zarr_backed(dense_path)

return csr_mem, csr_disk, csc_disk, dense_disk
return csr_mem, csr_disk, csc_mem, csc_disk, dense_disk


@pytest.mark.parametrize(
"empty_mask", [[], np.zeros(M, dtype=bool)], ids=["empty_list", "empty_bool_mask"]
)
def test_empty_backed_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData, AnnData],
empty_mask,
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert_equal(csr_mem.X[empty_mask], csr_disk.X[empty_mask])
assert_equal(csr_mem.X[:, empty_mask], csc_disk.X[:, empty_mask])
csr_mem, csr_disk, csc_mem, csc_disk, _ = ondisk_equivalent_adata
with settings.override(use_sparse_array_in_io=isinstance(csr_mem.X, SpArray)):
assert_equal(csr_mem.X[empty_mask], csr_disk.X[empty_mask])
assert isinstance(csr_mem.X[empty_mask], type(csr_disk.X[empty_mask]))
assert_equal(csr_mem.X[:, empty_mask], csc_disk.X[:, empty_mask])
assert isinstance(csc_mem.X[:, empty_mask], type(csc_disk.X[:, empty_mask]))

# The following do not work because of https://github.com/scipy/scipy/issues/19919
# Our implementation returns a (0,0) sized matrix but scipy does (1,0).
Expand All @@ -103,21 +118,28 @@ def test_empty_backed_indexing(


def test_backed_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData, AnnData],
subset_func,
subset_func2,
):
csr_mem, csr_disk, csc_disk, dense_disk = ondisk_equivalent_adata
csr_mem, csr_disk, csc_mem, csc_disk, dense_disk = ondisk_equivalent_adata

obs_idx = subset_func(csr_mem.obs_names)
var_idx = subset_func2(csr_mem.var_names)

assert_equal(csr_mem[obs_idx, var_idx].X, csr_disk[obs_idx, var_idx].X)
assert_equal(csr_mem[obs_idx, var_idx].X, csc_disk[obs_idx, var_idx].X)
assert_equal(csr_mem.X[...], csc_disk.X[...])
assert_equal(csr_mem[obs_idx, :].X, dense_disk[obs_idx, :].X)
assert_equal(csr_mem[obs_idx].X, csr_disk[obs_idx].X)
assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X)
with settings.override(use_sparse_array_in_io=isinstance(csr_mem.X, SpArray)):
assert_equal(csr_mem[obs_idx, var_idx].X, csr_disk[obs_idx, var_idx].X)
assert_equal(csr_mem[obs_idx, var_idx].X, csc_disk[obs_idx, var_idx].X)
assert_equal(csr_mem.X[...], csc_disk.X[...])
assert_equal(csr_mem[obs_idx, :].X, dense_disk[obs_idx, :].X)
assert_equal(csr_mem[obs_idx].X, csr_disk[obs_idx].X)
assert isinstance(csr_mem[obs_idx].X[...], type(csr_disk[obs_idx].X[...]))
assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X)
assert isinstance(csr_mem[:, var_idx].X[...], type(csr_disk[:, var_idx].X[...]))
assert isinstance(csc_mem[obs_idx].X[...], type(csc_disk[obs_idx].X[...]))
assert isinstance(
csc_mem[obs_idx, var_idx].X[...], type(csc_disk[obs_idx, var_idx].X[...])
)
assert isinstance(csc_mem[:, var_idx].X[...], type(csc_disk[:, var_idx].X[...]))


def make_randomized_mask(size: int) -> np.ndarray:
Expand Down Expand Up @@ -167,7 +189,7 @@ def make_one_elem_mask(size: int) -> np.ndarray:
)
def test_consecutive_bool(
mocker: MockerFixture,
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData, AnnData],
make_bool_mask: Callable[[int], np.ndarray],
should_trigger_optimization: bool | None,
):
Expand All @@ -184,7 +206,7 @@ def test_consecutive_bool(
should_trigger_optimization
Whether or not a given mask should trigger the optimized behavior.
"""
_, csr_disk, csc_disk, _ = ondisk_equivalent_adata
_, csr_disk, _, csc_disk, _ = ondisk_equivalent_adata
mask = make_bool_mask(csr_disk.shape[0])

# indexing needs to be on `X` directly to trigger the optimization.
Expand Down Expand Up @@ -269,6 +291,8 @@ def test_dataset_append_memory(
[
pytest.param(sparse.csr_matrix, sparse.vstack),
pytest.param(sparse.csc_matrix, sparse.hstack),
pytest.param(sparse.csr_array, sparse.vstack),
pytest.param(sparse.csc_array, sparse.hstack),
],
)
def test_dataset_append_disk(
Expand Down Expand Up @@ -305,6 +329,8 @@ def test_dataset_append_disk(
[
pytest.param(sparse.csr_matrix),
pytest.param(sparse.csc_matrix),
pytest.param(sparse.csr_array),
pytest.param(sparse.csc_array),
],
)
def test_indptr_cache(
Expand Down Expand Up @@ -408,33 +434,39 @@ def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
assert not np.any((pre_checks != post_checks).toarray())


def test_anndata_sparse_compat(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
def test_anndata_sparse_compat(
tmp_path: Path, diskfmt: Literal["h5ad", "zarr"], sparse_csr_format
):
path = (
tmp_path / f"test.{diskfmt.replace('ad', '')}"
) # diskfmt is either h5ad or zarr
base = sparse.random(100, 100, format="csr")
base = sparse_csr_format(sparse.random(100, 100, format="csr"))

if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")

ad._io.specs.write_elem(f, "/", base)
adata = ad.AnnData(sparse_dataset(f["/"]))
assert_equal(adata.X, base)
ad._io.specs.write_elem(f, "X", base)
adata = ad.AnnData(sparse_dataset(f["X"]))
with settings.override(use_sparse_array_in_io=isinstance(base, SpArray)):
assert_equal(adata.X, base)
assert isinstance(base, type(adata.X[...]))


def test_backed_sizeof(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData, AnnData],
diskfmt: Literal["h5ad", "zarr"],
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert csr_mem.__sizeof__() == csr_disk.__sizeof__(with_disk=True)
assert csr_mem.__sizeof__() == csc_disk.__sizeof__(with_disk=True)
assert csr_disk.__sizeof__(with_disk=True) == csc_disk.__sizeof__(with_disk=True)
assert csr_mem.__sizeof__() > csr_disk.__sizeof__()
assert csr_mem.__sizeof__() > csc_disk.__sizeof__()
csr_mem, csr_disk, _, csc_disk, _ = ondisk_equivalent_adata
with settings.override(use_sparse_array_in_io=isinstance(csr_mem, SpArray)):
assert csr_mem.__sizeof__() == csr_disk.__sizeof__(with_disk=True)
assert csr_mem.__sizeof__() == csc_disk.__sizeof__(with_disk=True)
assert csr_disk.__sizeof__(with_disk=True) == csc_disk.__sizeof__(
with_disk=True
)
assert csr_mem.__sizeof__() > csr_disk.__sizeof__()
assert csr_mem.__sizeof__() > csc_disk.__sizeof__()


@pytest.mark.parametrize(
Expand Down