Skip to content

Commit

Permalink
Sparse dense test (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
Koncopd authored and flying-sheep committed Jan 11, 2019
1 parent 2ab6e8e commit 622503a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
24 changes: 17 additions & 7 deletions anndata/readwrite/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def postprocess_reading(key, value):
return


def read_h5ad(filename, backed: Union[bool, str] = False):
def read_h5ad(filename, backed: Union[bool, str] = False, chunk_size: int = 6000):
"""Read ``.h5ad``-formatted hdf5 file.
Parameters
Expand All @@ -420,18 +420,28 @@ def read_h5ad(filename, backed: Union[bool, str] = False):
loading it into memory (`memory` mode). `True` and 'r' are
equivalent. If you want to modify backed attributes of the AnnData
object, you need to choose 'r+'.
chunk_size
Used only when loading sparse dataset that is stored as dense.
Loading iterates through chunks of the dataset of this row size
until it reads the whole dataset.
Higher size means higher memory consumption and higher loading speed.
"""

if backed:
# open in backed-mode
return AnnData(filename=filename, filemode=backed)
else:
# load everything into memory
d = _read_h5ad(filename=filename)
d = _read_h5ad(filename=filename, chunk_size=chunk_size)
return AnnData(d)


def _read_h5ad(adata: AnnData = None, filename: Optional[PathLike] = None, mode: str = None):
def _read_h5ad(
adata: AnnData = None,
filename: Optional[PathLike] = None,
mode: str = None,
chunk_size: int = 6000
):
"""Return a dict with arrays for initializing AnnData.
Parameters
Expand All @@ -458,7 +468,7 @@ def _read_h5ad(adata: AnnData = None, filename: Optional[PathLike] = None, mode:
if backed and key in AnnData._BACKED_ATTRS:
d[key] = None
else:
_read_key_value_from_h5(f, d, key)
_read_key_value_from_h5(f, d, key, chunk_size=chunk_size)
# backwards compat: save X with the correct name
if 'X' not in d:
if backed == 'r+':
Expand All @@ -477,18 +487,18 @@ def _read_h5ad(adata: AnnData = None, filename: Optional[PathLike] = None, mode:
return d


def _read_key_value_from_h5(f, d, key, key_write=None):
def _read_key_value_from_h5(f, d, key, key_write=None, chunk_size=6000):
if key_write is None: key_write = key
if isinstance(f[key], h5py.Group):
d[key_write] = OrderedDict() if key == 'uns' else {}
for k in f[key].keys():
_read_key_value_from_h5(f, d[key_write], key + '/' + k, k)
_read_key_value_from_h5(f, d[key_write], key + '/' + k, k, chunk_size)
return

ds = f[key]

if isinstance(ds, h5py.Dataset) and 'sparse_format' in ds.attrs:
value = h5py._load_h5_dataset_as_sparse(ds)
value = h5py._load_h5_dataset_as_sparse(ds, chunk_size)
elif isinstance(ds, h5py.Dataset):
value = np.empty(ds.shape, ds.dtype)
if 0 not in ds.shape:
Expand Down
15 changes: 14 additions & 1 deletion anndata/tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix, issparse
import anndata as ad


Expand All @@ -24,6 +24,13 @@ def tmp_path(request, tmp_path_factory):
# -------------------------------------------------------------------------------
# Some test data
# -------------------------------------------------------------------------------
X_sp = csr_matrix([
[1, 0, 0],
[3, 0, 0],
[5, 6, 0],
[0, 0, 0],
[0, 0, 0]
])

X_list = [ # data matrix of shape n_obs x n_vars
[1, 0],
Expand Down Expand Up @@ -69,6 +76,12 @@ def test_readwrite_h5ad(typ, tmp_path):
assert adata.obs['oanno1'].cat.categories.tolist() == ['cat1', 'cat2']
assert pd.api.types.is_categorical(adata.raw.var['vanno2'])

def test_readwrite_sparse_as_dense(tmp_path):
adata = ad.AnnData(X_sp)
adata.write(tmp_path / 'test.h5ad', force_dense=True)
adata = ad.read(tmp_path / 'test.h5ad', chunk_size=2)
assert issparse(adata.X)
assert np.allclose(X_sp.toarray(), adata.X.toarray())

@pytest.mark.parametrize('typ', [np.array, csr_matrix])
def test_readwrite_h5ad_one_dimensino(typ, tmp_path):
Expand Down

0 comments on commit 622503a

Please sign in to comment.