Skip to content

Commit

Permalink
Merge pull request #166 from theislab/backed-sparse
Browse files Browse the repository at this point in the history
Backed sparse compat for scipy 1.3
  • Loading branch information
flying-sheep authored Jun 19, 2019
2 parents e9ccfc3 + 6ee7e3e commit fd26586
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 37 deletions.
100 changes: 63 additions & 37 deletions anndata/h5py/h5sparse.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,58 @@
# TODO:
# - think about making all of the below subclasses
# - think about supporting the COO format
from typing import Optional, Union, KeysView
from typing import Optional, Union, KeysView, NamedTuple

import six
import h5py
import numpy as np
import scipy.sparse as ss
from scipy.sparse import _sparsetools

from ..utils import unpack_index
from ..compat import PathLike

from .utils import _chunked_rows

FORMAT_DICT = {
'csr': ss.csr_matrix,
'csc': ss.csc_matrix,
}

class BackedFormat(NamedTuple):
format_str: str
backed_type: type
memory_type: type


class backed_csr_matrix(ss.csr_matrix):
pass


class backed_csc_matrix(ss.csc_matrix):
pass


FORMATS = [
BackedFormat("csr", backed_csr_matrix, ss.csr_matrix),
BackedFormat("csc", backed_csc_matrix, ss.csc_matrix),
]


def get_format_str(data):
for format_str, format_class in six.viewitems(FORMAT_DICT):
if isinstance(data, format_class):
return format_str
for fmt, backed_class, memory_class in FORMATS:
if isinstance(data, memory_class):
return fmt
raise ValueError("Data type {} is not supported.".format(type(data)))


def get_format_class(format_str):
format_class = FORMAT_DICT.get(format_str, None)
if format_class is None:
raise ValueError("Format string {} is not supported."
.format(format_str))
return format_class
def get_memory_class(format_str):
for fmt, backed_class, memory_class in FORMATS:
if format_str == fmt:
return memory_class
raise ValueError(f"Format string {format_str} is not supported.")


def get_backed_class(format_str):
for fmt, backed_class, memory_class in FORMATS:
if format_str == fmt:
return backed_class
raise ValueError(f"Format string {format_str} is not supported.")


def _load_h5_dataset_as_sparse(sds, chunk_size=6000):
Expand All @@ -40,7 +61,7 @@ def _load_h5_dataset_as_sparse(sds, chunk_size=6000):
raise ValueError('sds should be a h5py Dataset')

if 'sparse_format' in sds.attrs:
sparse_class = get_format_class(sds.attrs['sparse_format'])
sparse_class = get_memory_class(sds.attrs['sparse_format'])
else:
sparse_class = ss.csr_matrix

Expand Down Expand Up @@ -163,10 +184,6 @@ def filename(self):
File.__init__.__doc__ = h5py.File.__init__.__doc__


from scipy.sparse.compressed import _cs_matrix
from scipy.sparse import _sparsetools


def _set_many(self, i, j, x):
"""Sets value at each (i, j) to x
Expand All @@ -175,7 +192,10 @@ def _set_many(self, i, j, x):
"""
i, j, M, N = self._prepare_indices(i, j)

n_samples = len(x)
if np.isscalar(x): # Scipy 1.3+ compat
n_samples = 1
else:
n_samples = len(x)
offsets = np.empty(n_samples, dtype=self.indices.dtype)
ret = _sparsetools.csr_sample_offsets(M, N, self.indptr, self.indices,
n_samples, i, j, offsets)
Expand All @@ -194,21 +214,26 @@ def _set_many(self, i, j, x):
return

else:
# raise ValueError(
# 'Currently, you cannot change the sparsity structure of a SparseDataset.')
raise ValueError(
'Currently, you cannot change the sparsity structure of a SparseDataset.')
# replace where possible
mask = offsets > -1
self.data[offsets[mask]] = x[mask]
# only insertions remain
mask = ~mask
i = i[mask]
i[i < 0] += M
j = j[mask]
j[j < 0] += N
self._insert_many(i, j, x[mask])
# mask = offsets > -1
# # offsets[mask]
# bool_data_mask = np.zeros(len(self.data), dtype=bool)
# bool_data_mask[offsets[mask]] = True
# self.data[bool_data_mask] = x[mask]
# # self.data[offsets[mask]] = x[mask]
# # only insertions remain
# mask = ~mask
# i = i[mask]
# i[i < 0] += M
# j = j[mask]
# j[j < 0] += N
# self._insert_many(i, j, x[mask])


