From 7d24f8866501ad119cdabaa4008869b35ad77ec4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 2 Sep 2024 15:54:08 +0200 Subject: [PATCH] (chore): ensure views of anndata produce view classes --- docs/tutorials/notebooks | 2 +- src/anndata/_core/storage.py | 4 ++++ src/anndata/_core/views.py | 28 +++++++++++++-------------- tests/test_views.py | 37 ++++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 15 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 9e186c5c6..a81e79650 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 9e186c5c694793bb04ea1397721d154d6e0b7069 +Subproject commit a81e79650daeba3f10b2bc5ff2b5da938838a4ec diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index adf51077d..bd4625515 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -12,6 +12,8 @@ from ..compat import ( AwkArray, CupyArray, + CupyCSCMatrix, + CupyCSRMatrix, CupySparseMatrix, DaskArray, H5Array, @@ -43,6 +45,8 @@ CSCDataset, DaskArray, CupyArray, + CupyCSCMatrix, + CupyCSRMatrix, CupySparseMatrix, ] diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index ca9af9164..67cd4bfb7 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -294,64 +294,64 @@ def as_view(obj, view_args): @as_view.register(np.ndarray) -def as_view_array(array, view_args): +def as_view_array(array, view_args) -> ArrayView: return ArrayView(array, view_args=view_args) @as_view.register(DaskArray) -def as_view_dask_array(array, view_args): +def as_view_dask_array(array, view_args) -> DaskArrayView: return DaskArrayView(array, view_args=view_args) @as_view.register(pd.DataFrame) -def as_view_df(df, view_args): +def as_view_df(df, view_args) -> DataFrameView: return DataFrameView(df, view_args=view_args) @as_view.register(sparse.csr_matrix) -def as_view_csr_matrix(mtx, view_args): +def as_view_csr_matrix(mtx, view_args) -> SparseCSRMatrixView: return SparseCSRMatrixView(mtx, view_args=view_args) @as_view.register(sparse.csc_matrix) -def as_view_csc_matrix(mtx, view_args): +def as_view_csc_matrix(mtx, view_args) -> SparseCSCMatrixView: return SparseCSCMatrixView(mtx, view_args=view_args) @as_view.register(sparse.csr_array) -def as_view_csr_array(mtx, view_args): +def as_view_csr_array(mtx, view_args) -> SparseCSRArrayView: return SparseCSRArrayView(mtx, view_args=view_args) @as_view.register(sparse.csc_array) -def as_view_csc_array(mtx, view_args): +def as_view_csc_array(mtx, view_args) -> SparseCSCArrayView: return SparseCSCArrayView(mtx, view_args=view_args) @as_view.register(dict) -def as_view_dict(d, view_args): +def as_view_dict(d, view_args) -> DictView: return DictView(d, view_args=view_args) @as_view.register(ZappyArray) -def as_view_zappy(z, view_args): +def as_view_zappy(z, view_args) -> ZappyArray: # Previous code says ZappyArray works as view, # but as far as I can tell they’re immutable. return z @as_view.register(CupyArray) -def as_view_cupy(array, view_args): +def as_view_cupy(array, view_args) -> CupyArrayView: return CupyArrayView(array, view_args=view_args) @as_view.register(CupyCSRMatrix) -def as_view_cupy_csr(mtx, view_args): +def as_view_cupy_csr(mtx, view_args) -> CupySparseCSRView: return CupySparseCSRView(mtx, view_args=view_args) @as_view.register(CupyCSCMatrix) -def as_view_cupy_csc(mtx, view_args): +def as_view_cupy_csc(mtx, view_args) -> CupySparseCSCView: return CupySparseCSCView(mtx, view_args=view_args) @@ -373,7 +373,7 @@ def _view_args(self): to be attached as "behavior". These "behaviors" cannot take any additional parameters (as we do for other data types to store `_view_args`). Therefore, we need to store `_view_args` using awkward's parameter mechanism. These parameters need to be json-serializable, which is why we can't store - ElementRef directly, but need to replace the reference to the parent AnnDataView container with a weak + ElementRef directly, but need to replace the reference to the parent AnnData container with a weak reference. """ parent_key, attrname, keys = self.layout.parameter(_PARAM_NAME) @@ -394,7 +394,7 @@ def __copy__(self) -> AwkArray: return array @as_view.register(AwkArray) - def as_view_awkarray(array, view_args): + def as_view_awkarray(array, view_args) -> AwkwardArrayView: parent, attrname, keys = view_args parent_key = f"target-{id(parent)}" _registry[parent_key] = parent diff --git a/tests/test_views.py b/tests/test_views.py index 2d4a0a78d..bcb1062cc 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,6 +3,7 @@ from contextlib import ExitStack from copy import deepcopy from operator import mul +from typing import get_type_hints import joblib import numpy as np @@ -786,6 +787,42 @@ def test_dataframe_view_index_setting(): assert a2.obs.index.values.tolist() == ["a", "b"] +def test_elem_view_class(): + """ + Ensure that: + + (a) AnnData views actually produce view classes + (b) Produced view classes are subtypes of their original type + which then allows distinguishing views from non-views. + + This test tries to then guarantee that `my_adata.is_view and isinstance(my_adata.obsm['my_array'], BaseArrayClass)` + tells a user that they are working with a view class of `obsm['my_array']` that inherits from the base class (and has its methods). + """ + orig = gen_adata((10, 10)) + subset = orig[:8, :8] + assert subset.is_view + registry = ad._core.views.as_view.registry + as_view_funcs = registry.values() + base_classes = registry.keys() + # Use set membership to ensure the *actual* view class is used + view_types = set( + get_type_hints(func)["return"] + for func in as_view_funcs + if "return" in func.__annotations__ + ) + base_types = tuple(base_classes) + assert type(subset.obs) in view_types + assert type(subset.var) in view_types + for view_data in ( + *subset.obsm.values(), + *subset.layers.values(), + *subset.obsp.values(), + ): + view_data_type = type(view_data) + assert view_data_type in view_types + assert isinstance(view_data, base_types) + + # @pytest.mark.parametrize("dim", ["obs", "var"]) # @pytest.mark.parametrize( # ("idx", "pat"),