Skip to content

Commit

Permalink
Fix shape inference for X=None (#1121)
Browse files Browse the repository at this point in the history
* add basic view tests

* Allow passing dicts as obs/var

* simplify

* add tests for errors

* fix test

* really fix tests

* clearer tests

* add source

* fix remaining tests

* fix 3.8 compat

* fix msg

* annots

* fix raw

* clearer tests

* Release note

* fix release note index

* no 0.10.1 yet

* suggestions
  • Loading branch information
flying-sheep authored Sep 8, 2023
1 parent ddba0bc commit 88dd129
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 75 deletions.
143 changes: 90 additions & 53 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,75 @@ def _check_2d_shape(X):
)


def _mk_df_error(
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
expected: int,
actual: int,
):
if source == "X":
what = "row" if attr == "obs" else "column"
msg = (
f"Observations annot. `{attr}` must have as many rows as `X` has {what}s "
f"({expected}), but has {actual} rows."
)
else:
msg = (
f"`shape` is inconsistent with `{attr}` "
"({actual} {what}s instead of {expected})"
)
return ValueError(msg)


@singledispatch
def _gen_dataframe(anno, length, index_names):
def _gen_dataframe(
anno: Mapping[str, Any],
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
) -> pd.DataFrame:
if anno is None or len(anno) == 0:
anno = {}

def mk_index(l: int) -> pd.Index:
return pd.RangeIndex(0, l, name=None).astype(str)

for index_name in index_names:
if index_name in anno:
return pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
return pd.DataFrame(
anno,
index=pd.RangeIndex(0, length, name=None).astype(str),
columns=None if len(anno) else [],
)
if index_name not in anno:
continue
df = pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
break
else:
df = pd.DataFrame(
anno,
index=None if length is None else mk_index(length),
columns=None if len(anno) else [],
)

if length is None:
df.index = mk_index(len(df))
elif length != len(df):
raise _mk_df_error(source, attr, length, len(df))
return df


