Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape inference for X=None #1121

Merged
merged 19 commits into from
Sep 8, 2023
143 changes: 90 additions & 53 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,75 @@ def _check_2d_shape(X):
)


def _mk_df_error(
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
expected: int,
actual: int,
):
if source == "X":
what = "row" if attr == "obs" else "column"
msg = (
f"Observations annot. `{attr}` must have as many rows as `X` has {what}s "
f"({expected}), but has {actual} rows."
)
else:
msg = (
f"`shape` is inconsistent with `{attr}` "
"({actual} {what}s instead of {expected})"
)
return ValueError(msg)


@singledispatch
def _gen_dataframe(anno, length, index_names):
def _gen_dataframe(
anno: Mapping[str, Any],
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
) -> pd.DataFrame:
if anno is None or len(anno) == 0:
anno = {}

def mk_index(l: int) -> pd.Index:
return pd.RangeIndex(0, l, name=None).astype(str)

for index_name in index_names:
if index_name in anno:
return pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
return pd.DataFrame(
anno,
index=pd.RangeIndex(0, length, name=None).astype(str),
columns=None if len(anno) else [],
)
if index_name not in anno:
continue
df = pd.DataFrame(
anno,
index=anno[index_name],
columns=[k for k in anno.keys() if k != index_name],
)
break
else:
df = pd.DataFrame(
anno,
index=None if length is None else mk_index(length),
columns=None if len(anno) else [],
)

if length is None:
df.index = mk_index(len(df))
elif length != len(df):
raise _mk_df_error(source, attr, length, len(df))
return df


