diff --git a/anndata/h5py/h5sparse.py b/anndata/h5py/h5sparse.py index 73d2922b7..586a28239 100644 --- a/anndata/h5py/h5sparse.py +++ b/anndata/h5py/h5sparse.py @@ -1,37 +1,58 @@ # TODO: # - think about making all of the below subclasses # - think about supporting the COO format -from typing import Optional, Union, KeysView +from typing import Optional, Union, KeysView, NamedTuple -import six import h5py import numpy as np import scipy.sparse as ss +from scipy.sparse import _sparsetools from ..utils import unpack_index from ..compat import PathLike from .utils import _chunked_rows -FORMAT_DICT = { - 'csr': ss.csr_matrix, - 'csc': ss.csc_matrix, -} + +class BackedFormat(NamedTuple): + format_str: str + backed_type: type + memory_type: type + + +class backed_csr_matrix(ss.csr_matrix): + pass + + +class backed_csc_matrix(ss.csc_matrix): + pass + + +FORMATS = [ + BackedFormat("csr", backed_csr_matrix, ss.csr_matrix), + BackedFormat("csc", backed_csc_matrix, ss.csc_matrix), +] def get_format_str(data): - for format_str, format_class in six.viewitems(FORMAT_DICT): - if isinstance(data, format_class): - return format_str + for fmt, backed_class, memory_class in FORMATS: + if isinstance(data, memory_class): + return fmt raise ValueError("Data type {} is not supported.".format(type(data))) -def get_format_class(format_str): - format_class = FORMAT_DICT.get(format_str, None) - if format_class is None: - raise ValueError("Format string {} is not supported." - .format(format_str)) - return format_class +def get_memory_class(format_str): + for fmt, backed_class, memory_class in FORMATS: + if format_str == fmt: + return memory_class + raise ValueError(f"Format string {format_str} is not supported.") + + +def get_backed_class(format_str): + for fmt, backed_class, memory_class in FORMATS: + if format_str == fmt: + return backed_class + raise ValueError(f"Format string {format_str} is not supported.") def _load_h5_dataset_as_sparse(sds, chunk_size=6000): @@ -40,7 +61,7 @@ def _load_h5_dataset_as_sparse(sds, chunk_size=6000): raise ValueError('sds should be a h5py Dataset') if 'sparse_format' in sds.attrs: - sparse_class = get_format_class(sds.attrs['sparse_format']) + sparse_class = get_memory_class(sds.attrs['sparse_format']) else: sparse_class = ss.csr_matrix @@ -163,10 +184,6 @@ def filename(self): File.__init__.__doc__ = h5py.File.__init__.__doc__ -from scipy.sparse.compressed import _cs_matrix -from scipy.sparse import _sparsetools - - def _set_many(self, i, j, x): """Sets value at each (i, j) to x @@ -175,7 +192,10 @@ def _set_many(self, i, j, x): """ i, j, M, N = self._prepare_indices(i, j) - n_samples = len(x) + if np.isscalar(x): # Scipy 1.3+ compat + n_samples = 1 + else: + n_samples = len(x) offsets = np.empty(n_samples, dtype=self.indices.dtype) ret = _sparsetools.csr_sample_offsets(M, N, self.indptr, self.indices, n_samples, i, j, offsets) @@ -194,21 +214,26 @@ def _set_many(self, i, j, x): return else: - # raise ValueError( - # 'Currently, you cannot change the sparsity structure of a SparseDataset.') + raise ValueError( + 'Currently, you cannot change the sparsity structure of a SparseDataset.') # replace where possible - mask = offsets > -1 - self.data[offsets[mask]] = x[mask] - # only insertions remain - mask = ~mask - i = i[mask] - i[i < 0] += M - j = j[mask] - j[j < 0] += N - self._insert_many(i, j, x[mask]) + # mask = offsets > -1 + # # offsets[mask] + # bool_data_mask = np.zeros(len(self.data), dtype=bool) + # bool_data_mask[offsets[mask]] = True + # self.data[bool_data_mask] = x[mask] + # # self.data[offsets[mask]] = x[mask] + # # only insertions remain + # mask = ~mask + # i = i[mask] + # i[i < 0] += M + # j = j[mask] + # j[j < 0] += N + # self._insert_many(i, j, x[mask]) -_cs_matrix._set_many = _set_many +backed_csr_matrix._set_many = _set_many +backed_csc_matrix._set_many = _set_many def _zero_many(self, i, j): @@ -233,7 +258,8 @@ def _zero_many(self, i, j): self.data[list(offsets[offsets > -1])] = 0 -_cs_matrix._zero_many = _zero_many +backed_csr_matrix._zero_many = _zero_many +backed_csc_matrix._zero_many = _zero_many class SparseDataset: @@ -256,7 +282,7 @@ def format_str(self): def __getitem__(self, index): if index == (): index = slice(None) row, col = unpack_index(index) - format_class = get_format_class(self.format_str) + format_class = get_backed_class(self.format_str) mock_matrix = format_class(self.shape, dtype=self.dtype) mock_matrix.data = self.h5py_group['data'] mock_matrix.indices = self.h5py_group['indices'] @@ -266,7 +292,7 @@ def __getitem__(self, index): def __setitem__(self, index, value): if index == (): index = slice(None) row, col = unpack_index(index) - format_class = get_format_class(self.format_str) + format_class = get_backed_class(self.format_str) mock_matrix = format_class(self.shape, dtype=self.dtype) mock_matrix.data = self.h5py_group['data'] mock_matrix.indices = self.h5py_group['indices'] @@ -283,7 +309,7 @@ def dtype(self): @property def value(self): - format_class = get_format_class(self.format_str) + format_class = get_memory_class(self.format_str) object = self.h5py_group data_array = format_class(self.shape, dtype=self.dtype) data_array.data = np.empty(object['data'].shape, object['data'].dtype) diff --git a/anndata/tests/test_hdf5_backing.py b/anndata/tests/test_hdf5_backing.py index 405d636d7..01846ee26 100644 --- a/anndata/tests/test_hdf5_backing.py +++ b/anndata/tests/test_hdf5_backing.py @@ -44,6 +44,12 @@ def adata(): }, dtype='int32' ) + + +@pytest.fixture(params=[sparse.csr_matrix, sparse.csc_matrix]) +def sparse_format(request): + return request.param + # ------------------------------------------------------------------------------- # The test functions # ------------------------------------------------------------------------------- @@ -121,3 +127,48 @@ def test_return_to_memory_mode(adata, backing_h5ad): bdata.filename = backing_h5ad # close the file bdata.filename = None + + +def test_backed_modification(adata, backing_h5ad): + adata.X[:, 1] = 0 # Make it a little sparse + adata.X = sparse.csr_matrix(adata.X) + assert not adata.isbacked + + # While this currently makes the file backed, it doesn't write it as sparse + adata.filename = backing_h5ad + adata.write() + assert not adata.file.isopen + assert adata.isbacked + + adata.X[0, [0, 2]] = 10 + adata.X[1, [0, 2]] = [11, 12] + adata.X[2, 1] = 13 # If it were written as sparse, this should fail + + assert adata.isbacked + + assert np.all(adata.X[0, :] == np.array([10, 0, 10])) + assert np.all(adata.X[1, :] == np.array([11, 0, 12])) + assert np.all(adata.X[2, :] == np.array([7, 13, 9])) + + +def test_backed_modification_sparse(adata, backing_h5ad, sparse_format): + adata.X[:, 1] = 0 # Make it a little sparse + adata.X = sparse_format(adata.X) + assert not adata.isbacked + + adata.write(backing_h5ad) + adata = ad.read_h5ad(backing_h5ad, backed="r+") + + assert adata.filename == backing_h5ad + assert adata.isbacked + + adata.X[0, [0, 2]] = 10 + adata.X[1, [0, 2]] = [11, 12] + with pytest.raises(ValueError): + adata.X[2, 1] = 13 + + assert adata.isbacked + + assert np.all(adata.X[0, :] == np.array([10, 0, 10])) + assert np.all(adata.X[1, :] == np.array([11, 0, 12])) + assert np.all(adata.X[2, :] == np.array([7, 0, 9]))