diff --git a/.azure-pipelines.yml b/.azure-pipelines.yml index afc1b4153..4464d6d66 100644 --- a/.azure-pipelines.yml +++ b/.azure-pipelines.yml @@ -53,7 +53,7 @@ jobs: - script: printf "llvmlite>=0.43\nscanpy>=1.10.0rc1" | tee /tmp/constraints.txt displayName: "Create constraints file for `pre-release` and `latest` jobs" - - script: uv pip install --system --compile "anndata[dev,test] @ ." -c /tmp/constraints.txt + - script: uv pip install --system --compile "anndata[dev,test-full] @ ." -c /tmp/constraints.txt displayName: "Install dependencies" condition: eq(variables['DEPENDENCIES_VERSION'], 'latest') @@ -65,7 +65,7 @@ jobs: displayName: "Install minimum dependencies" condition: eq(variables['DEPENDENCIES_VERSION'], 'minimum') - - script: uv pip install -v --system --compile --pre "anndata[dev,test] @ ." -c /tmp/constraints.txt + - script: uv pip install -v --system --compile --pre "anndata[dev,test-full] @ ." -c /tmp/constraints.txt displayName: "Install dependencies release candidates" condition: eq(variables['DEPENDENCIES_VERSION'], 'pre-release') @@ -76,6 +76,10 @@ jobs: displayName: "PyTest" condition: eq(variables['TEST_TYPE'], 'standard') + - script: pytest + displayName: "PyTest (minimum)" + condition: eq(variables['DEPENDENCIES_VERSION'], 'minimum') + - script: pytest --cov --cov-report=xml --cov-context=test displayName: "PyTest (coverage)" condition: eq(variables['TEST_TYPE'], 'coverage') diff --git a/.gitignore b/.gitignore index 8124b9450..710b149e0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ /*cache/ /node_modules/ /data/ +/venv/ # Distribution / packaging /dist/ diff --git a/docs/api.md b/docs/api.md index 60cbbf61c..951786f81 100644 --- a/docs/api.md +++ b/docs/api.md @@ -131,7 +131,8 @@ Low level methods for reading and writing elements of an {class}`AnnData` object .. autosummary:: :toctree: generated/ - experimental.read_elem_as_dask + experimental.read_elem_lazy + experimental.read_lazy ``` Utilities for customizing the IO process: @@ -156,6 +157,9 @@ Types used by the former: experimental.ReadCallback experimental.WriteCallback experimental.StorageType + experimental.backed._lazy_arrays.MaskedArray + experimental.backed._lazy_arrays.CategoricalArray + experimental.backed._xarray.Dataset2D ``` ## Errors and warnings diff --git a/docs/conf.py b/docs/conf.py index f98fe5ba7..9325d91c3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -106,8 +106,6 @@ ("py:class", "anndata.compat.CupyArray"), ("py:class", "anndata.compat.CupySparseMatrix"), ("py:class", "numpy.ma.core.MaskedArray"), - ("py:class", "dask.array.core.Array"), - ("py:class", "awkward.highlevel.Array"), ("py:class", "anndata._core.sparse_dataset.BaseCompressedSparseDataset"), ("py:obj", "numpy._typing._array_like._ScalarType_co"), # https://github.com/sphinx-doc/sphinx/issues/10974 @@ -134,6 +132,7 @@ def setup(app: Sphinx): zarr=("https://zarr.readthedocs.io/en/stable", None), xarray=("https://docs.xarray.dev/en/stable", None), dask=("https://docs.dask.org/en/stable", None), + ak=("https://awkward-array.org/doc/stable/", None), ) qualname_overrides = { "h5py._hl.group.Group": "h5py.Group", @@ -144,6 +143,7 @@ def setup(app: Sphinx): "anndata._types.WriteCallback": "anndata.experimental.WriteCallback", "anndata._types.Read": "anndata.experimental.Read", "anndata._types.Write": "anndata.experimental.Write", + "awkward.highlevel.Array": "ak.Array", } autodoc_type_aliases = dict( NDArray=":data:`~numpy.typing.NDArray`", diff --git a/docs/release-notes/0.11.0rc1.md b/docs/release-notes/0.11.0rc1.md index f5a98086d..0bcf842b7 100644 --- a/docs/release-notes/0.11.0rc1.md +++ b/docs/release-notes/0.11.0rc1.md @@ -20,8 +20,8 @@ - `scipy.sparse.csr_array` and `scipy.sparse.csc_array` are now supported when constructing `AnnData` objects {user}`ilan-gold` {user}`isaac-virshup` ({pr}`1028`) - Allow `axis` parameter of e.g. {func}`anndata.concat` to accept `'obs'` and `'var'` {user}`flying-sheep` ({pr}`1244`) - Add `settings` object with methods for altering internally-used options, like checking for uniqueness on `obs`' index {user}`ilan-gold` ({pr}`1270`) +- Add {func}`~anndata.experimental.read_elem_lazy` function to handle i/o with sparse and dense arrays {user}`ilan-gold` ({pr}`1469`) - Add {attr}`~anndata.settings.remove_unused_categories` option to {attr}`anndata.settings` to override current behavior {user}`ilan-gold` ({pr}`1340`) -- Add {func}`~anndata.experimental.read_elem_as_dask` function to handle i/o with sparse and dense arrays {user}`ilan-gold` ({pr}`1469`) - Add ability to convert strings to categoricals on write in {meth}`~anndata.AnnData.write_h5ad` and {meth}`~anndata.AnnData.write_zarr` via `convert_strings_to_categoricals` parameter {user}` falexwolf` ({pr}`1474`) - Add {attr}`~anndata.settings.check_uniqueness` option to {attr}`anndata.settings` to override current behavior {user}`ilan-gold` ({pr}`1507`) - Add functionality to write from GPU {class}`dask.array.Array` to disk {user}`ilan-gold` ({pr}`1550`) diff --git a/docs/release-notes/1247.feature.md b/docs/release-notes/1247.feature.md new file mode 100644 index 000000000..c19ccf9fa --- /dev/null +++ b/docs/release-notes/1247.feature.md @@ -0,0 +1 @@ +Add {func}`~anndata.experimental.read_elem_lazy` (in place of `read_elem_as_dask`) to handle backed dataframes, sparse arrays, and dense arrays, as well as a {func}`~anndata.experimental.read_lazy` to handle reading in as much of the on-disk data as possible to produce a {class}`~anndata.AnnData` object {user}`ilan-gold` diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index f62e7967c..29bc1f1fc 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -14,4 +14,5 @@ notebooks/anncollection-annloader notebooks/anndata_dask_array notebooks/awkward-arrays notebooks/{read,write}_dispatched +notebooks/read_lazy ``` diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 9e186c5c6..0af6cf336 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 9e186c5c694793bb04ea1397721d154d6e0b7069 +Subproject commit 0af6cf3363aed1cafd317516c8393136ee6287ae diff --git a/pyproject.toml b/pyproject.toml index 127ae3dc1..245e3dc31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,9 +80,11 @@ doc = [ "sphinx_design>=0.5.0", "readthedocs-sphinx-search", # for unreleased changes - "anndata[dev-doc]", + "anndata[dev-doc,dask]", + "awkward>=2.3" ] dev-doc = ["towncrier>=24.8.0"] # release notes tool +test-full = ["anndata[test,lazy]"] test = [ "loompy>=3.0.5", "pytest>=8.2", @@ -108,6 +110,7 @@ cu12 = ["cupy-cuda12x"] cu11 = ["cupy-cuda11x"] # https://github.com/dask/dask/issues/11290 dask = ["dask[array]>=2022.09.2,<2024.8.0"] +lazy = ["xarray>=2024.06.0", "aiohttp", "requests", "zarr<3.0.0a0", "anndata[dask]"] [tool.hatch.version] source = "vcs" diff --git a/src/anndata/_core/aligned_df.py b/src/anndata/_core/aligned_df.py index 321264886..ea66a7fda 100644 --- a/src/anndata/_core/aligned_df.py +++ b/src/anndata/_core/aligned_df.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections.abc import Mapping from functools import singledispatch from typing import TYPE_CHECKING @@ -10,13 +11,26 @@ from .._warnings import ImplicitModificationWarning if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable from typing import Any, Literal @singledispatch def _gen_dataframe( - anno: Mapping[str, Any], + anno: Any, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +) -> pd.DataFrame: # pragma: no cover + raise ValueError(f"Cannot convert {type(anno)} to {attr} DataFrame") + + +@_gen_dataframe.register(Mapping) +@_gen_dataframe.register(type(None)) +def _gen_dataframe_mapping( + anno: Mapping[str, Any] | None, index_names: Iterable[str], *, source: Literal["X", "shape"], diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 7ef9f8ac4..2d48733c1 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -8,7 +8,7 @@ from collections import OrderedDict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy -from functools import partial +from functools import partial, singledispatch from pathlib import Path from textwrap import dedent from typing import TYPE_CHECKING @@ -41,7 +41,6 @@ from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset from .storage import coerce_array from .views import ( - DataFrameView, DictView, _resolve_idxs, as_view, @@ -53,7 +52,7 @@ from typing import Any, Literal from ..compat import Index1D - from ..typing import ArrayDataStructureType + from ..typing import XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView from .index import Index @@ -302,8 +301,8 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index): self._remove_unused_categories(adata_ref.obs, obs_sub, uns) self._remove_unused_categories(adata_ref.var, var_sub, uns) # set attributes - self._obs = DataFrameView(obs_sub, view_args=(self, "obs")) - self._var = DataFrameView(var_sub, view_args=(self, "var")) + self._obs = as_view(obs_sub, view_args=(self, "obs")) + self._var = as_view(var_sub, view_args=(self, "var")) self._uns = uns # set data @@ -542,7 +541,7 @@ def shape(self) -> tuple[int, int]: return self.n_obs, self.n_vars @property - def X(self) -> ArrayDataStructureType | None: + def X(self) -> XDataType | None: """Data matrix of shape :attr:`n_obs` × :attr:`n_vars`.""" if self.isbacked: if not self.file.is_open: @@ -1022,8 +1021,10 @@ def __getitem__(self, index: Index) -> AnnData: oidx, vidx = self._normalize_indices(index) return AnnData(self, oidx=oidx, vidx=vidx, asview=True) + @staticmethod + @singledispatch def _remove_unused_categories( - self, df_full: pd.DataFrame, df_sub: pd.DataFrame, uns: dict[str, Any] + df_full: pd.DataFrame, df_sub: pd.DataFrame, uns: dict[str, Any] ): for k in df_full: if not isinstance(df_full[k].dtype, pd.CategoricalDtype): diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index f1d72ce0d..6d2997289 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -44,10 +44,13 @@ def _normalize_index( | pd.Index, index: pd.Index, ) -> slice | int | np.ndarray: # ndarray of int or bool - if not isinstance(index, pd.RangeIndex): - msg = "Don’t call _normalize_index with non-categorical/string names" - assert index.dtype != float, msg - assert index.dtype != int, msg + from ..experimental.backed._compat import DataArray + + # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. + # if not isinstance(index, pd.RangeIndex): + # msg = "Don’t call _normalize_index with non-categorical/string names and non-range index" + # assert index.dtype != float, msg + # assert index.dtype != int, msg # the following is insanely slow for sequences, # we replaced it using pandas below @@ -101,6 +104,10 @@ def name_idx(i): "are not valid obs/ var names or indices." ) return positions # np.ndarray[int] + elif isinstance(indexer, DataArray): + if isinstance(indexer.data, DaskArray): + return indexer.data.compute() + return indexer.data raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}") diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 0dfa5dab2..086880bed 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -4,13 +4,12 @@ from __future__ import annotations -import typing from collections import OrderedDict from collections.abc import Callable, Mapping, MutableSet from functools import partial, reduce, singledispatch from itertools import repeat from operator import and_, or_, sub -from typing import Literal, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar from warnings import warn import numpy as np @@ -19,6 +18,7 @@ from scipy import sparse from scipy.sparse import spmatrix +from anndata._core.file_backing import to_memory from anndata._warnings import ExperimentalFeatureWarning from ..compat import ( @@ -35,12 +35,15 @@ from .anndata import AnnData from .index import _subset, make_slice -if typing.TYPE_CHECKING: - from collections.abc import Collection, Iterable, Sequence +if TYPE_CHECKING: + from collections.abc import Collection, Generator, Iterable, Sequence from typing import Any from pandas.api.extensions import ExtensionDtype + from anndata._types import Join_T + from anndata.experimental.backed._compat import DataArray, Dataset2D + T = TypeVar("T") ################### @@ -206,6 +209,8 @@ def equal_awkward(a, b) -> bool: def as_sparse(x, use_sparse_array=False): + if isinstance(x, DaskArray): + x = x.compute() if not isinstance(x, sparse.spmatrix | SpArray): if CAN_USE_SPARSE_ARRAY and use_sparse_array: return sparse.csr_array(x) @@ -225,7 +230,9 @@ def as_cp_sparse(x) -> CupySparseMatrix: return cpsparse.csr_matrix(x) -def unify_dtypes(dfs: Iterable[pd.DataFrame]) -> list[pd.DataFrame]: +def unify_dtypes( + dfs: Iterable[pd.DataFrame | Dataset2D], +) -> list[pd.DataFrame | Dataset2D]: """ Attempts to unify datatypes from multiple dataframes. @@ -302,7 +309,7 @@ def try_unifying_dtype( return None -def check_combinable_cols(cols: list[pd.Index], join: Literal["inner", "outer"]): +def check_combinable_cols(cols: list[pd.Index], join: Join_T): """Given columns for a set of dataframes, checks if the can be combined. Looks for if there are duplicated column names that would show up in the result. @@ -706,9 +713,7 @@ def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None): return el[self.old_idx.get_indexer(self.new_idx)] -def merge_indices( - inds: Iterable[pd.Index], join: Literal["inner", "outer"] -) -> pd.Index: +def merge_indices(inds: Iterable[pd.Index], join: Join_T) -> pd.Index: if join == "inner": return reduce(lambda x, y: x.intersection(y), inds) elif join == "outer": @@ -765,10 +770,20 @@ def np_bool_to_pd_bool_array(df: pd.DataFrame): def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): + from anndata.experimental.backed._compat import Dataset2D + arrays = list(arrays) if fill_value is None: fill_value = default_fill_value(arrays) + if any(isinstance(a, Dataset2D) for a in arrays): + if any(isinstance(a, pd.DataFrame) for a in arrays): + arrays = [to_memory(a) if isinstance(a, Dataset2D) else a for a in arrays] + elif not all(isinstance(a, Dataset2D) for a in arrays): + msg = f"Cannot concatenate a Dataset2D with other array types {[type(a) for a in arrays if not isinstance(a, Dataset2D)]}." + raise ValueError(msg) + else: + return concat_dataset2d_on_annot_axis(arrays, join="outer") if any(isinstance(a, pd.DataFrame) for a in arrays): # TODO: This is hacky, 0 is a sentinel for outer_concat_aligned_mapping if not all( @@ -1060,11 +1075,162 @@ def concat_Xs(adatas, reindexers, axis, fill_value): return concat_arrays(Xs, reindexers, axis=axis, fill_value=fill_value) +def make_dask_col_from_extension_dtype( + col: DataArray, use_only_object_dtype: bool = False +) -> DaskArray: + """ + Creates dask arrays from :class:`pandas.api.extensions.ExtensionArray` dtype :class:`xarray.DataArray`s. + + Parameters + ---------- + col + The columns to be converted + use_only_object_dtype, optional + Whether or not to cast all :class:`pandas.api.extensions.ExtensionArray` dtypes to `object` type, by default False + + Returns + ------- + A :class:`dask.Array`: representation of the column. + """ + import dask.array as da + + from anndata._io.specs.lazy_methods import ( + compute_chunk_layout_for_axis_size, + maybe_open_h5, + ) + from anndata.experimental import read_lazy + from anndata.experimental.backed._compat import DataArray + from anndata.experimental.backed._compat import xarray as xr + + base_path_or_zarr_group = col.attrs.get("base_path_or_zarr_group") + elem_name = col.attrs.get("elem_name") + dims = col.dims + coords = col.coords.copy() + + def get_chunk(block_info=None): + with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: + v = read_lazy(f) + variable = xr.Variable( + data=xr.core.indexing.LazilyIndexedArray(v), dims=dims + ) + data_array = DataArray( + variable, + coords=coords, + dims=dims, + ) + idx = tuple( + slice(start, stop) for start, stop in block_info[None]["array-location"] + ) + chunk = np.array(data_array.data[idx].array) + return chunk + + if col.dtype == "category" or use_only_object_dtype: + dtype = "object" + else: + dtype = col.dtype.numpy_dtype + # TODO: get good chunk size? + return da.map_blocks( + get_chunk, + chunks=(compute_chunk_layout_for_axis_size(1000, col.shape[0]),), + meta=np.array([], dtype=dtype), + dtype=dtype, + ) + + +def make_xarray_extension_dtypes_dask( + annotations: Iterable[Dataset2D], use_only_object_dtype: bool = False +) -> Generator[Dataset2D, None, None]: + """ + Creates a generator of Dataset2D objects with dask arrays in place of :class:`pandas.api.extensions.ExtensionArray` dtype columns. + + Parameters + ---------- + annotations + The datasets to be altered + use_only_object_dtype, optional + Whether or not to cast all :class:`pandas.api.extensions.ExtensionArray` dtypes to `object` type, by default False + + Yields + ------ + An altered dataset. + """ + for a in annotations: + extension_cols = set( + filter(lambda col: pd.api.types.is_extension_array_dtype(a[col]), a.columns) + ) + + yield a.copy( + data={ + name: ( + make_dask_col_from_extension_dtype(col, use_only_object_dtype) + if name in extension_cols + else col + ) + for name, col in a.items() + } + ) + + +def get_attrs(annotations: Iterable[Dataset2D]) -> dict: + """Generate the `attrs` from `annotations`. + + Parameters + ---------- + annotations + The datasets with `attrs`. + + Returns + ------- + `attrs`. + """ + index_names = np.unique([a.index.name for a in annotations]) + assert len(index_names) == 1, "All annotations must have the same index name." + if any(a.index.dtype == "int64" for a in annotations): + msg = "Concatenating with a pandas numeric index among the indices. Index may likely not be unique." + warn(msg, UserWarning) + index_keys = [ + a.attrs["indexing_key"] for a in annotations if "indexing_key" in a.attrs + ] + attrs = {} + if len(np.unique(index_keys)) == 1: + attrs["indexing_key"] = index_keys[0] + return attrs + + +def concat_dataset2d_on_annot_axis( + annotations: Iterable[Dataset2D], + join: Join_T, +): + """Create a concatenate dataset from a list of :class:`~anndata.experimental.backed._xarray.Dataset2D` objects. + + Parameters + ---------- + annotations + The :class:`~anndata.experimental.backed._xarray.Dataset2D` objects to be concatenated. + join + Type of join operation + + Returns + ------- + Concatenated :class:`~anndata.experimental.backed._xarray.Dataset2D` + """ + from anndata.experimental.backed._compat import Dataset2D + from anndata.experimental.backed._compat import xarray as xr + + annotations_with_only_dask = list(make_xarray_extension_dtypes_dask(annotations)) + attrs = get_attrs(annotations_with_only_dask) + index_name = np.unique([a.index.name for a in annotations])[0] + [index_name] = {a.index.name for a in annotations} + return Dataset2D( + xr.concat(annotations_with_only_dask, join=join, dim=index_name), attrs=attrs + ) + + def concat( - adatas: Collection[AnnData] | typing.Mapping[str, AnnData], + adatas: Collection[AnnData] | Mapping[str, AnnData], *, axis: Literal["obs", 0, "var", 1] = "obs", - join: Literal["inner", "outer"] = "inner", + join: Join_T = "inner", merge: StrategiesLiteral | Callable | None = None, uns_merge: StrategiesLiteral | Callable | None = None, label: str | None = None, @@ -1097,6 +1263,8 @@ def concat( * `"unique"`: Elements for which there is only one possible value. * `"first"`: The first element seen at each from each position. * `"only"`: Elements that show up in only one of the objects. + + For :class:`xarray.Dataset` objects, we use their :func:`xarray.merge` with `override` to stay lazy. uns_merge How the elements of `.uns` are selected. Uses the same set of strategies as the `merge` argument, except applied recursively. @@ -1261,6 +1429,10 @@ def concat( >>> dict(ad.concat([a, b, c], uns_merge="first").uns) {'a': 1, 'b': 2, 'c': {'c.a': 3, 'c.b': 4, 'c.c': 5}} """ + + from anndata.experimental.backed._compat import Dataset2D + from anndata.experimental.backed._compat import xarray as xr + # Argument normalization merge = resolve_merge_strategy(merge) uns_merge = resolve_merge_strategy(uns_merge) @@ -1306,19 +1478,49 @@ def concat( # Annotation for concatenation axis check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join) - concat_annot = pd.concat( - unify_dtypes(getattr(a, axis_name) for a in adatas), - join=join, - ignore_index=True, + annotations = [getattr(a, axis_name) for a in adatas] + are_any_annotations_dataframes = any( + isinstance(a, pd.DataFrame) for a in annotations ) + if are_any_annotations_dataframes: + annotations_in_memory = ( + a.to_pandas() if isinstance(a, Dataset2D) else a for a in annotations + ) + concat_annot = pd.concat( + unify_dtypes(annotations_in_memory), + join=join, + ignore_index=True, + ) + else: + concat_annot = concat_dataset2d_on_annot_axis(annotations, join) concat_annot.index = concat_indices if label is not None: concat_annot[label] = label_col # Annotation for other axis - alt_annot = merge_dataframes( - [getattr(a, alt_axis_name) for a in adatas], alt_indices, merge + alt_annotations = [getattr(a, alt_axis_name) for a in adatas] + are_any_alt_annotations_dataframes = any( + isinstance(a, pd.DataFrame) for a in alt_annotations ) + if are_any_alt_annotations_dataframes: + alt_annotations_in_memory = [ + a.to_pandas() if isinstance(a, Dataset2D) else a for a in alt_annotations + ] + alt_annot = merge_dataframes(alt_annotations_in_memory, alt_indices, merge) + else: + # TODO: figure out mapping of our merge to theirs instead of just taking first, although this appears to be + # the only "lazy" setting so I'm not sure we really want that. + # Because of xarray's merge upcasting, it's safest to simply assume that all dtypes are objects. + annotations_with_only_dask = list( + make_xarray_extension_dtypes_dask( + alt_annotations, use_only_object_dtype=True + ) + ) + attrs = get_attrs(annotations_with_only_dask) + alt_annot = Dataset2D( + xr.merge(annotations_with_only_dask, join=join, compat="override"), + attrs=attrs, + ) X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) diff --git a/src/anndata/_core/storage.py b/src/anndata/_core/storage.py index 9e036ba44..047ef4f02 100644 --- a/src/anndata/_core/storage.py +++ b/src/anndata/_core/storage.py @@ -25,6 +25,15 @@ def coerce_array( allow_df: bool = False, allow_array_like: bool = False, ): + try: + from anndata.experimental.backed._compat import Dataset2D + except ImportError: + + class Dataset2D: + @staticmethod + def __repr__(): + return "mock anndata.experimental.backed._xarray." + """Coerce arrays stored in layers/X, and aligned arrays ({obs,var}{m,p}).""" from ..typing import ArrayDataStructureType @@ -33,7 +42,7 @@ def coerce_array( return value # If value is one of the allowed types, return it array_data_structure_types = get_args(ArrayDataStructureType) - if isinstance(value, array_data_structure_types): + if isinstance(value, (*array_data_structure_types, Dataset2D)): if isinstance(value, np.matrix): msg = f"{name} should not be a np.matrix, use np.ndarray instead." warnings.warn(msg, ImplicitModificationWarning) diff --git a/src/anndata/_core/views.py b/src/anndata/_core/views.py index ca9af9164..e8a214fac 100644 --- a/src/anndata/_core/views.py +++ b/src/anndata/_core/views.py @@ -13,6 +13,7 @@ from anndata._warnings import ImplicitModificationWarning +from .._settings import settings from ..compat import ( AwkArray, CupyArray, @@ -305,6 +306,11 @@ def as_view_dask_array(array, view_args): @as_view.register(pd.DataFrame) def as_view_df(df, view_args): + if settings.remove_unused_categories: + for col in df.columns: + if isinstance(df[col].dtype, pd.CategoricalDtype): + with pd.option_context("mode.chained_assignment", None): + df[col] = df[col].cat.remove_unused_categories() return DataFrameView(df, view_args=view_args) diff --git a/src/anndata/_io/specs/__init__.py b/src/anndata/_io/specs/__init__.py index 5eadfdb50..8fd9898a3 100644 --- a/src/anndata/_io/specs/__init__.py +++ b/src/anndata/_io/specs/__init__.py @@ -9,7 +9,7 @@ Writer, get_spec, read_elem, - read_elem_as_dask, + read_elem_lazy, write_elem, ) @@ -19,7 +19,7 @@ "write_elem", "get_spec", "read_elem", - "read_elem_as_dask", + "read_elem_lazy", "Reader", "Writer", "IOSpec", diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 48770be9c..4c0583577 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from contextlib import contextmanager from functools import partial from pathlib import Path @@ -7,22 +8,25 @@ import h5py import numpy as np +import pandas as pd from scipy import sparse import anndata as ad +from anndata._core.file_backing import filename, get_elem_name +from anndata.compat import DaskArray, H5Array, H5Group, ZarrArray, ZarrGroup -from ..._core.file_backing import filename, get_elem_name -from ...compat import H5Array, H5Group, ZarrArray, ZarrGroup from .registry import _LAZY_REGISTRY, IOSpec if TYPE_CHECKING: from collections.abc import Callable, Generator, Mapping, Sequence from typing import Literal, ParamSpec, TypeVar + from anndata.experimental.backed._compat import DataArray, Dataset2D + from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray + from ..._core.sparse_dataset import _CSCDataset, _CSRDataset from ..._types import ArrayStorageType, StorageType - from ...compat import DaskArray - from .registry import DaskReader + from .registry import DaskReader, LazyDataStructures, LazyReader BlockInfo = Mapping[ Literal[None], @@ -50,11 +54,11 @@ def maybe_open_h5( _DEFAULT_STRIDE = 1000 -def compute_chunk_layout_for_axis_shape( - chunk_axis_shape: int, full_axis_shape: int +def compute_chunk_layout_for_axis_size( + chunk_axis_size: int, full_axis_size: int ) -> tuple[int, ...]: - n_strides, rest = np.divmod(full_axis_shape, chunk_axis_shape) - chunk = (chunk_axis_shape,) * n_strides + n_strides, rest = np.divmod(full_axis_size, chunk_axis_size) + chunk = (chunk_axis_size,) * n_strides if rest > 0: chunk += (rest,) return chunk @@ -113,7 +117,7 @@ def read_sparse_as_dask( stride = chunks[major_dim] shape_minor, shape_major = shape if is_csc else shape[::-1] - chunks_major = compute_chunk_layout_for_axis_shape(stride, shape_major) + chunks_major = compute_chunk_layout_for_axis_size(stride, shape_major) chunks_minor = (shape_minor,) chunk_layout = ( (chunks_minor, chunks_major) if is_csc else (chunks_major, chunks_minor) @@ -131,9 +135,26 @@ def read_sparse_as_dask( return da_mtx +@_LAZY_REGISTRY.register_read(H5Array, IOSpec("string-array", "0.2.0")) +def read_h5_string_array( + elem: H5Array, + *, + _reader: LazyReader, + chunks: tuple[int, int] | None = None, +) -> DaskArray: + import dask.array as da + + from anndata._io.h5ad import read_dataset + + return da.from_array( + read_dataset(elem), + chunks=chunks if chunks is not None else (_DEFAULT_STRIDE,) * len(elem.shape), + ) + + @_LAZY_REGISTRY.register_read(H5Array, IOSpec("array", "0.2.0")) def read_h5_array( - elem: H5Array, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None + elem: H5Array, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None ) -> DaskArray: import dask.array as da @@ -146,7 +167,7 @@ def read_h5_array( ) chunk_layout = tuple( - compute_chunk_layout_for_axis_shape(chunks[i], shape[i]) + compute_chunk_layout_for_axis_size(chunks[i], shape[i]) for i in range(len(shape)) ) @@ -154,11 +175,156 @@ def read_h5_array( return da.map_blocks(make_chunk, dtype=dtype, chunks=chunk_layout) +@_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("string-array", "0.2.0")) @_LAZY_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0")) def read_zarr_array( - elem: ZarrArray, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None + elem: ZarrArray, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None ) -> DaskArray: chunks: tuple[int, ...] = chunks if chunks is not None else elem.chunks import dask.array as da return da.from_zarr(elem, chunks=chunks) + + +DUMMY_RANGE_INDEX_KEY = "_anndata_dummy_range_index" + + +def _gen_xarray_dict_iterator_from_elems( + elem_dict: dict[str, LazyDataStructures], + index_label: str, + index_key: str, + index: np.NDArray, +) -> Generator[tuple[str, DataArray], None, None]: + from anndata.experimental.backed._compat import DataArray + from anndata.experimental.backed._compat import xarray as xr + from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray + + for k, v in elem_dict.items(): + data_array_name = k + if isinstance(v, DaskArray) and k != index_key: + data_array = DataArray(v, coords=[index], dims=[index_label], name=k) + elif isinstance(v, CategoricalArray | MaskedArray) and k != index_key: + variable = xr.Variable( + data=xr.core.indexing.LazilyIndexedArray(v), dims=[index_label] + ) + data_array = DataArray( + variable, + coords=[index], + dims=[index_label], + name=k, + attrs={ + "base_path_or_zarr_group": v.base_path_or_zarr_group, + "elem_name": v.elem_name, + }, + ) + elif k == index_key: + data_array = DataArray( + index, coords=[index], dims=[index_label], name=index_label + ) + data_array_name = index_label + else: + raise ValueError(f"Could not read {k}: {v} from into xarray Dataset2D") + yield data_array_name, data_array + if index_key == DUMMY_RANGE_INDEX_KEY: + yield ( + index_label, + DataArray(index, coords=[index], dims=[index_label], name=index_label), + ) + + +@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("dataframe", "0.2.0")) +@_LAZY_REGISTRY.register_read(H5Group, IOSpec("dataframe", "0.2.0")) +def read_dataframe( + elem: H5Group | ZarrGroup, + *, + _reader: LazyReader, + use_range_index: bool = False, +) -> Dataset2D: + from anndata.experimental.backed._compat import Dataset2D + + elem_dict = { + k: _reader.read_elem(elem[k]) + for k in [*elem.attrs["column-order"], elem.attrs["_index"]] + } + elem_name = get_elem_name(elem) + # Determine whether we can use label based indexing i.e., is the elem `obs` or `var` + obs_var_matches = re.findall(r"(obs|var)", elem_name) + if not len(obs_var_matches) == 1: + label_based_indexing_key = "index" + else: + label_based_indexing_key = f"{obs_var_matches[0]}_names" + # If we are not using a range index, the underlying on disk label for the index + # could be different than {obs,var}_names - otherwise we use a dummy value. + if not use_range_index: + index_label = label_based_indexing_key + index_key = elem.attrs["_index"] + # no sense in reading this in multiple times + index = elem_dict[index_key].compute() + else: + index_label = DUMMY_RANGE_INDEX_KEY + index_key = DUMMY_RANGE_INDEX_KEY + index = pd.RangeIndex(len(elem_dict[elem.attrs["_index"]])) + elem_xarray_dict = dict( + _gen_xarray_dict_iterator_from_elems(elem_dict, index_label, index_key, index) + ) + ds = Dataset2D(elem_xarray_dict, attrs={"indexing_key": label_based_indexing_key}) + if use_range_index: + return ds.rename_vars({elem.attrs["_index"]: label_based_indexing_key}) + return ds + + +@_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("categorical", "0.2.0")) +@_LAZY_REGISTRY.register_read(H5Group, IOSpec("categorical", "0.2.0")) +def read_categorical( + elem: H5Group | ZarrGroup, + *, + _reader: LazyReader, +) -> CategoricalArray: + from anndata.experimental.backed._lazy_arrays import CategoricalArray + + base_path_or_zarr_group = ( + Path(filename(elem)) if isinstance(elem, H5Group) else elem + ) + elem_name = get_elem_name(elem) + return CategoricalArray( + codes=elem["codes"], + categories=elem["categories"], + ordered=elem.attrs["ordered"], + base_path_or_zarr_group=base_path_or_zarr_group, + elem_name=elem_name, + ) + + +def read_nullable( + elem: H5Group | ZarrGroup, + *, + encoding_type: Literal["nullable-integer", "nullable-boolean"], + _reader: LazyReader, +) -> MaskedArray: + from anndata.experimental.backed._lazy_arrays import MaskedArray + + base_path_or_zarr_group = ( + Path(filename(elem)) if isinstance(elem, H5Group) else elem + ) + elem_name = get_elem_name(elem) + return MaskedArray( + values=elem["values"], + mask=elem["mask"] if "mask" in elem else None, + dtype_str=encoding_type, + base_path_or_zarr_group=base_path_or_zarr_group, + elem_name=elem_name, + ) + + +_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("nullable-integer", "0.1.0"))( + partial(read_nullable, encoding_type="nullable-integer") +) +_LAZY_REGISTRY.register_read(H5Group, IOSpec("nullable-integer", "0.1.0"))( + partial(read_nullable, encoding_type="nullable-integer") +) +_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("nullable-boolean", "0.1.0"))( + partial(read_nullable, encoding_type="nullable-boolean") +) +_LAZY_REGISTRY.register_read(H5Group, IOSpec("nullable-boolean", "0.1.0"))( + partial(read_nullable, encoding_type="nullable-boolean") +) diff --git a/src/anndata/_io/specs/registry.py b/src/anndata/_io/specs/registry.py index 3b43def7c..db463449c 100644 --- a/src/anndata/_io/specs/registry.py +++ b/src/anndata/_io/specs/registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import warnings from collections.abc import Mapping from dataclasses import dataclass @@ -8,7 +9,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from anndata._io.utils import report_read_key_on_error, report_write_key_on_error -from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal +from anndata._types import Read, ReadLazy, _ReadInternal, _ReadLazyInternal from anndata.compat import DaskArray, _read_attr if TYPE_CHECKING: @@ -23,10 +24,13 @@ WriteCallback, _WriteInternal, ) + from anndata.experimental.backed._compat import Dataset2D + from anndata.experimental.backed._lazy_arrays import CategoricalArray, MaskedArray from anndata.typing import RWAble T = TypeVar("T") W = TypeVar("W", bound=_WriteInternal) + LazyDataStructures = DaskArray | Dataset2D | CategoricalArray | MaskedArray # TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-" @@ -78,8 +82,8 @@ def wrapper(g: GroupStorageType, k: str, *args, **kwargs): return decorator -_R = TypeVar("_R", _ReadInternal, _ReadDaskInternal) -R = TypeVar("R", Read, ReadDask) +_R = TypeVar("_R", _ReadInternal, _ReadLazyInternal) +R = TypeVar("R", Read, ReadLazy) class IORegistry(Generic[_R, R]): @@ -213,7 +217,7 @@ def get_spec(self, elem: Any) -> IOSpec: _REGISTRY: IORegistry[_ReadInternal, Read] = IORegistry() -_LAZY_REGISTRY: IORegistry[_ReadDaskInternal, ReadDask] = IORegistry() +_LAZY_REGISTRY: IORegistry[_ReadLazyInternal, ReadLazy] = IORegistry() @singledispatch @@ -282,24 +286,35 @@ def read_elem( return self.callback(read_func, elem.name, elem, iospec=iospec) -class DaskReader(Reader): +class LazyReader(Reader): @report_read_key_on_error def read_elem( self, elem: StorageType, modifiers: frozenset[str] = frozenset(), chunks: tuple[int, ...] | None = None, - ) -> DaskArray: + **kwargs, + ) -> LazyDataStructures: """Read a dask element from a store. See exported function for more details.""" iospec = get_spec(elem) - read_func: ReadDask = self.registry.get_read( + read_func: ReadLazy = self.registry.get_read( type(elem), iospec, modifiers, reader=self ) if self.callback is not None: msg = "Dask reading does not use a callback. Ignoring callback." warnings.warn(msg, stacklevel=2) - return read_func(elem, chunks=chunks) + read_params = inspect.signature(read_func).parameters + for kwarg in kwargs: + if kwarg not in read_params: + msg = ( + f"Keyword argument {kwarg} passed to read_elem_lazy are not supported by the " + "registered read function." + ) + raise ValueError(msg) + if "chunks" in read_params: + kwargs["chunks"] = chunks + return read_func(elem, **kwargs) class Writer: @@ -378,9 +393,9 @@ def read_elem(elem: StorageType) -> RWAble: return Reader(_REGISTRY).read_elem(elem) -def read_elem_as_dask( - elem: StorageType, chunks: tuple[int, ...] | None = None -) -> DaskArray: +def read_elem_lazy( + elem: StorageType, chunks: tuple[int, ...] | None = None, **kwargs +) -> LazyDataStructures: """ Read an element from a store lazily. @@ -423,18 +438,16 @@ def read_elem_as_dask( Reading a sparse matrix from a zarr store lazily, with custom chunk size and default: >>> g = zarr.open(zarr_path) - >>> adata.X = ad.experimental.read_elem_as_dask(g["X"]) + >>> adata.X = ad.experimental.read_elem_lazy(g["X"]) >>> adata.X dask.array - >>> adata.X = ad.experimental.read_elem_as_dask( - ... g["X"], chunks=(500, adata.shape[1]) - ... ) + >>> adata.X = ad.experimental.read_elem_lazy(g["X"], chunks=(500, adata.shape[1])) >>> adata.X dask.array Reading a dense matrix from a zarr store lazily: - >>> adata.layers["dense"] = ad.experimental.read_elem_as_dask(g["layers/dense"]) + >>> adata.layers["dense"] = ad.experimental.read_elem_lazy(g["layers/dense"]) >>> adata.layers["dense"] dask.array @@ -447,12 +460,10 @@ def read_elem_as_dask( ... obsm=ad.io.read_elem(g["obsm"]), ... varm=ad.io.read_elem(g["varm"]), ... ) - >>> adata.X = ad.experimental.read_elem_as_dask( - ... g["X"], chunks=(500, adata.shape[1]) - ... ) - >>> adata.layers["dense"] = ad.experimental.read_elem_as_dask(g["layers/dense"]) + >>> adata.X = ad.experimental.read_elem_lazy(g["X"], chunks=(500, adata.shape[1])) + >>> adata.layers["dense"] = ad.experimental.read_elem_lazy(g["layers/dense"]) """ - return DaskReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks) + return LazyReader(_LAZY_REGISTRY).read_elem(elem, chunks=chunks, **kwargs) def write_elem( diff --git a/src/anndata/_io/zarr.py b/src/anndata/_io/zarr.py index 2564738ad..2a690c1b8 100644 --- a/src/anndata/_io/zarr.py +++ b/src/anndata/_io/zarr.py @@ -48,6 +48,7 @@ def callback(func, s, k, elem, dataset_kwargs, iospec): func(s, k, elem, dataset_kwargs=dataset_kwargs) write_dispatched(f, "/", adata, callback=callback, dataset_kwargs=ds_kwargs) + zarr.convenience.consolidate_metadata(f.store) def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> AnnData: diff --git a/src/anndata/_types.py b/src/anndata/_types.py index 66f8a9e29..92a758dad 100644 --- a/src/anndata/_types.py +++ b/src/anndata/_types.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, TypeVar +from typing import TYPE_CHECKING, Literal, Protocol, TypeVar from .compat import ( H5Array, @@ -18,8 +18,13 @@ from collections.abc import Mapping from typing import Any, TypeAlias - from ._io.specs.registry import DaskReader, IOSpec, Reader, Writer - from .compat import DaskArray + from ._io.specs.registry import ( + IOSpec, + LazyDataStructures, + LazyReader, + Reader, + Writer, + ) __all__ = [ "ArrayStorageType", @@ -44,10 +49,10 @@ class _ReadInternal(Protocol[SCon, CovariantRWAble]): def __call__(self, elem: SCon, *, _reader: Reader) -> CovariantRWAble: ... -class _ReadDaskInternal(Protocol[SCon]): +class _ReadLazyInternal(Protocol[SCon]): def __call__( - self, elem: SCon, *, _reader: DaskReader, chunks: tuple[int, ...] | None = None - ) -> DaskArray: ... + self, elem: SCon, *, _reader: LazyReader, chunks: tuple[int, ...] | None = None + ) -> LazyDataStructures: ... class Read(Protocol[SCon, CovariantRWAble]): @@ -60,16 +65,16 @@ def __call__(self, elem: SCon) -> CovariantRWAble: The element to read from. Returns ------- - The element read from the store. + The element read from the store. """ ... -class ReadDask(Protocol[SCon]): +class ReadLazy(Protocol[SCon]): def __call__( self, elem: SCon, *, chunks: tuple[int, ...] | None = None - ) -> DaskArray: - """Low-level reading function for a dask element. + ) -> LazyDataStructures: + """Low-level reading function for a lazy element. Parameters ---------- @@ -79,7 +84,7 @@ def __call__( The chunk size to be used. Returns ------- - The dask element read from the store. + The lazy element read from the store. """ ... @@ -147,7 +152,7 @@ def __call__( Returns ------- - The element read from the store. + The element read from the store. """ ... @@ -183,3 +188,19 @@ def __call__( Keyword arguments to be passed to a library-level io function, like `chunks` for :doc:`zarr:index`. """ ... + + +AnnDataElem = Literal[ + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "X", + "raw", + "uns", +] + +Join_T = Literal["inner", "outer"] diff --git a/src/anndata/experimental/__init__.py b/src/anndata/experimental/__init__.py index 90e83a87e..2c233c6b6 100644 --- a/src/anndata/experimental/__init__.py +++ b/src/anndata/experimental/__init__.py @@ -3,10 +3,11 @@ from types import MappingProxyType from typing import TYPE_CHECKING -from .._io.specs import IOSpec, read_elem_as_dask +from .._io.specs import IOSpec, read_elem_lazy from .._types import Read, ReadCallback, StorageType, Write, WriteCallback from ..utils import module_get_attr_redirect from ._dispatch_io import read_dispatched, write_dispatched +from .backed import read_lazy from .merge import concat_on_disk from .multi_files import AnnCollection from .pytorch import AnnLoader @@ -14,7 +15,6 @@ if TYPE_CHECKING: from typing import Any - # Map old name in `anndata.experimental` to new name in `anndata` _DEPRECATED = MappingProxyType( dict( @@ -41,12 +41,13 @@ def __getattr__(attr_name: str) -> Any: __all__ = [ "AnnCollection", "AnnLoader", - "read_elem_as_dask", + "read_elem_lazy", "read_dispatched", "write_dispatched", "IOSpec", "concat_on_disk", "Read", + "read_lazy", "Write", "ReadCallback", "WriteCallback", diff --git a/src/anndata/experimental/backed/__init__.py b/src/anndata/experimental/backed/__init__.py new file mode 100644 index 000000000..9c8acba50 --- /dev/null +++ b/src/anndata/experimental/backed/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ._io import read_lazy + +__all__ = ["read_lazy"] diff --git a/src/anndata/experimental/backed/_compat.py b/src/anndata/experimental/backed/_compat.py new file mode 100644 index 000000000..6c69cb051 --- /dev/null +++ b/src/anndata/experimental/backed/_compat.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from importlib.util import find_spec +from typing import TYPE_CHECKING + +if find_spec("xarray") or TYPE_CHECKING: + import xarray + from xarray import DataArray + from xarray.backends import BackendArray + from xarray.backends.zarr import ZarrArrayWrapper + + +else: + + class DataArray: + def __repr__(self) -> str: + return "mock DataArray" + + xarray = None + + class ZarrArrayWrapper: + def __repr__(self) -> str: + return "mock ZarrArrayWrapper" + + class BackendArray: + def __repr__(self) -> str: + return "mock BackendArray" + + +from ._xarray import Dataset, Dataset2D # noqa: F401 diff --git a/src/anndata/experimental/backed/_io.py b/src/anndata/experimental/backed/_io.py new file mode 100644 index 000000000..4ef86d7bd --- /dev/null +++ b/src/anndata/experimental/backed/_io.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import typing +import warnings +from pathlib import Path +from typing import TYPE_CHECKING + +import h5py + +from anndata._io.specs.registry import read_elem_lazy +from anndata._types import AnnDataElem + +from ..._core.anndata import AnnData +from ..._settings import settings +from .. import read_dispatched + +if TYPE_CHECKING: + from collections.abc import MutableMapping + + from anndata._io.specs.registry import IOSpec + from anndata._types import Read, StorageType + + from ...compat import ZarrGroup + + +def read_lazy( + store: str | Path | MutableMapping | ZarrGroup | h5py.Dataset, + load_annotation_index: bool = True, +) -> AnnData: + """ + Lazily read in on-disk/in-cloud AnnData stores, including `obs` and `var`. + No array data should need to be read into memory with the exception of :class:`ak.Array`, scalars, and some older-encoding arrays. + + Parameters + ---------- + store + A store-like object to be read in. If :class:`zarr.hierarchy.Group`, it is best for it to be consolidated. + load_annotation_index + Whether or not to use a range index for the `{obs,var}` :class:`xarray.Dataset` so as not to load the index into memory. + If `False`, the real `index` will be inserted as `{obs,var}_names` in the object but not be one of the `coords` thereby preventing read operations. + Access to `adata.obs.index` will also only give the dummy index, and not the "real" index that is file-backed. + + Returns + ------- + A lazily read-in :class:`~anndata.AnnData` object. + + Examples + -------- + + Preparing example objects + + >>> import anndata as ad + >>> import httpx + >>> import scanpy as sc + >>> base_url = "https://datasets.cellxgene.cziscience.com" + >>> def get_cellxgene_data(id_: str): + ... out_path = sc.settings.datasetdir / f"{id_}.h5ad" + ... if out_path.exists(): + ... return out_path + ... file_url = f"{base_url}/{id_}.h5ad" + ... sc.settings.datasetdir.mkdir(parents=True, exist_ok=True) + ... with httpx.stream("GET", file_url) as r, out_path.open("wb") as f: + ... r.raise_for_status() + ... for data in r.iter_bytes(): + ... f.write(data) + ... return out_path + >>> path_b_cells = get_cellxgene_data("a93eab58-3d82-4b61-8a2f-d7666dcdb7c4") + >>> path_fetal = get_cellxgene_data("d170ff04-6da0-4156-a719-f8e1bbefbf53") + >>> b_cells_adata = ad.experimental.read_lazy(path_b_cells) + >>> fetal_adata = ad.experimental.read_lazy(path_fetal) + >>> print(b_cells_adata) + AnnData object with n_obs × n_vars = 146 × 33452 + obs: 'donor_id', 'self_reported_ethnicity_ontology_term_id', 'organism_ontology_term_id', 'sample_uuid', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'suspension_uuid', 'suspension_type', 'library_uuid', 'assay_ontology_term_id', 'mapped_reference_annotation', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'sex_ontology_term_id', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'Phase', 'sample', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage' + var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype' + uns: 'default_embedding', 'schema_version', 'title' + obsm: 'X_harmony', 'X_pca', 'X_umap' + >>> print(fetal_adata) + AnnData object with n_obs × n_vars = 344 × 15585 + obs: 'nCount_Spatial', 'nFeature_Spatial', 'Cluster', 'adult_pred_type', 'adult_pred_value', 'fetal_pred_type', 'fetal_pred_value', 'pDCs', 'Cell Cycle', 'Type 3 ILCs', 'DCs', 'Mast', 'Monocytes', 'Naive T-Cells', 'Venous (CP) 1', 'Venous (M) 2', 'Arterial (L)', 'Endothelium G2M-phase', 'Venous (CP) 2', 'Arterial (CP)', 'Arterial (M)', 'Endothelium S-phase', 'Proximal Progenitor', 'Proximal Mature Enterocytes', 'BEST4_OTOP2 Cells', 'Proximal TA', 'Proximal Early Enterocytes', 'Proximal Enterocytes', 'Proximal Stem Cells', 'EECs', 'Distal Enterocytes', 'Goblets', 'Distal TA', 'Distal Absorptive', 'Distal Stem Cells', 'Secretory Progenitors', 'Distal Mature Enterocytes', 'S1', 'S1 COL6A5+', 'S4 CCL21+', 'Proximal S2 (2)', 'S1 IFIT3+', 'Distal S2', 'Fibroblasts S-phase', 'Proximal S2 (1)', 'S3 Progenitor', 'Fibroblasts G2M-phase', 'S4 CXCL14+', 'Fibroblast Progenitor', 'S3 Transitional', 'Erythroid', 'S3 EBF+', 'S3 HAND1+', 'Pericytes G2M-phase', 'Pericyte Progenitors', 'Undifferentiated Pericytes', 'ICC PDGFRA+', 'MYOCD+ Muscularis', 'Muscularis S-phase', 'Muscularis G2M-phase', 'HOXP+ Proximal Muscularis', 'FOXF2+ Distal Muscularis', 'FOXF2- Muscularis', 'MORN5+ Distal Muscularis', 'Myofibroblast Progenitors', 'Myofibroblasts', 'Mesothelium SOX6+', 'Myofibroblasts S-phase', 'Myofibroblasts G2M-phase', 'Glial Progenitors', 'Excitory Motor Neuron', 'Interneuron', 'Differentiating Submucosal Glial', 'Inhibitory Motor Neuron Precursor', 'Neuroendocrine (1)', 'max', 'tissue_ontology_term_id', 'assay_ontology_term_id', 'disease_ontology_term_id', 'development_stage_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'cell_type_ontology_term_id', 'sex_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'donor_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage' + var: 'sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_mean', 'sct.residual_variance', 'sct.variable', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype' + uns: 'adult_pred_mat', 'fetal_pred_mat', 'schema_version', 'title' + obsm: 'X_pca', 'X_spatial', 'X_umap' + layers: 'counts', 'scale.data' + + This functionality is compatible with :func:`anndata.concat` + + >>> ad.concat([b_cells_adata, fetal_adata], join="outer") + AnnData object with n_obs × n_vars = 490 × 33452 + obs: 'donor_id', 'self_reported_ethnicity_ontology_term_id', 'organism_ontology_term_id', 'sample_uuid', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'suspension_uuid', 'suspension_type', 'library_uuid', 'assay_ontology_term_id', 'mapped_reference_annotation', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'sex_ontology_term_id', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'Phase', 'sample', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'nCount_Spatial', 'nFeature_Spatial', 'Cluster', 'adult_pred_type', 'adult_pred_value', 'fetal_pred_type', 'fetal_pred_value', 'pDCs', 'Cell Cycle', 'Type 3 ILCs', 'DCs', 'Mast', 'Monocytes', 'Naive T-Cells', 'Venous (CP) 1', 'Venous (M) 2', 'Arterial (L)', 'Endothelium G2M-phase', 'Venous (CP) 2', 'Arterial (CP)', 'Arterial (M)', 'Endothelium S-phase', 'Proximal Progenitor', 'Proximal Mature Enterocytes', 'BEST4_OTOP2 Cells', 'Proximal TA', 'Proximal Early Enterocytes', 'Proximal Enterocytes', 'Proximal Stem Cells', 'EECs', 'Distal Enterocytes', 'Goblets', 'Distal TA', 'Distal Absorptive', 'Distal Stem Cells', 'Secretory Progenitors', 'Distal Mature Enterocytes', 'S1', 'S1 COL6A5+', 'S4 CCL21+', 'Proximal S2 (2)', 'S1 IFIT3+', 'Distal S2', 'Fibroblasts S-phase', 'Proximal S2 (1)', 'S3 Progenitor', 'Fibroblasts G2M-phase', 'S4 CXCL14+', 'Fibroblast Progenitor', 'S3 Transitional', 'Erythroid', 'S3 EBF+', 'S3 HAND1+', 'Pericytes G2M-phase', 'Pericyte Progenitors', 'Undifferentiated Pericytes', 'ICC PDGFRA+', 'MYOCD+ Muscularis', 'Muscularis S-phase', 'Muscularis G2M-phase', 'HOXP+ Proximal Muscularis', 'FOXF2+ Distal Muscularis', 'FOXF2- Muscularis', 'MORN5+ Distal Muscularis', 'Myofibroblast Progenitors', 'Myofibroblasts', 'Mesothelium SOX6+', 'Myofibroblasts S-phase', 'Myofibroblasts G2M-phase', 'Glial Progenitors', 'Excitory Motor Neuron', 'Interneuron', 'Differentiating Submucosal Glial', 'Inhibitory Motor Neuron Precursor', 'Neuroendocrine (1)', 'max' + var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'sct.detection_rate', 'sct.gmean', 'sct.variance', 'sct.residual_mean', 'sct.residual_variance', 'sct.variable' + obsm: 'X_harmony', 'X_pca', 'X_umap', 'X_spatial' + layers: 'counts', 'scale.data' + """ + try: + import xarray # noqa: F401 + except ImportError: + raise ImportError( + "xarray is required to use the `read_lazy` function. Please install xarray." + ) + is_h5_store = isinstance(store, h5py.Dataset | h5py.File | h5py.Group) + is_h5 = ( + isinstance(store, Path | str) and Path(store).suffix == ".h5ad" + ) or is_h5_store + + has_keys = True # true if consolidated or h5ad + if not is_h5: + import zarr + + if not isinstance(store, zarr.hierarchy.Group): + try: + f = zarr.open_consolidated(store, mode="r") + except KeyError: + msg = "Did not read zarr as consolidated. Consider consolidating your metadata." + warnings.warn(msg) + has_keys = False + f = zarr.open(store, mode="r") + else: + f = store + else: + if is_h5_store: + f = store + else: + f = h5py.File(store, mode="r") + + def callback(func: Read, /, elem_name: str, elem: StorageType, *, iospec: IOSpec): + if iospec.encoding_type in {"anndata", "raw"} or elem_name.endswith("/"): + iter_object = ( + elem.items() + if has_keys + else [(k, elem[k]) for k in typing.get_args(AnnDataElem) if k in elem] + ) + return AnnData(**{k: read_dispatched(v, callback) for k, v in iter_object}) + elif ( + iospec.encoding_type + in { + "csr_matrix", + "csc_matrix", + "array", + "string-array", + "dataframe", + "categorical", + } + or "nullable" in iospec.encoding_type + ): + if "dataframe" == iospec.encoding_type and elem_name in {"/obs", "/var"}: + return read_elem_lazy(elem, use_range_index=not load_annotation_index) + return read_elem_lazy(elem) + elif iospec.encoding_type in {"awkward-array"}: + return read_dispatched(elem, None) + elif iospec.encoding_type == "dict": + return {k: read_dispatched(v, callback=callback) for k, v in elem.items()} + return func(elem) + + with settings.override(check_uniqueness=load_annotation_index): + adata = read_dispatched(f, callback=callback) + + return adata diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py new file mode 100644 index 000000000..6ee4eb404 --- /dev/null +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Generic, TypeVar + +import pandas as pd + +from anndata._core.index import _subset +from anndata._core.views import as_view +from anndata.compat import H5Array, ZarrArray + +from ..._settings import settings +from ._compat import BackendArray, DataArray, ZarrArrayWrapper +from ._compat import xarray as xr + +if TYPE_CHECKING: + from pathlib import Path + from typing import Literal + + import numpy as np + + from anndata._core.index import Index + from anndata.compat import ZarrGroup + + +K = TypeVar("K", H5Array, ZarrArray) + + +class ZarrOrHDF5Wrapper(ZarrArrayWrapper, Generic[K]): + def __init__(self, array: K): + if isinstance(array, ZarrArray): + return super().__init__(array) + self._array = array + self.shape = self._array.shape + self.dtype = self._array.dtype + + def __getitem__(self, key: xr.core.indexing.ExplicitIndexer): + if isinstance(self._array, ZarrArray): + return super().__getitem__(key) + return xr.core.indexing.explicit_indexing_adapter( + key, + self.shape, + xr.core.indexing.IndexingSupport.OUTER_1VECTOR, + lambda key: self._array[key], + ) + + +class CategoricalArray(BackendArray, Generic[K]): + _codes: ZarrOrHDF5Wrapper[K] + _categories: ZarrArray | H5Array + shape: tuple[int, ...] + base_path_or_zarr_group: Path | ZarrGroup + elem_name: str + + def __init__( + self, + codes: K, + categories: ZarrArray | H5Array, + ordered: bool, + base_path_or_zarr_group: Path | ZarrGroup, + elem_name: str, + *args, + **kwargs, + ): + self._categories = categories + self._ordered = ordered + self._codes = ZarrOrHDF5Wrapper(codes) + self.shape = self._codes.shape + self.base_path_or_zarr_group = base_path_or_zarr_group + self.file_format = "zarr" if isinstance(codes, ZarrArray) else "h5" + self.elem_name = elem_name + + @cached_property + def categories(self) -> np.ndarray: + if isinstance(self._categories, ZarrArray): + return self._categories[...] + from ..._io.h5ad import read_dataset + + return read_dataset(self._categories) + + def __getitem__( + self, key: xr.core.indexing.ExplicitIndexer + ) -> xr.core.extension_array.PandasExtensionArray: + codes = self._codes[key] + categorical_array = pd.Categorical.from_codes( + codes=codes, categories=self.categories, ordered=self._ordered + ) + if settings.remove_unused_categories: + categorical_array = categorical_array.remove_unused_categories() + return xr.core.extension_array.PandasExtensionArray(categorical_array) + + @cached_property + def dtype(self): + return pd.CategoricalDtype(categories=self._categories, ordered=self._ordered) + + +class MaskedArray(BackendArray, Generic[K]): + _mask: ZarrOrHDF5Wrapper[K] + _values: ZarrOrHDF5Wrapper[K] + _dtype_str: Literal["nullable-integer", "nullable-boolean"] + shape: tuple[int, ...] + base_path_or_zarr_group: Path | ZarrGroup + elem_name: str + + def __init__( + self, + values: ZarrArray | H5Array, + dtype_str: Literal["nullable-integer", "nullable-boolean"], + mask: ZarrArray | H5Array, + base_path_or_zarr_group: Path | ZarrGroup, + elem_name: str, + ): + self._mask = ZarrOrHDF5Wrapper(mask) + self._values = ZarrOrHDF5Wrapper(values) + self._dtype_str = dtype_str + self.shape = self._values.shape + self.base_path_or_zarr_group = base_path_or_zarr_group + self.file_format = "zarr" if isinstance(mask, ZarrArray) else "h5" + self.elem_name = elem_name + + def __getitem__( + self, key: xr.core.indexing.ExplicitIndexer + ) -> xr.core.extension_array.PandasExtensionArray: + values = self._values[key] + mask = self._mask[key] + if self._dtype_str == "nullable-integer": + # numpy does not support nan ints + extension_array = pd.arrays.IntegerArray(values, mask=mask) + elif self._dtype_str == "nullable-boolean": + extension_array = pd.arrays.BooleanArray(values, mask=mask) + else: + raise RuntimeError(f"Invalid dtype_str {self._dtype_str}") + return xr.core.extension_array.PandasExtensionArray(extension_array) + + @cached_property + def dtype(self): + if self._dtype_str == "nullable-integer": + return pd.array( + [], + dtype=str(pd.api.types.pandas_dtype(self._values.dtype)).capitalize(), + ).dtype + elif self._dtype_str == "nullable-boolean": + return pd.BooleanDtype() + raise RuntimeError(f"Invalid dtype_str {self._dtype_str}") + + +@_subset.register(DataArray) +def _subset_masked(a: DataArray, subset_idx: Index): + return a[subset_idx] + + +@as_view.register(DataArray) +def _view_pd_boolean_array(a: DataArray, view_args): + return a diff --git a/src/anndata/experimental/backed/_xarray.py b/src/anndata/experimental/backed/_xarray.py new file mode 100644 index 000000000..12a1e7b3f --- /dev/null +++ b/src/anndata/experimental/backed/_xarray.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd + +from ..._core.anndata import AnnData, _gen_dataframe +from ..._core.file_backing import to_memory +from ..._core.index import _subset +from ..._core.views import as_view + +try: + from xarray import Dataset +except ImportError: + + class Dataset: + def __repr__(self) -> str: + return "mock Dataset" + + +if TYPE_CHECKING: + from collections.abc import Hashable, Iterable + from typing import Any, Literal + + from ..._core.index import Index + from ._compat import xarray as xr + + +def get_index_dim(ds: xr.DataArray) -> Hashable: + if len(ds.sizes) != 1: + msg = f"xarray Dataset should not have more than 1 dims, found {len(ds)}" + raise ValueError(msg) + return list(ds.indexes.keys())[0] + + +class Dataset2D(Dataset): + __slots__ = () + + @property + def index(self) -> pd.Index: + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.index` so this ensures usability + + Returns + ------- + The index of the of the dataframe as resolved from :attr:`~xarray.Dataset.coords`. + """ + coord = get_index_dim(self) + return self.indexes[coord] + + @index.setter + def index(self, val) -> None: + coord = get_index_dim(self) + self.coords[coord] = val + + @property + def shape(self) -> tuple[int, int]: + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.shape` so this ensures usability + + Returns + ------- + The (2D) shape of the dataframe resolved from :attr:`~xarray.Dataset.sizes`. + """ + return (self.sizes[get_index_dim(self)], len(self)) + + @property + def iloc(self): + """:attr:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.iloc` so this ensures usability + + Returns + ------- + Handler class for doing the iloc-style indexing using :meth:`~xarray.Dataset.isel`. + """ + + class IlocGetter: + def __init__(self, ds): + self._ds = ds + + def __getitem__(self, idx): + coord = get_index_dim(self._ds) + return self._ds.isel(**{coord: idx}) + + return IlocGetter(self) + + @property + def columns(self) -> pd.Index: + """ + :class:`~anndata.AnnData` internally looks for :attr:`~pandas.DataFrame.columns` so this ensures usability + + Returns + ------- + :class:`pandas.Index` that represents the "columns." + """ + columns_list = list(self.keys()) + return pd.Index(columns_list) + + +@_subset.register(Dataset2D) +def _(a: Dataset2D, subset_idx: Index): + key = a.attrs["indexing_key"] + # xarray seems to have some code looking for a second entry in tuples + if isinstance(subset_idx, tuple) and len(subset_idx) == 1: + subset_idx = subset_idx[0] + return a.isel(**{key: subset_idx}) + + +@as_view.register(Dataset2D) +def _(a: Dataset2D, view_args): + return a + + +@_gen_dataframe.register(Dataset2D) +def _gen_dataframe_xr( + anno: Dataset2D, + index_names: Iterable[str], + *, + source: Literal["X", "shape"], + attr: Literal["obs", "var"], + length: int | None = None, +): + return anno + + +@AnnData._remove_unused_categories.register(Dataset2D) +def _remove_unused_categories_xr( + df_full: Dataset2D, df_sub: Dataset2D, uns: dict[str, Any] +): + pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. + + +@to_memory.register(Dataset2D) +def to_memory(ds: Dataset2D, copy=False): + df = ds.to_dataframe() + index_key = ds.attrs.get("indexing_key", None) + if df.index.name != index_key and index_key is not None: + df = df.set_index(index_key) + df.index.name = None # matches old AnnData object + return df diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 6ed637ed8..1e69e5b70 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -4,7 +4,7 @@ import random import re import warnings -from collections import Counter +from collections import Counter, defaultdict from collections.abc import Mapping from contextlib import contextmanager from functools import partial, singledispatch, wraps @@ -37,7 +37,7 @@ from anndata.utils import asarray if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable + from collections.abc import Callable, Collection from typing import Literal, TypeGuard, TypeVar DT = TypeVar("DT") @@ -1052,34 +1052,56 @@ def __init__(self, *_args, **_kwargs) -> None: class AccessTrackingStore(DirectoryStore): _access_count: Counter[str] - _accessed_keys: dict[str, list[str]] + _accessed: defaultdict[str, set] + _accessed_keys: defaultdict[str, list[str]] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._access_count = Counter() - self._accessed_keys = {} + self._accessed = defaultdict(set) + self._accessed_keys = defaultdict(list) def __getitem__(self, key: str) -> object: for tracked in self._access_count: if tracked in key: self._access_count[tracked] += 1 + self._accessed[tracked].add(key) self._accessed_keys[tracked] += [key] return super().__getitem__(key) def get_access_count(self, key: str) -> int: + # access defaultdict when value is not there causes key to be there, + # which causes it to be tracked + if key not in self._access_count: + raise KeyError(f"{key} not found among access count") return self._access_count[key] + def get_subkeys_accessed(self, key: str) -> set[str]: + if key not in self._accessed: + raise KeyError(f"{key} not found among accessed") + return self._accessed[key] + def get_accessed_keys(self, key: str) -> list[str]: + if key not in self._accessed_keys: + raise KeyError(f"{key} not found among accessed keys") return self._accessed_keys[key] - def initialize_key_trackers(self, keys_to_track: Iterable[str]) -> None: + def initialize_key_trackers(self, keys_to_track: Collection[str]) -> None: for k in keys_to_track: self._access_count[k] = 0 self._accessed_keys[k] = [] + self._accessed[k] = set() def reset_key_trackers(self) -> None: self.initialize_key_trackers(self._access_count.keys()) + def assert_access_count(self, key: str, count: int): + keys_accessed = self.get_subkeys_accessed(key) + access_count = self.get_access_count(key) + assert ( + self.get_access_count(key) == count + ), f"Found {access_count} accesses at {keys_accessed}" + def get_multiindex_columns_df(shape: tuple[int, int]) -> pd.DataFrame: return pd.DataFrame( diff --git a/src/anndata/typing.py b/src/anndata/typing.py index d13927bad..ee6ff74fc 100644 --- a/src/anndata/typing.py +++ b/src/anndata/typing.py @@ -31,14 +31,12 @@ Index = _Index """1D or 2D index an :class:`~anndata.AnnData` object can be sliced with.""" - -ArrayDataStructureType: TypeAlias = ( +XDataType: TypeAlias = ( np.ndarray | ma.MaskedArray | sparse.csr_matrix | sparse.csc_matrix | SpArray - | AwkArray | H5Array | ZarrArray | ZappyArray @@ -48,6 +46,7 @@ | CupyArray | CupySparseMatrix ) +ArrayDataStructureType: TypeAlias = XDataType | AwkArray InMemoryArrayOrScalarType: TypeAlias = ( diff --git a/src/testing/anndata/_pytest.py b/src/testing/anndata/_pytest.py index 5b0fd60e0..ecca3348d 100644 --- a/src/testing/anndata/_pytest.py +++ b/src/testing/anndata/_pytest.py @@ -9,6 +9,7 @@ from __future__ import annotations +import importlib import re import warnings from typing import TYPE_CHECKING, cast @@ -55,6 +56,9 @@ def _doctest_env( from anndata.utils import import_name assert isinstance(request.node.parent, pytest.Module) + if "experimental/backed" in str(request.node.path): + if importlib.util.find_spec("xarray") is None: + pytest.skip("xarray not installed") # request.node.parent is either a DoctestModule or a DoctestTextFile. # Only DoctestModule has a .obj attribute (the imported module). if request.node.parent.obj: @@ -63,7 +67,6 @@ def _doctest_env( if warning_detail := getattr(func, "__deprecated", None): cat, msg, _ = warning_detail warnings.filterwarnings("ignore", category=cat, message=re.escape(msg)) - old_dd, settings.datasetdir = settings.datasetdir, cache.mkdir("scanpy-data") with chdir(tmp_path): yield diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index e034debd2..8c143d2f6 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -150,9 +150,11 @@ def fix_known_differences(orig, result, backwards_compat=True): result.obs.drop(columns=["batch"], inplace=True) # Possibly need to fix this, ordered categoricals lose orderedness - for k, dtype in orig.obs.dtypes.items(): - if isinstance(dtype, pd.CategoricalDtype) and dtype.ordered: - result.obs[k] = result.obs[k].astype(dtype) + for get_df in [lambda k: getattr(k, "obs"), lambda k: getattr(k, "obsm")["df"]]: + str_to_df_converted = get_df(result) + for k, dtype in get_df(orig).dtypes.items(): + if isinstance(dtype, pd.CategoricalDtype) and dtype.ordered: + str_to_df_converted[k] = str_to_df_converted[k].astype(dtype) return orig, result diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index e46cd7d81..367511e83 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -23,7 +23,7 @@ ) from anndata._io.specs.registry import IORegistryError from anndata.compat import CAN_USE_SPARSE_ARRAY, SpArray, ZarrGroup, _read_attr -from anndata.experimental import read_elem_as_dask +from anndata.experimental import read_elem_lazy from anndata.io import read_elem, write_elem from anndata.tests.helpers import ( as_cupy, @@ -251,7 +251,7 @@ def test_dask_write_sparse(sparse_format, store): def test_read_lazy_2d_dask(sparse_format, store): arr_store = create_sparse_store(sparse_format, store) - X_dask_from_disk = read_elem_as_dask(arr_store["X"]) + X_dask_from_disk = read_elem_lazy(arr_store["X"]) X_from_disk = read_elem(arr_store["X"]) assert_equal(X_from_disk, X_dask_from_disk) @@ -288,7 +288,7 @@ def test_read_lazy_2d_dask(sparse_format, store): ) def test_read_lazy_subsets_nd_dask(store, n_dims, chunks): arr_store = create_dense_store(store, n_dims) - X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + X_dask_from_disk = read_elem_lazy(arr_store["X"], chunks=chunks) X_from_disk = read_elem(arr_store["X"]) assert_equal(X_from_disk, X_dask_from_disk) @@ -306,7 +306,7 @@ def test_read_lazy_h5_cluster(sparse_format, tmp_path): with h5py.File(tmp_path / "test.h5", "w") as file: store = file["/"] arr_store = create_sparse_store(sparse_format, store) - X_dask_from_disk = read_elem_as_dask(arr_store["X"]) + X_dask_from_disk = read_elem_lazy(arr_store["X"]) X_from_disk = read_elem(arr_store["X"]) with ( dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster, @@ -328,10 +328,10 @@ def test_read_lazy_h5_cluster(sparse_format, tmp_path): def test_read_lazy_2d_chunk_kwargs(store, arr_type, chunks): if arr_type == "dense": arr_store = create_dense_store(store) - X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + X_dask_from_disk = read_elem_lazy(arr_store["X"], chunks=chunks) else: arr_store = create_sparse_store(arr_type, store) - X_dask_from_disk = read_elem_as_dask(arr_store["X"], chunks=chunks) + X_dask_from_disk = read_elem_lazy(arr_store["X"], chunks=chunks) if chunks is not None: assert X_dask_from_disk.chunksize == chunks else: @@ -350,9 +350,9 @@ def test_read_lazy_bad_chunk_kwargs(tmp_path): with pytest.raises( ValueError, match=r"`chunks` must be a tuple of two integers" ): - read_elem_as_dask(arr_store["X"], chunks=(SIZE,)) + read_elem_lazy(arr_store["X"], chunks=(SIZE,)) with pytest.raises(ValueError, match=r"Only the major axis can be chunked"): - read_elem_as_dask(arr_store["X"], chunks=(SIZE, 10)) + read_elem_lazy(arr_store["X"], chunks=(SIZE, 10)) @pytest.mark.parametrize("sparse_format", ["csr", "csc"]) diff --git a/tests/test_read_lazy.py b/tests/test_read_lazy.py new file mode 100644 index 000000000..9603d17c4 --- /dev/null +++ b/tests/test_read_lazy.py @@ -0,0 +1,664 @@ +from __future__ import annotations + +import typing +from contextlib import nullcontext +from importlib.util import find_spec +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pytest +from scipy import sparse + +import anndata as ad +from anndata import AnnData +from anndata._core.file_backing import to_memory +from anndata._types import AnnDataElem +from anndata.compat import DaskArray +from anndata.experimental import read_lazy +from anndata.tests.helpers import ( + AccessTrackingStore, + as_dense_dask_array, + assert_equal, + gen_adata, + gen_typed_df, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + from pathlib import Path + from typing import Literal + + from numpy.typing import NDArray + + from anndata._types import Join_T + +pytestmark = pytest.mark.skipif( + not find_spec("xarray"), reason="Xarray is not installed" +) + +ANNDATA_ELEMS = typing.get_args(AnnDataElem) + + +@pytest.fixture( + params=[sparse.csr_matrix, sparse.csc_matrix, np.array, as_dense_dask_array], + ids=["scipy-csr", "scipy-csc", "np-array", "dask_array"], + scope="session", +) +def mtx_format(request): + return request.param + + +@pytest.fixture( + params=[True, False], ids=["vars_different", "vars_same"], scope="session" +) +def are_vars_different(request): + return request.param + + +@pytest.fixture(params=["zarr", "h5ad"], scope="session") +def diskfmt(request): + return request.param + + +@pytest.fixture( + params=[True, False], + scope="session", + ids=["load-annotation-index", "dont-load-annotation-index"], +) +def load_annotation_index(request): + return request.param + + +@pytest.fixture(params=["outer", "inner"], scope="session") +def join(request): + return request.param + + +@pytest.fixture( + params=[ + pytest.param(lambda x: x, id="full"), + pytest.param(lambda x: x[0:10, :], id="subset"), + ], + scope="session", +) +def simple_subset_func(request): + return request.param + + +@pytest.fixture(scope="session") +def adata_remote_orig_with_path( + tmp_path_factory, + diskfmt: str, + mtx_format, + worker_id: str = "serial", +) -> tuple[Path, AnnData]: + """Create remote fixtures, one without a range index and the other with""" + file_name = f"orig_{worker_id}.{diskfmt}" + if diskfmt == "h5ad": + orig_path = tmp_path_factory.mktemp("h5ad_file_dir") / file_name + else: + orig_path = tmp_path_factory.mktemp(file_name) + orig = gen_adata((1000, 1100), mtx_format) + orig.raw = orig.copy() + getattr(orig, f"write_{diskfmt}")(orig_path) + return orig_path, orig + + +@pytest.fixture +def adata_remote( + adata_remote_orig_with_path: tuple[Path, AnnData], load_annotation_index: bool +) -> AnnData: + orig_path, _ = adata_remote_orig_with_path + return read_lazy(orig_path, load_annotation_index=load_annotation_index) + + +@pytest.fixture +def adata_orig(adata_remote_orig_with_path: tuple[Path, AnnData]) -> AnnData: + _, orig = adata_remote_orig_with_path + return orig + + +@pytest.fixture(scope="session") +def adata_remote_with_store_tall_skinny_path( + tmp_path_factory, + mtx_format, + worker_id: str = "serial", +) -> Path: + orig_path = tmp_path_factory.mktemp(f"orig_{worker_id}.zarr") + M = 100_000 # forces zarr to chunk `obs` columns multiple ways - that way 1 access to `int64` below is actually only one access + N = 5 + obs_names = pd.Index(f"cell{i}" for i in range(M)) + var_names = pd.Index(f"gene{i}" for i in range(N)) + obs = gen_typed_df(M, obs_names) + var = gen_typed_df(N, var_names) + orig = AnnData( + obs=obs, + var=var, + X=mtx_format(np.random.binomial(100, 0.005, (M, N)).astype(np.float32)), + ) + orig.raw = orig.copy() + orig.write_zarr(orig_path) + return orig_path + + +@pytest.fixture(scope="session") +def adatas_paths_var_indices_for_concatenation( + tmp_path_factory, are_vars_different: bool, worker_id: str = "serial" +) -> tuple[list[AnnData], list[Path], list[pd.Index]]: + adatas = [] + var_indices = [] + paths = [] + M = 1000 + N = 50 + n_datasets = 3 + for dataset_index in range(n_datasets): + orig_path = tmp_path_factory.mktemp(f"orig_{worker_id}_{dataset_index}.zarr") + paths.append(orig_path) + obs_names = pd.Index(f"cell_{dataset_index}_{i}" for i in range(M)) + var_names = pd.Index( + f"gene_{i}{f'_{dataset_index}_ds' if are_vars_different and (i % 2) else ''}" + for i in range(N) + ) + var_indices.append(var_names) + obs = gen_typed_df(M, obs_names) + var = gen_typed_df(N, var_names) + orig = AnnData( + obs=obs, + var=var, + X=np.random.binomial(100, 0.005, (M, N)).astype(np.float32), + ) + orig.write_zarr(orig_path) + adatas.append(orig) + return adatas, paths, var_indices + + +@pytest.fixture +def var_indices_for_concat( + adatas_paths_var_indices_for_concatenation, +) -> list[pd.Index]: + _, _, var_indices = adatas_paths_var_indices_for_concatenation + return var_indices + + +@pytest.fixture +def adatas_for_concat( + adatas_paths_var_indices_for_concatenation, +) -> list[AnnData]: + adatas, _, _ = adatas_paths_var_indices_for_concatenation + return adatas + + +@pytest.fixture +def stores_for_concat( + adatas_paths_var_indices_for_concatenation, +) -> list[AccessTrackingStore]: + _, paths, _ = adatas_paths_var_indices_for_concatenation + return [AccessTrackingStore(path) for path in paths] + + +@pytest.fixture +def lazy_adatas_for_concat( + stores_for_concat, +) -> list[AnnData]: + return [read_lazy(store) for store in stores_for_concat] + + +@pytest.fixture +def adata_remote_with_store_tall_skinny( + adata_remote_with_store_tall_skinny_path: Path, +) -> tuple[AnnData, AccessTrackingStore]: + store = AccessTrackingStore(adata_remote_with_store_tall_skinny_path) + remote = read_lazy(store) + return remote, store + + +@pytest.fixture +def remote_store_tall_skinny( + adata_remote_with_store_tall_skinny_path: Path, +) -> AccessTrackingStore: + return AccessTrackingStore(adata_remote_with_store_tall_skinny_path) + + +@pytest.fixture +def adata_remote_tall_skinny( + remote_store_tall_skinny: AccessTrackingStore, +) -> AnnData: + remote = read_lazy(remote_store_tall_skinny) + return remote + + +def get_key_trackers_for_columns_on_axis( + adata: AnnData, axis: Literal["obs", "var"] +) -> Generator[str, None, None]: + """Generate keys for tracking, using `codes` from categorical columns instead of the column name + + Parameters + ---------- + adata + Object to get keys from + axis + Axis to get keys from + + Yields + ------ + Keys for tracking + """ + for col in getattr(adata, axis).columns: + yield f"{axis}/{col}" if "cat" not in col else f"{axis}/{col}/codes" + + +@pytest.mark.parametrize( + ("elem_key", "sub_key"), + [ + ("raw", "X"), + ("obs", "cat"), + ("obs", "int64"), + *((elem_name, None) for elem_name in ANNDATA_ELEMS), + ], +) +def test_access_count_elem_access( + remote_store_tall_skinny: AccessTrackingStore, + adata_remote_tall_skinny: AnnData, + elem_key: AnnDataElem, + sub_key: str, + simple_subset_func: Callable[[AnnData], AnnData], +): + full_path = f"{elem_key}/{sub_key}" if sub_key is not None else elem_key + remote_store_tall_skinny.initialize_key_trackers({full_path, "X"}) + # a series of methods that should __not__ read in any data + elem = getattr(simple_subset_func(adata_remote_tall_skinny), elem_key) + if sub_key is not None: + getattr(elem, sub_key) + remote_store_tall_skinny.assert_access_count(full_path, 0) + remote_store_tall_skinny.assert_access_count("X", 0) + + +def test_access_count_subset( + remote_store_tall_skinny: AccessTrackingStore, + adata_remote_tall_skinny: AnnData, +): + non_obs_elem_names = filter(lambda e: e != "obs", ANNDATA_ELEMS) + remote_store_tall_skinny.initialize_key_trackers( + ["obs/cat/codes", *non_obs_elem_names] + ) + adata_remote_tall_skinny[adata_remote_tall_skinny.obs["cat"] == "a", :] + # all codes read in for subset (from 1 chunk) + remote_store_tall_skinny.assert_access_count("obs/cat/codes", 1) + for elem_name in non_obs_elem_names: + remote_store_tall_skinny.assert_access_count(elem_name, 0) + + +def test_access_count_subset_column_compute( + remote_store_tall_skinny: AccessTrackingStore, + adata_remote_tall_skinny: AnnData, +): + remote_store_tall_skinny.initialize_key_trackers(["obs/int64"]) + adata_remote_tall_skinny[adata_remote_tall_skinny.shape[0] // 2, :].obs[ + "int64" + ].compute() + # two chunks needed for 0:10 subset + remote_store_tall_skinny.assert_access_count("obs/int64", 1) + + +def test_access_count_index( + remote_store_tall_skinny: AccessTrackingStore, +): + remote_store_tall_skinny.initialize_key_trackers(["obs/_index"]) + read_lazy(remote_store_tall_skinny, load_annotation_index=False) + remote_store_tall_skinny.assert_access_count("obs/_index", 0) + read_lazy(remote_store_tall_skinny) + # 4 is number of chunks + remote_store_tall_skinny.assert_access_count("obs/_index", 4) + + +def test_access_count_dtype( + remote_store_tall_skinny: AccessTrackingStore, + adata_remote_tall_skinny: AnnData, +): + remote_store_tall_skinny.initialize_key_trackers(["obs/cat/categories"]) + remote_store_tall_skinny.assert_access_count("obs/cat/categories", 0) + # This should only cause categories to be read in once + adata_remote_tall_skinny.obs["cat"].dtype + adata_remote_tall_skinny.obs["cat"].dtype + adata_remote_tall_skinny.obs["cat"].dtype + remote_store_tall_skinny.assert_access_count("obs/cat/categories", 1) + + +def test_uns_uses_dask(adata_remote: AnnData): + assert isinstance(adata_remote.uns["nested"]["nested_further"]["array"], DaskArray) + + +def test_to_memory(adata_remote: AnnData, adata_orig: AnnData): + remote_to_memory = adata_remote.to_memory() + assert_equal(remote_to_memory, adata_orig) + + +def test_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): + subset_obs = adata_orig.obs["obs_cat"] == "a" + assert_equal(adata_orig[subset_obs, :], adata_remote[subset_obs, :].to_memory()) + + subset_var = adata_orig.var["var_cat"] == "a" + assert_equal(adata_orig[:, subset_var], adata_remote[:, subset_var].to_memory()) + + +def test_view_of_view_to_memory(adata_remote: AnnData, adata_orig: AnnData): + subset_obs = (adata_orig.obs["obs_cat"] == "a") | (adata_orig.obs["obs_cat"] == "b") + subsetted_adata = adata_orig[subset_obs, :] + subset_subset_obs = subsetted_adata.obs["obs_cat"] == "b" + subsetted_subsetted_adata = subsetted_adata[subset_subset_obs, :] + assert_equal( + subsetted_subsetted_adata, + adata_remote[subset_obs, :][subset_subset_obs, :].to_memory(), + ) + + subset_var = (adata_orig.var["var_cat"] == "a") | (adata_orig.var["var_cat"] == "b") + subsetted_adata = adata_orig[:, subset_var] + subset_subset_var = subsetted_adata.var["var_cat"] == "b" + subsetted_subsetted_adata = subsetted_adata[:, subset_subset_var] + assert_equal( + subsetted_subsetted_adata, + adata_remote[:, subset_var][:, subset_subset_var].to_memory(), + ) + + +def test_unconsolidated(tmp_path: Path, mtx_format): + adata = gen_adata((1000, 1000), mtx_format) + orig_pth = tmp_path / "orig.zarr" + adata.write_zarr(orig_pth) + (orig_pth / ".zmetadata").unlink() + store = AccessTrackingStore(orig_pth) + store.initialize_key_trackers(["obs/.zgroup", ".zgroup"]) + with pytest.warns(UserWarning, match=r"Did not read zarr as consolidated"): + remote = read_lazy(store) + remote_to_memory = remote.to_memory() + assert_equal(remote_to_memory, adata) + store.assert_access_count("obs/.zgroup", 1) + + +def unify_extension_dtypes( + remote: pd.DataFrame, memory: pd.DataFrame +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + For concatenated lazy datasets, we send the extension arrays through dask + But this means we lose the pandas dtype, so this function corrects that. + + Parameters + ---------- + remote + The dataset that comes from the concatenated lazy operation + memory + The in-memory, "correct" version + + Returns + ------- + The two dataframes unified + """ + for col in memory.columns: + dtype = memory[col].dtype + if pd.api.types.is_extension_array_dtype(dtype): + remote[col] = remote[col].astype(dtype) + return remote, memory + + +ANNDATA_ELEMS = typing.get_args(AnnDataElem) + + +@pytest.mark.parametrize("join", ["outer", "inner"]) +@pytest.mark.parametrize( + ("elem_key", "sub_key"), + [ + ("obs", "cat"), + ("obs", "int64"), + *((elem_name, None) for elem_name in ANNDATA_ELEMS), + ], +) +def test_concat_access_count( + adatas_for_concat: list[AnnData], + stores_for_concat: list[AccessTrackingStore], + lazy_adatas_for_concat: list[AnnData], + join: Join_T, + elem_key: AnnDataElem, + sub_key: str, + simple_subset_func: Callable[[AnnData], AnnData], +): + # track all elems except codes because they must be read in for concatenation + non_categorical_columns = ( + f"{elem}/{col}" if "cat" not in col else f"{elem}/{col}/codes" + for elem in ["obs", "var"] + for col in adatas_for_concat[0].obs.columns + ) + non_obs_var_keys = filter(lambda e: e not in {"obs", "var"}, ANNDATA_ELEMS) + keys_to_track = [*non_categorical_columns, *non_obs_var_keys] + for store in stores_for_concat: + store.initialize_key_trackers(keys_to_track) + concated_remote = ad.concat(lazy_adatas_for_concat, join=join) + # a series of methods that should __not__ read in any data + elem = getattr(simple_subset_func(concated_remote), elem_key) + if sub_key is not None: + getattr(elem, sub_key) + for store in stores_for_concat: + for elem in keys_to_track: + store.assert_access_count(elem, 0) + + +def test_concat_to_memory_obs_access_count( + adatas_for_concat: list[AnnData], + stores_for_concat: list[AccessTrackingStore], + lazy_adatas_for_concat: list[AnnData], + join: Join_T, + simple_subset_func: Callable[[AnnData], AnnData], +): + concated_remote = simple_subset_func(ad.concat(lazy_adatas_for_concat, join=join)) + concated_remote_subset = simple_subset_func(concated_remote) + n_datasets = len(adatas_for_concat) + obs_keys_to_track = get_key_trackers_for_columns_on_axis( + adatas_for_concat[0], "obs" + ) + for store in stores_for_concat: + store.initialize_key_trackers(obs_keys_to_track) + concated_remote_subset.to_memory() + # check access count for the stores - only the first should be accessed when reading into memory + for col in obs_keys_to_track: + stores_for_concat[0].assert_access_count(col, 1) + for i in range(1, n_datasets): + # if the shapes are the same, data was read in to bring the object into memory; otherwise, not + stores_for_concat[i].assert_access_count( + col, concated_remote_subset.shape[0] == concated_remote.shape[0] + ) + + +def test_concat_to_memory_obs( + adatas_for_concat: list[AnnData], + lazy_adatas_for_concat: list[AnnData], + join: Join_T, + simple_subset_func: Callable[[AnnData], AnnData], +): + concatenated_memory = simple_subset_func(ad.concat(adatas_for_concat, join=join)) + concated_remote = simple_subset_func(ad.concat(lazy_adatas_for_concat, join=join)) + # TODO: name is lost normally, should fix + obs_memory = concatenated_memory.obs + obs_memory.index.name = "obs_names" + assert_equal( + *unify_extension_dtypes( + concated_remote.obs.to_pandas(), concatenated_memory.obs + ) + ) + + +def test_concat_to_memory_obs_dtypes( + lazy_adatas_for_concat: list[AnnData], + join: Join_T, +): + concated_remote = ad.concat(lazy_adatas_for_concat, join=join) + # check preservation of non-categorical dtypes on the concat axis + assert concated_remote.obs["int64"].dtype == "int64" + assert concated_remote.obs["uint8"].dtype == "uint8" + assert concated_remote.obs["nullable-int"].dtype == "int32" + assert concated_remote.obs["float64"].dtype == "float64" + assert concated_remote.obs["bool"].dtype == "bool" + assert concated_remote.obs["nullable-bool"].dtype == "bool" + + +def test_concat_to_memory_var( + var_indices_for_concat: list[pd.Index], + adatas_for_concat: list[AnnData], + stores_for_concat: list[AccessTrackingStore], + lazy_adatas_for_concat: list[AnnData], + join: Join_T, + are_vars_different: bool, + simple_subset_func: Callable[[AnnData], AnnData], +): + concated_remote = simple_subset_func(ad.concat(lazy_adatas_for_concat, join=join)) + var_keys_to_track = get_key_trackers_for_columns_on_axis( + adatas_for_concat[0], "var" + ) + for store in stores_for_concat: + store.initialize_key_trackers(var_keys_to_track) + # check non-different variables, taken from first annotation. + pd_index_overlapping = pd.Index( + filter(lambda x: not x.endswith("ds"), var_indices_for_concat[0]) + ) + var_df_overlapping = adatas_for_concat[0][:, pd_index_overlapping].var.copy() + test_cases = [(pd_index_overlapping, var_df_overlapping, 0)] + if are_vars_different and join == "outer": + # check a set of unique variables from the first object since we only take from there if different + pd_index_only_ds_0 = pd.Index( + filter(lambda x: "0_ds" in x, var_indices_for_concat[1]) + ) + var_df_only_ds_0 = adatas_for_concat[0][:, pd_index_only_ds_0].var.copy() + test_cases.append((pd_index_only_ds_0, var_df_only_ds_0, 0)) + for pd_index, var_df, store_idx in test_cases: + var_df.index.name = "var_names" + remote_df = concated_remote[:, pd_index].var.to_pandas() + remote_df_corrected, _ = unify_extension_dtypes(remote_df, var_df) + # TODO:xr.merge always upcasts to float due to NA and you can't downcast? + for col in remote_df_corrected.columns: + dtype = remote_df_corrected[col].dtype + if dtype in [np.float64, np.float32]: + var_df[col] = var_df[col].astype(dtype) + assert_equal(remote_df_corrected, var_df) + for key in var_keys_to_track: + stores_for_concat[store_idx].assert_access_count(key, 1) + for store in stores_for_concat: + if store != stores_for_concat[store_idx]: + store.assert_access_count(key, 0) + stores_for_concat[store_idx].reset_key_trackers() + + +def test_concat_data_with_cluster_to_memory( + adata_remote: AnnData, join: Join_T, load_annotation_index: bool +): + import dask.distributed as dd + + with ( + dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster, + dd.Client(cluster), + ): + with ( + pytest.warns(UserWarning, match=r"Concatenating with a pandas numeric") + if not load_annotation_index + else nullcontext() + ): + ad.concat([adata_remote, adata_remote], join=join).to_memory() + + +@pytest.mark.parametrize( + "index", + [ + pytest.param( + slice(500, 1500), + id="slice", + ), + pytest.param( + np.arange(950, 1050), + id="consecutive integer array", + ), + pytest.param( + np.random.randint(800, 1100, 500), + id="random integer array", + ), + pytest.param( + np.random.choice([True, False], 2000), + id="boolean array", + ), + pytest.param(slice(None), id="full slice"), + pytest.param("a", id="categorical_subset"), + pytest.param(None, id="No index"), + ], +) +def test_concat_data( + adata_remote: AnnData, + adata_orig: AnnData, + join: Join_T, + index: slice | NDArray | Literal["a"] | None, + load_annotation_index: bool, +): + from anndata.experimental.backed._compat import Dataset2D + + with ( + pytest.warns(UserWarning, match=r"Concatenating with a pandas numeric") + if not load_annotation_index + else nullcontext() + ): + remote_concatenated = ad.concat([adata_remote, adata_remote], join=join) + if index is not None: + if np.isscalar(index) and index == "a": + index = remote_concatenated.obs["obs_cat"] == "a" + remote_concatenated = remote_concatenated[index] + orig_concatenated = ad.concat([adata_orig, adata_orig], join=join) + if index is not None: + orig_concatenated = orig_concatenated[index] + in_memory_remote_concatenated = remote_concatenated.to_memory() + corrected_remote_obs, corrected_memory_obs = unify_extension_dtypes( + in_memory_remote_concatenated.obs, orig_concatenated.obs + ) + assert isinstance(remote_concatenated.obs, Dataset2D) + assert_equal(corrected_remote_obs, corrected_memory_obs) + assert_equal(in_memory_remote_concatenated.X, orig_concatenated.X) + assert ( + in_memory_remote_concatenated.var_names.tolist() + == orig_concatenated.var_names.tolist() + ) + + +@pytest.mark.parametrize( + ("attr", "key"), + ( + pytest.param(param[0], param[1], id="-".join(map(str, param))) + for param in [("obs", None), ("var", None), ("obsm", "df"), ("varm", "df")] + ), +) +def test_concat_df_ds_mixed_types( + adata_remote: AnnData, + adata_orig: AnnData, + load_annotation_index: bool, + join: Join_T, + attr: str, + key: str | None, +): + def with_elem_in_memory(adata: AnnData, attr: str, key: str | None) -> AnnData: + parent_elem = getattr(adata, attr) + if key is not None: + getattr(adata, attr)[key] = to_memory(parent_elem[key]) + return adata + setattr(adata, attr, to_memory(parent_elem)) + return adata + + if not load_annotation_index: + pytest.skip( + "Testing for mixed types is independent of the axis since the indices always have to match." + ) + remote = with_elem_in_memory(adata_remote, attr, key) + in_memory_concatenated = ad.concat([adata_orig, adata_orig], join=join) + mixed_concatenated = ad.concat([remote, adata_orig], join=join) + assert_equal(mixed_concatenated, in_memory_concatenated) + + +def test_concat_bad_mixed_types(tmp_path: str): + orig = gen_adata((100, 200), np.array) + orig.write_zarr(tmp_path) + remote = read_lazy(tmp_path) + orig.obsm["df"] = orig.obsm["array"] + with pytest.raises(ValueError, match=r"Cannot concatenate a Dataset2D*"): + ad.concat([remote, orig], join="outer") diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 518559995..1da55a0cd 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -847,3 +847,11 @@ def test_h5py_attr_limit(tmp_path): np.ones((5, N)), index=a.obs_names, columns=[str(i) for i in range(N)] ) a.write(tmp_path / "tmp.h5ad") + + +@pytest.mark.skipif( + find_spec("xarray"), reason="Xarray is installed so `read_lazy` will not error" +) +def test_read_lazy_import_error(): + with pytest.raises(ImportError, match="xarray"): + ad.experimental.read_lazy("test.zarr")