_cs_matrix._set_many = _set_many
backed_csr_matrix._set_many = _set_many
backed_csc_matrix._set_many = _set_many


def _zero_many(self, i, j):
Expand All @@ -233,7 +258,8 @@ def _zero_many(self, i, j):
self.data[list(offsets[offsets > -1])] = 0


_cs_matrix._zero_many = _zero_many
backed_csr_matrix._zero_many = _zero_many
backed_csc_matrix._zero_many = _zero_many


class SparseDataset:
Expand All @@ -256,7 +282,7 @@ def format_str(self):
def __getitem__(self, index):
if index == (): index = slice(None)
row, col = unpack_index(index)
format_class = get_format_class(self.format_str)
format_class = get_backed_class(self.format_str)
mock_matrix = format_class(self.shape, dtype=self.dtype)
mock_matrix.data = self.h5py_group['data']
mock_matrix.indices = self.h5py_group['indices']
Expand All @@ -266,7 +292,7 @@ def __getitem__(self, index):
def __setitem__(self, index, value):
if index == (): index = slice(None)
row, col = unpack_index(index)
format_class = get_format_class(self.format_str)
format_class = get_backed_class(self.format_str)
mock_matrix = format_class(self.shape, dtype=self.dtype)
mock_matrix.data = self.h5py_group['data']
mock_matrix.indices = self.h5py_group['indices']
Expand All @@ -283,7 +309,7 @@ def dtype(self):

@property
def value(self):
format_class = get_format_class(self.format_str)
format_class = get_memory_class(self.format_str)
object = self.h5py_group
data_array = format_class(self.shape, dtype=self.dtype)
data_array.data = np.empty(object['data'].shape, object['data'].dtype)
Expand Down
51 changes: 51 additions & 0 deletions anndata/tests/test_hdf5_backing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def adata():
},
dtype='int32'
)


@pytest.fixture(params=[sparse.csr_matrix, sparse.csc_matrix])
def sparse_format(request):
return request.param

# -------------------------------------------------------------------------------
# The test functions
# -------------------------------------------------------------------------------
Expand Down Expand Up @@ -121,3 +127,48 @@ def test_return_to_memory_mode(adata, backing_h5ad):
bdata.filename = backing_h5ad
# close the file
bdata.filename = None


def test_backed_modification(adata, backing_h5ad):
adata.X[:, 1] = 0 # Make it a little sparse
adata.X = sparse.csr_matrix(adata.X)
assert not adata.isbacked

# While this currently makes the file backed, it doesn't write it as sparse
adata.filename = backing_h5ad
adata.write()
assert not adata.file.isopen
assert adata.isbacked

adata.X[0, [0, 2]] = 10
adata.X[1, [0, 2]] = [11, 12]
adata.X[2, 1] = 13 # If it were written as sparse, this should fail

assert adata.isbacked

assert np.all(adata.X[0, :] == np.array([10, 0, 10]))
assert np.all(adata.X[1, :] == np.array([11, 0, 12]))
assert np.all(adata.X[2, :] == np.array([7, 13, 9]))


def test_backed_modification_sparse(adata, backing_h5ad, sparse_format):
adata.X[:, 1] = 0 # Make it a little sparse
adata.X = sparse_format(adata.X)
assert not adata.isbacked

adata.write(backing_h5ad)
adata = ad.read_h5ad(backing_h5ad, backed="r+")

assert adata.filename == backing_h5ad
assert adata.isbacked

adata.X[0, [0, 2]] = 10
adata.X[1, [0, 2]] = [11, 12]
with pytest.raises(ValueError):
adata.X[2, 1] = 13

assert adata.isbacked

assert np.all(adata.X[0, :] == np.array([10, 0, 10]))
assert np.all(adata.X[1, :] == np.array([11, 0, 12]))
assert np.all(adata.X[2, :] == np.array([7, 0, 9]))

0 comments on commit fd26586

Please sign in to comment.