@_gen_dataframe.register(pd.DataFrame)
def _(anno, length, index_names):
def _gen_dataframe_df(
anno: pd.DataFrame,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
if length is not None and length != len(anno):
raise _mk_df_error(source, attr, length, len(anno))
anno = anno.copy(deep=False)
if not is_string_dtype(anno.index):
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
Expand All @@ -133,8 +182,15 @@ def _(anno, length, index_names):

@_gen_dataframe.register(pd.Series)
@_gen_dataframe.register(pd.Index)
def _(anno, length, index_names):
raise ValueError(f"Cannot convert {type(anno)} to DataFrame")
def _gen_dataframe_1d(
anno: pd.Series | pd.Index,
index_names: Iterable[str],
*,
source: Literal["X", "shape"],
attr: Literal["obs", "var"],
length: int | None = None,
):
raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame")


class AnnData(metaclass=utils.DeprecationMixinMeta):
Expand Down Expand Up @@ -356,8 +412,6 @@ def _init_as_view(self, adata_ref: "AnnData", oidx: Index, vidx: Index):
self._obs = DataFrameView(obs_sub, view_args=(self, "obs"))
self._var = DataFrameView(var_sub, view_args=(self, "var"))
self._uns = uns
self._n_obs = len(self.obs)
self._n_vars = len(self.var)

# set data
if self.isbacked:
Expand Down Expand Up @@ -473,27 +527,20 @@ def _init_as_actual(
X = np.array(X, dtype, copy=False)
# data matrix and shape
self._X = X
self._n_obs, self._n_vars = self._X.shape
n_obs, n_vars = self._X.shape
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
source = "X"
else:
self._X = None
self._n_obs = len([] if obs is None else obs)
self._n_vars = len([] if var is None else var)
# check consistency with shape
if shape is not None:
if self._n_obs == 0:
self._n_obs = shape[0]
else:
if self._n_obs != shape[0]:
raise ValueError("`shape` is inconsistent with `obs`")
if self._n_vars == 0:
self._n_vars = shape[1]
else:
if self._n_vars != shape[1]:
raise ValueError("`shape` is inconsistent with `var`")
n_obs, n_vars = (None, None) if shape is None else shape
source = "shape"

# annotations
self._obs = _gen_dataframe(obs, self._n_obs, ["obs_names", "row_names"])
self._var = _gen_dataframe(var, self._n_vars, ["var_names", "col_names"])
self._obs = _gen_dataframe(
obs, ["obs_names", "row_names"], source=source, attr="obs", length=n_obs
)
self._var = _gen_dataframe(
var, ["var_names", "col_names"], source=source, attr="var", length=n_vars
)

# now we can verify if indices match!
for attr_name, x_name, idx in x_indices:
Expand Down Expand Up @@ -783,12 +830,12 @@ def raw(self):
@property
def n_obs(self) -> int:
"""Number of observations."""
return self._n_obs
return len(self.obs)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved

@property
def n_vars(self) -> int:
"""Number of variables/features."""
return self._n_vars
return len(self.var)

def _set_dim_df(self, value: pd.DataFrame, attr: str):
if not isinstance(value, pd.DataFrame):
Expand Down Expand Up @@ -1855,38 +1902,28 @@ def __contains__(self, key: Any):

def _check_dimensions(self, key=None):
if key is None:
key = {"obs", "var", "obsm", "varm"}
key = {"obsm", "varm"}
else:
key = {key}
if "obs" in key and len(self._obs) != self._n_obs:
raise ValueError(
"Observations annot. `obs` must have number of rows of `X`"
f" ({self._n_obs}), but has {self._obs.shape[0]} rows."
)
if "var" in key and len(self._var) != self._n_vars:
raise ValueError(
"Variables annot. `var` must have number of columns of `X`"
f" ({self._n_vars}), but has {self._var.shape[0]} rows."
)
if "obsm" in key:
obsm = self._obsm
if (
not all([dim_len(o, 0) == self._n_obs for o in obsm.values()])
and len(obsm.dim_names) != self._n_obs
not all([dim_len(o, 0) == self.n_obs for o in obsm.values()])
and len(obsm.dim_names) != self.n_obs
):
raise ValueError(
"Observations annot. `obsm` must have number of rows of `X`"
f" ({self._n_obs}), but has {len(obsm)} rows."
f" ({self.n_obs}), but has {len(obsm)} rows."
)
if "varm" in key:
varm = self._varm
if (
not all([dim_len(v, 0) == self._n_vars for v in varm.values()])
and len(varm.dim_names) != self._n_vars
not all([dim_len(v, 0) == self.n_vars for v in varm.values()])
and len(varm.dim_names) != self.n_vars
):
raise ValueError(
"Variables annot. `varm` must have number of columns of `X`"
f" ({self._n_vars}), but has {len(varm)} rows."
f" ({self.n_vars}), but has {len(varm)} rows."
)

def write_h5ad(
Expand Down
5 changes: 4 additions & 1 deletion anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(
self._X = X.get()
else:
self._X = X
self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"])
n_var = None if self._X is None else self._X.shape[1]
self._var = _gen_dataframe(
var, ["var_names"], source="X", attr="var", length=n_var
)
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
self._varm = AxisArrays(self, 1, varm)
elif X is None: # construct from adata
# Move from GPU to CPU since it's large and not always used
Expand Down
45 changes: 42 additions & 3 deletions anndata/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from itertools import product
import re
import warnings

import numpy as np
Expand Down Expand Up @@ -39,9 +42,6 @@ def test_creation():
assert adata.raw.X.tolist() == X.tolist()
assert adata.raw.var_names.tolist() == ["a", "b", "c"]

with pytest.raises(ValueError):
AnnData(np.array([[1, 2], [3, 4]]), dict(TooLong=[1, 2, 3, 4]))

# init with empty data matrix
shape = (3, 5)
adata = AnnData(None, uns=dict(test=np.array((3, 3))), shape=shape)
Expand All @@ -50,6 +50,45 @@ def test_creation():
assert "test" in adata.uns


@pytest.mark.parametrize(
("src", "src_arg", "dim_msg"),
[
pytest.param(
"X",
adata_dense.X,
"`{dim}` must have as many rows as `X` has {mat_dim}s",
id="x",
),
pytest.param(
"shape", (2, 2), "`shape` is inconsistent with `{dim}`", id="shape"
),
],
)
@pytest.mark.parametrize("dim", ["obs", "var"])
@pytest.mark.parametrize(
("dim_arg", "msg"),
[
pytest.param(
lambda _: dict(TooLong=[1, 2, 3, 4]),
"Length of values (4) does not match length of index (2)",
id="too_long_col",
),
pytest.param(
lambda dim: {f"{dim}_names": ["a", "b", "c"]}, None, id="too_many_names"
),
pytest.param(
lambda _: pd.DataFrame(index=["a", "b", "c"]), None, id="too_long_df"
),
],
)
def test_creation_error(src, src_arg, dim_msg, dim, dim_arg, msg: str | None):
if msg is None:
mat_dim = "row" if dim == "obs" else "column"
msg = dim_msg.format(dim=dim, mat_dim=mat_dim)
with pytest.raises(ValueError, match=re.escape(msg)):
AnnData(**{src: src_arg, dim: dim_arg(dim)})


def test_create_with_dfs():
X = np.ones((6, 3))
obs = pd.DataFrame(dict(cat_anno=pd.Categorical(["a", "a", "a", "a", "b", "a"])))
Expand Down
18 changes: 9 additions & 9 deletions anndata/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


@pytest.fixture
def adata_raw():
def adata_raw() -> ad.AnnData:
adata = ad.AnnData(
np.array(data, dtype="int32"), obs=obs_dict, var=var_dict, uns=uns_dict
)
Expand All @@ -48,18 +48,18 @@ def adata_raw():
# -------------------------------------------------------------------------------


def test_raw_init(adata_raw):
def test_raw_init(adata_raw: ad.AnnData):
assert adata_raw.var_names.tolist() == ["var1", "var2"]
assert adata_raw.raw.var_names.tolist() == ["var1", "var2", "var3"]
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_del(adata_raw):
def test_raw_del(adata_raw: ad.AnnData):
del adata_raw.raw
assert adata_raw.raw is None


def test_raw_set_as_none(adata_raw):
def test_raw_set_as_none(adata_raw: ad.AnnData):
# Test for scverse/anndata#445
a = adata_raw
b = adata_raw.copy()
Expand All @@ -70,15 +70,15 @@ def test_raw_set_as_none(adata_raw):
assert_equal(a, b)


def test_raw_of_view(adata_raw):
def test_raw_of_view(adata_raw: ad.AnnData):
adata_view = adata_raw[adata_raw.obs["oanno1"] == "cat2"]
assert adata_view.raw.X.tolist() == [
[4, 5, 6],
[7, 8, 9],
]


def test_raw_rw(adata_raw, backing_h5ad):
def test_raw_rw(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.write(backing_h5ad)
adata_read = ad.read(backing_h5ad)

Expand All @@ -89,7 +89,7 @@ def test_raw_rw(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_view_rw(adata_raw, backing_h5ad):
def test_raw_view_rw(adata_raw: ad.AnnData, backing_h5ad):
# Make sure it still writes correctly if the object is a view
adata_raw_view = adata_raw[:, adata_raw.var_names]
assert_equal(adata_raw_view, adata_raw)
Expand All @@ -104,7 +104,7 @@ def test_raw_view_rw(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X.tolist() == [[1], [4], [7]]


def test_raw_backed(adata_raw, backing_h5ad):
def test_raw_backed(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.filename = backing_h5ad

assert adata_raw.var_names.tolist() == ["var1", "var2"]
Expand All @@ -114,7 +114,7 @@ def test_raw_backed(adata_raw, backing_h5ad):
assert adata_raw.raw[:, 0].X[:].tolist() == [[1], [4], [7]]


def test_raw_view_backed(adata_raw, backing_h5ad):
def test_raw_view_backed(adata_raw: ad.AnnData, backing_h5ad):
adata_raw.filename = backing_h5ad

assert adata_raw.var_names.tolist() == ["var1", "var2"]
Expand Down
11 changes: 11 additions & 0 deletions anndata/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,14 @@ def test_copy_X_dtype():
adata = ad.AnnData(sparse.eye(50, dtype=np.float64, format="csr"))
adata_c = adata[::2].copy()
assert adata_c.X.dtype == adata.X.dtype


def test_x_none():
orig = ad.AnnData(obs=pd.DataFrame(index=np.arange(50)))
assert orig.shape == (50, 0)
view = orig[2:4]
assert view.shape == (2, 0)
assert view.obs_names.tolist() == ["2", "3"]
new = view.copy()
assert new.shape == (2, 0)
assert new.obs_names.tolist() == ["2", "3"]
Loading
Loading