Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/get df with default args #1231

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 68 additions & 56 deletions scvelo/core/_anndata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from typing import List, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -178,6 +179,13 @@ def get_df(
:class:`pd.DataFrame`
A dataframe.
"""
warnings.warn(
"`get_df` is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please `AnnData::get_df` or Scanpy's `scanpy.get.obs_df` or `scanpy.get.var_df`.",
DeprecationWarning,
stacklevel=2,
)

if precision is not None:
pd.set_option("display.precision", precision)

Expand All @@ -188,8 +196,6 @@ def get_df(
keys, key_add = (
keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None)
)
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
d_keys = [
Expand All @@ -207,62 +213,68 @@ def get_df(

if keys is None:
df = data.to_df()
elif key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.")
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

if key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(
f"'{key}' not found in any of {', '.join(s_keys)}."
)
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
elif isinstance(data, pd.DataFrame):
if isinstance(keys, str) and "*" in keys:
keys, keys_split = keys.split("*")
Expand Down
14 changes: 14 additions & 0 deletions tests/core/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,20 @@ def test_data_as_array(
else:
assert (df.columns == ["col_1", "col_2"]).all()

@given(
adata=get_adata(
max_obs=5,
max_vars=5,
layer_keys=["layer_1", "layer_2"],
),
modality=st.sampled_from([None, "X", "layer_1", "layer_2"]),
)
def test_default(self, adata: AnnData, modality: Optional[None]):
df = get_df(adata, layer=modality)

assert isinstance(df, pd.DataFrame)
np.testing.assert_equal(adata.to_df().values, df.values)


class TestGetInitialSize(TestBase):
@given(
Expand Down
Loading