Skip to content

Commit

Permalink
(chore): ensure views of anndata produce view classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Sep 2, 2024
1 parent 0bc2b39 commit 7d24f88
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
4 changes: 4 additions & 0 deletions src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ..compat import (
AwkArray,
CupyArray,
CupyCSCMatrix,
CupyCSRMatrix,
CupySparseMatrix,
DaskArray,
H5Array,
Expand Down Expand Up @@ -43,6 +45,8 @@
CSCDataset,
DaskArray,
CupyArray,
CupyCSCMatrix,
CupyCSRMatrix,
CupySparseMatrix,
]

Expand Down
28 changes: 14 additions & 14 deletions src/anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,64 +294,64 @@ def as_view(obj, view_args):


@as_view.register(np.ndarray)
def as_view_array(array, view_args):
def as_view_array(array, view_args) -> ArrayView:
return ArrayView(array, view_args=view_args)


@as_view.register(DaskArray)
def as_view_dask_array(array, view_args):
def as_view_dask_array(array, view_args) -> DaskArrayView:
return DaskArrayView(array, view_args=view_args)


@as_view.register(pd.DataFrame)
def as_view_df(df, view_args):
def as_view_df(df, view_args) -> DataFrameView:
return DataFrameView(df, view_args=view_args)


@as_view.register(sparse.csr_matrix)
def as_view_csr_matrix(mtx, view_args):
def as_view_csr_matrix(mtx, view_args) -> SparseCSRMatrixView:
return SparseCSRMatrixView(mtx, view_args=view_args)


@as_view.register(sparse.csc_matrix)
def as_view_csc_matrix(mtx, view_args):
def as_view_csc_matrix(mtx, view_args) -> SparseCSCMatrixView:
return SparseCSCMatrixView(mtx, view_args=view_args)


@as_view.register(sparse.csr_array)
def as_view_csr_array(mtx, view_args):
def as_view_csr_array(mtx, view_args) -> SparseCSRArrayView:
return SparseCSRArrayView(mtx, view_args=view_args)


@as_view.register(sparse.csc_array)
def as_view_csc_array(mtx, view_args):
def as_view_csc_array(mtx, view_args) -> SparseCSCArrayView:
return SparseCSCArrayView(mtx, view_args=view_args)


@as_view.register(dict)
def as_view_dict(d, view_args):
def as_view_dict(d, view_args) -> DictView:
return DictView(d, view_args=view_args)


@as_view.register(ZappyArray)
def as_view_zappy(z, view_args):
def as_view_zappy(z, view_args) -> ZappyArray:
# Previous code says ZappyArray works as view,
# but as far as I can tell they’re immutable.
return z


@as_view.register(CupyArray)
def as_view_cupy(array, view_args):
def as_view_cupy(array, view_args) -> CupyArrayView:
return CupyArrayView(array, view_args=view_args)


@as_view.register(CupyCSRMatrix)
def as_view_cupy_csr(mtx, view_args):
def as_view_cupy_csr(mtx, view_args) -> CupySparseCSRView:
return CupySparseCSRView(mtx, view_args=view_args)


@as_view.register(CupyCSCMatrix)
def as_view_cupy_csc(mtx, view_args):
def as_view_cupy_csc(mtx, view_args) -> CupySparseCSCView:
return CupySparseCSCView(mtx, view_args=view_args)


Expand All @@ -373,7 +373,7 @@ def _view_args(self):
to be attached as "behavior". These "behaviors" cannot take any additional parameters (as we do
for other data types to store `_view_args`). Therefore, we need to store `_view_args` using awkward's
parameter mechanism. These parameters need to be json-serializable, which is why we can't store
ElementRef directly, but need to replace the reference to the parent AnnDataView container with a weak
ElementRef directly, but need to replace the reference to the parent AnnData container with a weak
reference.
"""
parent_key, attrname, keys = self.layout.parameter(_PARAM_NAME)
Expand All @@ -394,7 +394,7 @@ def __copy__(self) -> AwkArray:
return array

@as_view.register(AwkArray)
def as_view_awkarray(array, view_args):
def as_view_awkarray(array, view_args) -> AwkwardArrayView:
parent, attrname, keys = view_args
parent_key = f"target-{id(parent)}"
_registry[parent_key] = parent
Expand Down
37 changes: 37 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import ExitStack
from copy import deepcopy
from operator import mul
from typing import get_type_hints

import joblib
import numpy as np
Expand Down Expand Up @@ -786,6 +787,42 @@ def test_dataframe_view_index_setting():
assert a2.obs.index.values.tolist() == ["a", "b"]


def test_elem_view_class():
"""
Ensure that:
(a) AnnData views actually produce view classes
(b) Produced view classes are subtypes of their original type
which then allows distinguishing views from non-views.
This test tries to then guarantee that `my_adata.is_view and isinstance(my_adata.obsm['my_array'], BaseArrayClass)`
tells a user that they are working with a view class of `obsm['my_array']` that inherits from the base class (and has its methods).
"""
orig = gen_adata((10, 10))
subset = orig[:8, :8]
assert subset.is_view
registry = ad._core.views.as_view.registry
as_view_funcs = registry.values()
base_classes = registry.keys()
# Use set membership to ensure the *actual* view class is used
view_types = set(
get_type_hints(func)["return"]
for func in as_view_funcs
if "return" in func.__annotations__
)
base_types = tuple(base_classes)
assert type(subset.obs) in view_types
assert type(subset.var) in view_types
for view_data in (
*subset.obsm.values(),
*subset.layers.values(),
*subset.obsp.values(),
):
view_data_type = type(view_data)
assert view_data_type in view_types
assert isinstance(view_data, base_types)


# @pytest.mark.parametrize("dim", ["obs", "var"])
# @pytest.mark.parametrize(
# ("idx", "pat"),
Expand Down

0 comments on commit 7d24f88

Please sign in to comment.