@_gen_dataframe.register(pd.DataFrame)
def _(anno, length, index_names):
def _gen_dataframe_df(
anno: pd.DataFrame,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
if length is not None and length != len(anno):
raise _mk_df_error(source, attr, length, len(anno))
anno = anno.copy(deep=False)
if not is_string_dtype(anno.index):
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
Expand All @@ -133,8 +182,15 @@ def _(anno, length, index_names):

@_gen_dataframe.register(pd.Series)
@_gen_dataframe.register(pd.Index)
def _(anno, length, index_names):
raise ValueError(f"Cannot convert {type(anno)} to DataFrame")
def _gen_dataframe_1d(
anno: pd.Series | pd.Index,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame")


class AnnData(metaclass=utils.DeprecationMixinMeta):
Expand Down Expand Up @@ -356,8 +412,6 @@ def _init_as_view(self, adata_ref: "AnnData", oidx: Index, vidx: Index):
self._obs = DataFrameView(obs_sub, view_args=(self, "obs"))
self._var = DataFrameView(var_sub, view_args=(self, "var"))
self._uns = uns
self._n_obs = len(self.obs)
self._n_vars = len(self.var)

# set data
if self.isbacked:
Expand Down Expand Up @@ -473,27 +527,20 @@ def _init_as_actual(
X = np.array(X, dtype, copy=False)
# data matrix and shape
self._X = X
self._n_obs, self._n_vars = self._X.shape
n_obs, n_vars = X.shape
source = "X"
else:
self._X = None
self._n_obs = len([] if obs is None else obs)
self._n_vars = len([] if var is None else var)
# check consistency with shape
if shape is not None:
if self._n_obs == 0:
self._n_obs = shape[0]
else:
if self._n_obs != shape[0]:
raise ValueError("`shape` is inconsistent with `obs`")
if self._n_vars == 0:
self._n_vars = shape[1]
else:
if self._n_vars != shape[1]:
raise ValueError("`shape` is inconsistent with `var`")
n_obs, n_vars = (None, None) if shape is None else shape
source = "shape"

# annotations
self._obs = _gen_dataframe(obs, self._n_obs, ["obs_names", "row_names"])
self._var = _gen_dataframe(var, self._n_vars, ["var_names", "col_names"])
self._obs = _gen_dataframe(
obs, ["obs_names", "row_names"], source=source, attr="obs", length=n_obs
)
self._var = _gen_dataframe(
var, ["var_names", "col_names"], source=source, attr="var", length=n_vars
)

# now we can verify if indices match!
for attr_name, x_name, idx in x_indices:
Expand Down Expand Up @@ -783,12 +830,12 @@ def raw(self):
@property
def n_obs(self) -> int:
"""Number of observations."""
return self._n_obs
return len(self.obs_names)

@property
def n_vars(self) -> int:
"""Number of variables/features."""
return self._n_vars
return len(self.var_names)

def _set_dim_df(self, value: pd.DataFrame, attr: str):
if not isinstance(value, pd.DataFrame):
Expand Down Expand Up @@ -1855,38 +1902,28 @@ def __contains__(self, key: Any):

def _check_dimensions(self, key=None):
if key is None:
key = {"obs", "var", "obsm", "varm"}
key = {"obsm", "varm"}
else:
key = {key}
if "obs" in key and len(self._obs) != self._n_obs:
raise ValueError(
"Observations annot. `obs` must have number of rows of `X`"
f" ({self._n_obs}), but has {self._obs.shape[0]} rows."
)
if "var" in key and len(self._var) != self._n_vars:
raise ValueError(
"Variables annot. `var` must have number of columns of `X`"
f" ({self._n_vars}), but has {self._var.shape[0]} rows."
)
if "obsm" in key:
obsm = self._obsm
if (
not all([dim_len(o, 0) == self._n_obs for o in obsm.values()])
and len(obsm.dim_names) != self._n_obs
not all([dim_len(o, 0) == self.n_obs for o in obsm.values()])
and len(obsm.dim_names) != self.n_obs
):
raise ValueError(
"Observations annot. `obsm` must have number of rows of `X`"
f" ({self._n_obs}), but has {len(obsm)} rows."
f" ({self.n_obs}), but has {len(obsm)} rows."
)
if "varm" in key:
varm = self._varm
if (
not all([dim_len(v, 0) == self._n_vars for v in varm.values()])
and len(varm.dim_names) != self._n_vars
not all([dim_len(v, 0) == self.n_vars for v in varm.values()])
and len(varm.dim_names) != self.n_vars
):
raise ValueError(
"Variables annot. `varm` must have number of columns of `X`"
f" ({self._n_vars}), but has {len(varm)} rows."
f" ({self.n_vars}), but has {len(varm)} rows."
)

def write_h5ad(
Expand Down
5 changes: 4 additions & 1 deletion anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(
self._X = X.get()
else:
self._X = X
self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"])
n_var = None if self._X is None else self._X.shape[1]
self._var = _gen_dataframe(
var, ["var_names"], source="X", attr="var", length=n_var
)
self._varm = AxisArrays(self, 1, varm)
elif X is None: # construct from adata
# Move from GPU to CPU since it's large and not always used
Expand Down
45 changes: 42 additions & 3 deletions anndata/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from itertools import product
import re
import warnings

import numpy as np
Expand Down Expand Up @@ -39,9 +42,6 @@ def test_creation():
assert adata.raw.X.tolist() == X.tolist()
assert adata.raw.var_names.tolist() == ["a", "b", "c"]

with pytest.raises(ValueError):
AnnData(np.array([[1, 2], [3, 4]]), dict(TooLong=[1, 2, 3, 4]))

# init with empty data matrix
shape = (3, 5)
adata = AnnData(None, uns=dict(test=np.array((3, 3))), shape=shape)
Expand All @@ -50,6 +50,45 @@ def test_creation():
assert "test" in adata.uns


@pytest.mark.parametrize(
("src", "src_arg", "dim_msg"),
[
pytest.param(
"X",
adata_dense.X,
"`{dim}` must have as many rows as `X` has {mat_dim}s",
id="x",
),
pytest.param(
"shape", (2, 2), "`shape` is inconsistent with `{dim}`", id="shape"
),
],
)
@pytest.mark.parametrize("dim", ["obs", "var"])
@pytest.mark.parametrize(
("dim_arg", "msg"),
[
pytest.param(
lambda _: dict(TooLong=[1, 2, 3, 4]),
"Length of values (4) does not match length of index (2)",
id="too_long_col",
),
pytest.param(
lambda dim: {f"{dim}_names": ["a", "b", "c"]}, None, id="too_many_names"
),
pytest.param(
lambda _: pd.DataFrame(index=["a", "b", "c"]), None, id="too_long_df"
),
],
)
def test_creation_error(src, src_arg, dim_msg, dim, dim_arg, msg: str | None):
if msg is None:
mat_dim = "row" if dim == "obs" else "column"
msg = dim_msg.format(dim=dim, mat_dim=mat_dim)
with pytest.raises(ValueError, match=re.escape(msg)):
AnnData(**{src: src_arg, dim: dim_arg(dim)})


def test_create_with_dfs():
X = np.ones((6, 3))
obs = pd.DataFrame(dict(cat_anno=pd.Categorical(["a", "a", "a", "a", "b", "a"])))
Expand Down
18 changes: 9 additions & 9 deletions anndata/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


@pytest.fixture
def adata_raw():
def adata_raw() -> ad.AnnData:
adata = ad.AnnData(
np.array(data, dtype="int32"), obs=obs_dict, var=var_dict, uns=uns_dict
)
Expand All @@ -48,18 +48,18 @@ def adata_raw():
# -------------------------------------------------------------------------------


def test_raw_init(adata_raw):
def test_raw_init(adata_raw: ad.AnnData):
assert adata_raw.var_names.tolist() == ["var1", "var2"]
assert adata_raw.raw.var_names.tolist() == ["var1", "var2", "var3"]
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_del(adata_raw):
def test_raw_del(adata_raw: ad.AnnData):
del adata_raw.raw
assert adata_raw.raw is None


def test_raw_set_as_none(adata_raw):
def test_raw_set_as_none(adata_raw: ad.AnnData):
# Test for scverse/anndata#445
a = adata_raw
b = adata_raw.copy()
Expand All @@ -70,15 +70,15 @@ def test_raw_set_as_none(adata_raw):
assert_equal(a, b)


def test_raw_of_view(adata_raw):
def test_raw_of_view(adata_raw: ad.AnnData):
adata_view = adata_raw[adata_raw.obs["oanno1"] == "cat2"]
assert adata_view.raw.X.tolist() == [
[4, 5, 6],
[7, 8, 9],
]


def test_raw_rw(adata_raw, backing_h5ad):
def test_raw_rw(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.write(backing_h5ad)
adata_read = ad.read(backing_h5ad)

Expand All @@ -89,7 +89,7 @@ def test_raw_rw(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_view_rw(adata_raw, backing_h5ad):
def test_raw_view_rw(adata_raw: ad.AnnData, backing_h5ad):
# Make sure it still writes correctly if the object is a view
adata_raw_view = adata_raw[:, adata_raw.var_names]
assert_equal(adata_raw_view, adata_raw)
Expand All @@ -104,7 +104,7 @@ def test_raw_view_rw(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_backed(adata_raw, backing_h5ad):
def test_raw_backed(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.filename = backing_h5ad

assert adata_raw.var_names.tolist() == ["var1", "var2"]
Expand All @@ -114,7 +114,7 @@ def test_raw_backed(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X[:].tolist() == [[1], [4], [7]]


def test_raw_view_backed(adata_raw, backing_h5ad):
def test_raw_view_backed(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.filename = backing_h5ad

assert adata_raw.var_names.tolist() == ["var1", "var2"]
Expand Down
11 changes: 11 additions & 0 deletions anndata/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,14 @@ def test_copy_X_dtype():
adata = ad.AnnData(sparse.eye(50, dtype=np.float64, format="csr"))
adata_c = adata[::2].copy()
assert adata_c.X.dtype == adata.X.dtype


def test_x_none():
orig = ad.AnnData(obs=pd.DataFrame(index=np.arange(50)))
assert orig.shape == (50, 0)
view = orig[2:4]
assert view.shape == (2, 0)
assert view.obs_names.tolist() == ["2", "3"]
new = view.copy()
assert new.shape == (2, 0)
assert new.obs_names.tolist() == ["2", "3"]
Loading

0 comments on commit 88dd129

Please sign in to comment.