Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(chore): ensure views of anndata produce distinguishable view classes #1637

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
28 changes: 14 additions & 14 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,64 +294,64 @@ def as_view(obj, view_args):


@as_view.register(np.ndarray)
def as_view_array(array, view_args):
def as_view_array(array, view_args) -> ArrayView:
return ArrayView(array, view_args=view_args)


@as_view.register(DaskArray)
def as_view_dask_array(array, view_args):
def as_view_dask_array(array, view_args) -> DaskArrayView:
return DaskArrayView(array, view_args=view_args)


@as_view.register(pd.DataFrame)
def as_view_df(df, view_args):
def as_view_df(df, view_args) -> DataFrameView:
return DataFrameView(df, view_args=view_args)


@as_view.register(sparse.csr_matrix)
def as_view_csr_matrix(mtx, view_args):
def as_view_csr_matrix(mtx, view_args) -> SparseCSRMatrixView:
return SparseCSRMatrixView(mtx, view_args=view_args)


@as_view.register(sparse.csc_matrix)
def as_view_csc_matrix(mtx, view_args):
def as_view_csc_matrix(mtx, view_args) -> SparseCSCMatrixView:
return SparseCSCMatrixView(mtx, view_args=view_args)


@as_view.register(sparse.csr_array)
def as_view_csr_array(mtx, view_args):
def as_view_csr_array(mtx, view_args) -> SparseCSRArrayView:
return SparseCSRArrayView(mtx, view_args=view_args)


@as_view.register(sparse.csc_array)
def as_view_csc_array(mtx, view_args):
def as_view_csc_array(mtx, view_args) -> SparseCSCArrayView:
return SparseCSCArrayView(mtx, view_args=view_args)


@as_view.register(dict)
def as_view_dict(d, view_args):
def as_view_dict(d, view_args) -> DictView:
return DictView(d, view_args=view_args)


@as_view.register(ZappyArray)
def as_view_zappy(z, view_args):
def as_view_zappy(z, view_args) -> ZappyArray:
# Previous code says ZappyArray works as view,
# but as far as I can tell they’re immutable.
return z


@as_view.register(CupyArray)
def as_view_cupy(array, view_args):
def as_view_cupy(array, view_args) -> CupyArrayView:
return CupyArrayView(array, view_args=view_args)


@as_view.register(CupyCSRMatrix)
def as_view_cupy_csr(mtx, view_args):
def as_view_cupy_csr(mtx, view_args) -> CupySparseCSRView:
return CupySparseCSRView(mtx, view_args=view_args)


@as_view.register(CupyCSCMatrix)
def as_view_cupy_csc(mtx, view_args):
def as_view_cupy_csc(mtx, view_args) -> CupySparseCSCView:
return CupySparseCSCView(mtx, view_args=view_args)


Expand All @@ -373,7 +373,7 @@ def _view_args(self):
to be attached as "behavior". These "behaviors" cannot take any additional parameters (as we do
for other data types to store `_view_args`). Therefore, we need to store `_view_args` using awkward's
parameter mechanism. These parameters need to be json-serializable, which is why we can't store
ElementRef directly, but need to replace the reference to the parent AnnDataView container with a weak
ElementRef directly, but need to replace the reference to the parent AnnData container with a weak
reference.
"""
parent_key, attrname, keys = self.layout.parameter(_PARAM_NAME)
Expand All @@ -394,7 +394,7 @@ def __copy__(self) -> AwkArray:
return array

@as_view.register(AwkArray)
def as_view_awkarray(array, view_args):
def as_view_awkarray(array, view_args) -> AwkwardArrayView:
parent, attrname, keys = view_args
parent_key = f"target-{id(parent)}"
_registry[parent_key] = parent
Expand Down
37 changes: 37 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import ExitStack
from copy import deepcopy
from operator import mul
from typing import get_type_hints

import joblib
import numpy as np
Expand Down Expand Up @@ -786,6 +787,42 @@ def test_dataframe_view_index_setting():
assert a2.obs.index.values.tolist() == ["a", "b"]


def test_elem_view_class():
"""
Ensure that:

(a) AnnData views actually produce view classes
(b) Produced view classes are subtypes of their original type
which then allows distinguishing views from non-views.

This test tries to then guarantee that `my_adata.is_view and isinstance(my_adata.obsm['my_array'], BaseArrayClass)`
tells a user that they are working with a view class of `obsm['my_array']` that inherits from the base class (and has its methods).
"""
orig = gen_adata((10, 10))
subset = orig[:8, :8]
assert subset.is_view
registry = ad._core.views.as_view.registry
as_view_funcs = registry.values()
base_classes = registry.keys()
# Use set membership to ensure the *actual* view class is used
view_types = set(
get_type_hints(func)["return"]
for func in as_view_funcs
if "return" in func.__annotations__
)
base_types = tuple(base_classes)
assert type(subset.obs) in view_types
assert type(subset.var) in view_types
for view_data in (
*subset.obsm.values(),
*subset.layers.values(),
*subset.obsp.values(),
):
view_data_type = type(view_data)
assert view_data_type in view_types
assert isinstance(view_data, base_types)


# @pytest.mark.parametrize("dim", ["obs", "var"])
# @pytest.mark.parametrize(
# ("idx", "pat"),
Expand Down
Loading