From b19a5eae55d596a28a796aeaf56fe324e06e102d Mon Sep 17 00:00:00 2001 From: Philipp A Date: Mon, 21 Aug 2023 15:08:48 +0200 Subject: [PATCH 1/8] Fix benchmarks and pandas compat (#1100) --- .github/workflows/benchmark.yml | 23 +++++++++++------------ anndata/tests/test_views.py | 8 ++++++++ benchmarks/asv.conf.json | 4 ++-- pyproject.toml | 4 +++- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 12d4afb82..b6d3919d7 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -11,12 +11,12 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - shell: bash -e {0} # -e to fail on error + shell: bash -el {0} # -e to fail on error, -l for mamba strategy: fail-fast: false matrix: - python: ["3.10"] + python: ["3.11"] os: [ubuntu-latest] env: @@ -33,12 +33,15 @@ jobs: if: ${{ github.ref_name != 'main' }} # Errors on main branch - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v4 + - uses: mamba-org/setup-micromamba@v1 with: - python-version: ${{ matrix.python }} - cache: "pip" - cache-dependency-path: "**/pyproject.toml" + environment-name: asv + cache-environment: true + create-args: >- + python=3.11 + asv + mamba + packaging - name: Cache datasets uses: actions/cache@v3 @@ -47,12 +50,8 @@ jobs: ~/.cache key: benchmark-state-${{ hashFiles('benchmarks/**') }} - - name: Install dependencies - run: | - pip install asv - - name: Quick benchmark run working-directory: ${{ env.ASV_DIR }} run: | asv machine --yes - asv run -qev --strict + asv run --quick --show-stderr --verbose diff --git a/anndata/tests/test_views.py b/anndata/tests/test_views.py index f9e41bef1..6770b2000 100644 --- a/anndata/tests/test_views.py +++ b/anndata/tests/test_views.py @@ -108,6 +108,14 @@ def test_views(): assert adata_subset.obs["foo"].tolist() == list(range(2)) +def test_view_subset_shapes(): + adata = gen_adata((20, 10), **GEN_ADATA_DASK_ARGS) + + view = adata[:, ::2] + assert view.var.shape == (5, 8) + assert {k: v.shape[0] for k, v in view.varm.items()} == {k: 5 for k in view.varm} + + def test_modify_view_component(matrix_type, mapping_name): adata = ad.AnnData( np.zeros((10, 10)), diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d91da2150..0761f6722 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -44,7 +44,7 @@ // If missing or the empty string, the tool will be automatically // determined by looking for tools on the PATH environment // variable. - "environment_type": "conda", + "environment_type": "mamba", // timeout in seconds for installing any dependencies in environment // defaults to 10 min @@ -117,7 +117,7 @@ // // additional env for python2.7 // {"python": "2.7", "numpy": "1.8"}, // // additional env if run on windows+conda - // {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""}, + // {"platform": "win32", "environment_type": "mamba", "python": "2.7", "libpython": ""}, // ], // The directory (relative to the current directory) that benchmarks are diff --git a/pyproject.toml b/pyproject.toml index 123f81960..f8e97c1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Visualization", ] dependencies = [ - "pandas>=1.1.1", # pandas <1.1.1 has pandas/issues/35446 + # pandas <1.1.1 has pandas/issues/35446 + # pandas 2.1.0rc0 has pandas/issues/54622 + "pandas >=1.1.1, !=2.1.0rc0", "numpy>=1.16.5", # required by pandas 1.x "scipy>1.4", "h5py>=3", From 8341754a1cc28c9e72f8788ffcf8cc29ba2adabb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:16:51 +0000 Subject: [PATCH 2/8] [pre-commit.ci] pre-commit autoupdate (#1096) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A --- .pre-commit-config.yaml | 2 +- anndata/tests/test_io_elementwise.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7812f92f..63f056097 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: "v0.0.282" + rev: "v0.0.284" hooks: - id: ruff args: ["--fix"] diff --git a/anndata/tests/test_io_elementwise.py b/anndata/tests/test_io_elementwise.py index 688215fa1..a90f563e2 100644 --- a/anndata/tests/test_io_elementwise.py +++ b/anndata/tests/test_io_elementwise.py @@ -187,10 +187,10 @@ def test_categorical_order_type(store): write_elem(store, "ordered", cat) write_elem(store, "unordered", cat.set_ordered(False)) + assert isinstance(read_elem(store["ordered"]).ordered, bool) assert read_elem(store["ordered"]).ordered is True - assert type(read_elem(store["ordered"]).ordered) == bool + assert isinstance(read_elem(store["unordered"]).ordered, bool) assert read_elem(store["unordered"]).ordered is False - assert type(read_elem(store["unordered"]).ordered) == bool def test_override_specification(): From bec216e22a306493413e5bbfa5f8121a743824e7 Mon Sep 17 00:00:00 2001 From: Severin Dicks <37635888+Intron7@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:18:01 +0200 Subject: [PATCH 3/8] pepy fix (#1097) Co-authored-by: Philipp A --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 72e26f6d6..af784833a 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ [![Coverage](https://codecov.io/gh/scverse/anndata/branch/main/graph/badge.svg?token=IN1mJN1Wi8)](https://codecov.io/gh/scverse/anndata) [![Docs](https://readthedocs.com/projects/icb-anndata/badge/?version=latest)](https://anndata.readthedocs.io) [![PyPI](https://img.shields.io/pypi/v/anndata.svg)](https://pypi.org/project/anndata) -[![PyPIDownloadsMonth](https://img.shields.io/pypi/dm/scanpy?logo=PyPI&color=blue)](https://pypi.org/project/anndata) -[![PyPIDownloadsTotal](https://pepy.tech/badge/anndata)](https://pepy.tech/project/anndata) +[![Downloads](https://static.pepy.tech/badge/anndata/month)](https://pepy.tech/project/anndata) +[![Downloads](https://static.pepy.tech/badge/anndata)](https://pepy.tech/project/anndata) [![Stars](https://img.shields.io/github/stars/scverse/anndata?logo=GitHub&color=yellow)](https://github.com/scverse/anndata/stargazers) [![Powered by NumFOCUS](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](http://numfocus.org) From 2e022bce679e6c89e548cacf7ad6328d84a1dc36 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Mon, 21 Aug 2023 15:35:56 +0200 Subject: [PATCH 4/8] update main CI python version (#1102) --- .azure-pipelines.yml | 12 ++++++------ .github/workflows/test-gpu.yml | 2 +- .readthedocs.yml | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.azure-pipelines.yml b/.azure-pipelines.yml index 7dce7d5f1..c3d9e6c23 100644 --- a/.azure-pipelines.yml +++ b/.azure-pipelines.yml @@ -13,13 +13,13 @@ jobs: vmImage: "ubuntu-22.04" strategy: matrix: - Python310: - python.version: "3.10" + Python3.11: + python.version: "3.11" RUN_COVERAGE: yes - Python38: + Python3.8: python.version: "3.8" PreRelease: - python.version: "3.10" + python.version: "3.11" PRERELEASE_DEPENDENCIES: yes steps: - task: UsePythonVersion@0 @@ -87,8 +87,8 @@ jobs: steps: - task: UsePythonVersion@0 inputs: - versionSpec: "3.10" - displayName: "Use Python 3.10" + versionSpec: "3.11" + displayName: "Use Python 3.11" - script: | python -m pip install --upgrade pip diff --git a/.github/workflows/test-gpu.yml b/.github/workflows/test-gpu.yml index c6a74e57e..c32ce2492 100644 --- a/.github/workflows/test-gpu.yml +++ b/.github/workflows/test-gpu.yml @@ -51,7 +51,7 @@ jobs: micromamba-version: "1.3.1-0" environment-name: anndata-gpu-ci create-args: >- - python=3.10 + python=3.11 cupy numba pytest diff --git a/.readthedocs.yml b/.readthedocs.yml index 7b443598a..ec7305492 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,7 +2,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.10" + python: "3.11" sphinx: configuration: docs/conf.py fail_on_warning: true # do not change or you will be fired From c95571c024e876e8731c9289366d725110236563 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Aug 2023 07:28:08 +0000 Subject: [PATCH 5/8] [pre-commit.ci] pre-commit autoupdate (#1104) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63f056097..658017792 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ repos: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: "v0.0.284" + rev: "v0.0.285" hooks: - id: ruff args: ["--fix"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.1 + rev: v3.0.2 hooks: - id: prettier - repo: https://github.com/pre-commit/pre-commit-hooks From caed9264f63422150294682756b83c3e38292bc7 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Thu, 24 Aug 2023 16:56:12 +0200 Subject: [PATCH 6/8] Use __notes__ for IO exceptions (#1055) * Modify write errors with notes * Modify read errors with notes * Test tests * test error messages * add release note * Tests for writing group encoded types to the root group * Update typing Co-authored-by: Philipp A. * Don't reference 'above error' message in note * Fix annotation usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Philipp A. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- anndata/__init__.py | 7 +++ anndata/_io/specs/methods.py | 14 +++--- anndata/_io/utils.py | 44 ++++++++----------- anndata/compat/__init__.py | 2 + anndata/compat/exceptiongroups.py | 12 +++++ anndata/tests/helpers.py | 33 ++++++++++++++ anndata/tests/test_helpers.py | 29 ++++++++++++ anndata/tests/test_io_elementwise.py | 66 ++++++++++++++++++++++++---- anndata/tests/test_io_utils.py | 44 +++++++++++++++---- anndata/tests/test_readwrite.py | 27 +++++++----- docs/release-notes/0.10.0.md | 1 + pyproject.toml | 1 + 12 files changed, 221 insertions(+), 59 deletions(-) create mode 100644 anndata/compat/exceptiongroups.py diff --git a/anndata/__init__.py b/anndata/__init__.py index f4fda20aa..2b6822363 100644 --- a/anndata/__init__.py +++ b/anndata/__init__.py @@ -12,6 +12,13 @@ "anndata is not correctly installed. Please install it, e.g. with pip." ) +# Allowing notes to be added to exceptions. See: https://github.com/scverse/anndata/issues/868 +import sys + +if sys.version_info < (3, 11): + # Backport package for exception groups + import exceptiongroup # noqa: F401 + from ._core.anndata import AnnData from ._core.merge import concat from ._core.raw import Raw diff --git a/anndata/_io/specs/methods.py b/anndata/_io/specs/methods.py index d4c6bb996..ea89d5610 100644 --- a/anndata/_io/specs/methods.py +++ b/anndata/_io/specs/methods.py @@ -270,7 +270,7 @@ def read_anndata(elem, _reader): @_REGISTRY.register_write(H5Group, Raw, IOSpec("raw", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, Raw, IOSpec("raw", "0.1.0")) def write_raw(f, k, raw, _writer, dataset_kwargs=MappingProxyType({})): - g = f.create_group(k) + g = f.require_group(k) _writer.write_elem(g, "X", raw.X, dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "var", raw.var, dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "varm", dict(raw.varm), dataset_kwargs=dataset_kwargs) @@ -290,7 +290,7 @@ def read_mapping(elem, _reader): @_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0")) @_REGISTRY.register_write(ZarrGroup, dict, IOSpec("dict", "0.1.0")) def write_mapping(f, k, v, _writer, dataset_kwargs=MappingProxyType({})): - g = f.create_group(k) + g = f.require_group(k) for sub_k, sub_v in v.items(): _writer.write_elem(g, sub_k, sub_v, dataset_kwargs=dataset_kwargs) @@ -459,7 +459,7 @@ def write_sparse_compressed( fmt: Literal["csr", "csc"], dataset_kwargs=MappingProxyType({}), ): - g = f.create_group(key) + g = f.require_group(key) g.attrs["shape"] = value.shape # Allow resizing for hdf5 @@ -546,7 +546,7 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))) def write_awkward(f, k, v, _writer, dataset_kwargs=MappingProxyType({})): from anndata.compat import awkward as ak - group = f.create_group(k) + group = f.require_group(k) form, length, container = ak.to_buffers(ak.to_packed(v)) group.attrs["length"] = length group.attrs["form"] = form.to_json() @@ -580,7 +580,7 @@ def write_dataframe(f, key, df, _writer, dataset_kwargs=MappingProxyType({})): for reserved in ("_index",): if reserved in df.columns: raise ValueError(f"{reserved!r} is a reserved name for dataframe columns.") - group = f.create_group(key) + group = f.require_group(key) col_names = [check_key(c) for c in df.columns] group.attrs["column-order"] = col_names @@ -699,7 +699,7 @@ def read_partial_dataframe_0_1_0( @_REGISTRY.register_write(H5Group, pd.Categorical, IOSpec("categorical", "0.2.0")) @_REGISTRY.register_write(ZarrGroup, pd.Categorical, IOSpec("categorical", "0.2.0")) def write_categorical(f, k, v, _writer, dataset_kwargs=MappingProxyType({})): - g = f.create_group(k) + g = f.require_group(k) g.attrs["ordered"] = bool(v.ordered) _writer.write_elem(g, "codes", v.codes, dataset_kwargs=dataset_kwargs) @@ -746,7 +746,7 @@ def read_partial_categorical(elem, *, items=None, indices=(slice(None),)): ZarrGroup, pd.arrays.BooleanArray, IOSpec("nullable-boolean", "0.1.0") ) def write_nullable_integer(f, k, v, _writer, dataset_kwargs=MappingProxyType({})): - g = f.create_group(k) + g = f.require_group(k) if v._mask is not None: _writer.write_elem(g, "mask", v._mask, dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "values", v._data, dataset_kwargs=dataset_kwargs) diff --git a/anndata/_io/utils.py b/anndata/_io/utils.py index 4770c231a..9eb6db875 100644 --- a/anndata/_io/utils.py +++ b/anndata/_io/utils.py @@ -1,14 +1,14 @@ from __future__ import annotations from functools import wraps -from typing import Callable +from typing import Callable, Literal from warnings import warn from packaging import version import h5py from .._core.sparse_dataset import SparseDataset -from anndata.compat import H5Group, ZarrGroup +from anndata.compat import H5Group, ZarrGroup, add_note # For allowing h5py v3 # https://github.com/scverse/anndata/issues/442 @@ -164,6 +164,21 @@ def _get_parent(elem): return parent +def re_raise_error(e, elem, key, op=Literal["read", "writ"]): + if any( + f"Error raised while {op}ing key" in note + for note in getattr(e, "__notes__", []) + ): + raise + else: + parent = _get_parent(elem) + add_note( + e, + f"Error raised while {op}ing key {key!r} of {type(elem)} to " f"{parent}", + ) + raise e + + def report_read_key_on_error(func): """\ A decorator for zarr element reading which makes keys involved in errors get reported. @@ -179,16 +194,6 @@ def report_read_key_on_error(func): >>> read_arr(z["X"]) # doctest: +SKIP """ - def re_raise_error(e, elem): - if isinstance(e, AnnDataReadError): - raise e - else: - parent = _get_parent(elem) - raise AnnDataReadError( - f"Above error raised while reading key {elem.name!r} of " - f"type {type(elem)} from {parent}." - ) from e - @wraps(func) def func_wrapper(*args, **kwargs): from anndata._io.specs import Reader @@ -200,7 +205,7 @@ def func_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - re_raise_error(e, elem) + re_raise_error(e, elem, elem.name, "read") return func_wrapper @@ -220,17 +225,6 @@ def report_write_key_on_error(func): >>> write_arr(z, "X", X) # doctest: +SKIP """ - def re_raise_error(e, elem, key): - if "Above error raised while writing key" in format(e): - raise - else: - parent = _get_parent(elem) - raise type(e)( - f"{e}\n\n" - f"Above error raised while writing key {key!r} of {type(elem)} " - f"to {parent}" - ) from e - @wraps(func) def func_wrapper(*args, **kwargs): from anndata._io.specs import Writer @@ -244,7 +238,7 @@ def func_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - re_raise_error(e, elem, key) + re_raise_error(e, elem, key, "writ") return func_wrapper diff --git a/anndata/compat/__init__.py b/anndata/compat/__init__.py index f881cb78d..168c2922d 100644 --- a/anndata/compat/__init__.py +++ b/anndata/compat/__init__.py @@ -11,6 +11,8 @@ import numpy as np import pandas as pd +from .exceptiongroups import add_note # noqa: F401 + class Empty: pass diff --git a/anndata/compat/exceptiongroups.py b/anndata/compat/exceptiongroups.py new file mode 100644 index 000000000..f64090017 --- /dev/null +++ b/anndata/compat/exceptiongroups.py @@ -0,0 +1,12 @@ +import sys + + +def add_note(err: BaseException, msg: str) -> BaseException: + """ + Adds a note to an exception inplace and returns it. + """ + if sys.version_info < (3, 11): + err.__notes__ = getattr(err, "__notes__", []) + [msg] + else: + err.add_note(msg) + return err diff --git a/anndata/tests/helpers.py b/anndata/tests/helpers.py index 16c3915e1..139d3aaf5 100644 --- a/anndata/tests/helpers.py +++ b/anndata/tests/helpers.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +from contextlib import contextmanager from functools import singledispatch, wraps, partial +import re from string import ascii_letters from typing import Tuple, Optional, Type from collections.abc import Mapping, Collection @@ -601,6 +605,35 @@ def _(a): return as_dense_dask_array(a.toarray()) +@contextmanager +def pytest_8_raises(exc_cls, *, match: str | re.Pattern = None): + """Error handling using pytest 8's support for __notes__. + + See: https://github.com/pytest-dev/pytest/pull/11227 + + Remove once pytest 8 is out! + """ + + with pytest.raises(exc_cls) as exc_info: + yield exc_info + + check_error_or_notes_match(exc_info, match) + + +def check_error_or_notes_match(e: pytest.ExceptionInfo, pattern: str | re.Pattern): + """ + Checks whether the printed error message or the notes contains the given pattern. + + DOES NOT WORK IN IPYTHON - because of the way IPython handles exceptions + """ + import traceback + + message = "".join(traceback.format_exception_only(e.type, e.value)) + assert re.search( + pattern, message + ), f"Could not find pattern: '{pattern}' in error:\n\n{message}\n" + + def as_cupy_type(val, typ=None): """ Rough conversion function diff --git a/anndata/tests/test_helpers.py b/anndata/tests/test_helpers.py index eb074c53e..c0c0c094a 100644 --- a/anndata/tests/test_helpers.py +++ b/anndata/tests/test_helpers.py @@ -12,8 +12,10 @@ report_name, gen_adata, asarray, + pytest_8_raises, ) from anndata.utils import dim_len +from anndata.compat import add_note # Testing to see if all error types can have the key name appended. # Currently fails for 22/118 since they have required arguments. Not sure what to do about that. @@ -246,3 +248,30 @@ def test_assert_equal_dask_sparse_arrays(): assert_equal(x, y) assert_equal(y, x) + + +@pytest.mark.parametrize( + "error, match", + [ + (Exception("test"), "test"), + (add_note(AssertionError("foo"), "bar"), "bar"), + (add_note(add_note(AssertionError("foo"), "bar"), "baz"), "bar"), + (add_note(add_note(AssertionError("foo"), "bar"), "baz"), "baz"), + ], +) +def test_check_error_notes_success(error, match): + with pytest_8_raises(Exception, match=match): + raise error + + +@pytest.mark.parametrize( + "error, match", + [ + (Exception("test"), "foo"), + (add_note(AssertionError("foo"), "bar"), "baz"), + ], +) +def test_check_error_notes_failure(error, match): + with pytest.raises(AssertionError): + with pytest_8_raises(Exception, match=match): + raise error diff --git a/anndata/tests/test_io_elementwise.py b/anndata/tests/test_io_elementwise.py index a90f563e2..48f93c405 100644 --- a/anndata/tests/test_io_elementwise.py +++ b/anndata/tests/test_io_elementwise.py @@ -15,10 +15,14 @@ import anndata as ad from anndata._io.specs import _REGISTRY, get_spec, IOSpec from anndata._io.specs.registry import IORegistryError -from anndata._io.utils import AnnDataReadError from anndata.compat import _read_attr, H5Group, ZarrGroup from anndata._io.specs import write_elem, read_elem -from anndata.tests.helpers import assert_equal, gen_adata, as_cupy_type +from anndata.tests.helpers import ( + assert_equal, + as_cupy_type, + pytest_8_raises, + gen_adata, +) @pytest.fixture(params=["h5ad", "zarr"]) @@ -134,7 +138,7 @@ def test_io_spec_raw(store): assert_equal(from_disk.raw, adata.raw) -def test_write_to_root(store): +def test_write_anndata_to_root(store): adata = gen_adata((3, 2)) write_elem(store, "/", adata) @@ -157,11 +161,9 @@ def test_read_iospec_not_found(store, attribute, value): write_elem(store, "/", adata) store["obs"].attrs.update({attribute: value}) - with pytest.raises( - AnnDataReadError, match=r"while reading key '/(obs)?'" - ) as exc_info: + with pytest.raises(IORegistryError) as exc_info: read_elem(store) - msg = str(exc_info.value.__cause__) + msg = str(exc_info.value) assert "No read method registered for IOSpec" in msg assert f"{attribute.replace('-', '_')}='{value}'" in msg @@ -175,9 +177,11 @@ def test_write_io_error(store, obj): full_pattern = re.compile( rf"No method registered for writing {type(obj)} into .*Group" ) - with pytest.raises(IORegistryError, match=r"while writing key '/el'") as exc_info: + + with pytest_8_raises(IORegistryError, match=r"while writing key '/el'") as exc_info: write_elem(store, "/el", obj) - msg = str(exc_info.value.__cause__) + + msg = str(exc_info.value) assert re.search(full_pattern, msg) @@ -210,6 +214,50 @@ def _(store, key, adata): pass +@pytest.mark.parametrize( + "value", + [ + pytest.param({"a": 1}, id="dict"), + pytest.param(gen_adata((3, 2)), id="anndata"), + pytest.param(sparse.random(5, 3, format="csr", density=0.5), id="csr_matrix"), + pytest.param(sparse.random(5, 3, format="csc", density=0.5), id="csc_matrix"), + pytest.param(pd.DataFrame({"a": [1, 2, 3]}), id="dataframe"), + pytest.param(pd.Categorical(list("aabccedd")), id="categorical"), + pytest.param( + pd.Categorical(list("aabccedd"), ordered=True), id="categorical-ordered" + ), + pytest.param( + pd.Categorical([1, 2, 1, 3], ordered=True), id="categorical-numeric" + ), + pytest.param( + pd.arrays.IntegerArray( + np.ones(5, dtype=int), mask=np.array([True, False, True, False, True]) + ), + id="nullable-integer", + ), + pytest.param(pd.array([1, 2, 3]), id="nullable-integer-no-nulls"), + pytest.param( + pd.arrays.BooleanArray( + np.random.randint(0, 2, size=5, dtype=bool), + mask=np.random.randint(0, 2, size=5, dtype=bool), + ), + id="nullable-boolean", + ), + pytest.param( + pd.array([True, False, True, True]), id="nullable-boolean-no-nulls" + ), + ], +) +def test_write_to_root(store, value): + """ + Test that elements which are written as groups can we written to the root group. + """ + write_elem(store, "/", value) + result = read_elem(store) + + assert_equal(result, value) + + @pytest.mark.parametrize("consolidated", [True, False]) def test_read_zarr_from_group(tmp_path, consolidated): # https://github.com/scverse/anndata/issues/1056 diff --git a/anndata/tests/test_io_utils.py b/anndata/tests/test_io_utils.py index 09c1f6c54..f884020a8 100644 --- a/anndata/tests/test_io_utils.py +++ b/anndata/tests/test_io_utils.py @@ -10,8 +10,9 @@ from anndata.compat import _clean_uns from anndata._io.utils import ( report_read_key_on_error, - AnnDataReadError, ) +from anndata.experimental import read_elem, write_elem +from anndata.tests.helpers import pytest_8_raises @pytest.fixture(params=["h5ad", "zarr"]) @@ -35,12 +36,12 @@ def read_attr(_): with group if hasattr(group, "__enter__") else suppress(): group["X"] = [1, 2, 3] group.create_group("group") - with pytest.raises(AnnDataReadError) as e: + + with pytest_8_raises(NotImplementedError, match=r"/X"): read_attr(group["X"]) - assert "'/X'" in str(e.value) - with pytest.raises(AnnDataReadError) as e: + + with pytest_8_raises(NotImplementedError, match=r"/group"): read_attr(group["group"]) - assert "'/group'" in str(e.value) def test_write_error_info(diskfmt, tmp_path): @@ -50,9 +51,7 @@ def test_write_error_info(diskfmt, tmp_path): # Assuming we don't define a writer for tuples a = ad.AnnData(uns={"a": {"b": {"c": (1, 2, 3)}}}) - with pytest.raises( - IORegistryError, match=r"Above error raised while writing key 'c'" - ): + with pytest_8_raises(IORegistryError, match=r"Error raised while writing key 'c'"): write(a) @@ -69,3 +68,32 @@ def test_clean_uns(): # var’s categories were overwritten by obs’s, # which we can detect here because var has too high codes assert pd.api.types.is_integer_dtype(adata.var["species"]) + + +@pytest.mark.parametrize( + "group_fn", + [ + pytest.param(lambda _: zarr.group(), id="zarr"), + pytest.param(lambda p: h5py.File(p / "test.h5", mode="a"), id="h5py"), + ], +) +def test_only_child_key_reported_on_failure(tmp_path, group_fn): + class Foo: + pass + + group = group_fn(tmp_path) + + # This regex checks that the pattern inside the (?!...) group does not exist in the string + # (?!...) is a negative lookahead + # (?s) enables the dot to match newlines + # https://stackoverflow.com/a/406408/130164 <- copilot suggested lol + pattern = r"(?s)((?!Error raised while writing key '/?a').)*$" + + with pytest_8_raises(IORegistryError, match=pattern): + write_elem(group, "/", {"a": {"b": Foo()}}) + + write_elem(group, "/", {"a": {"b": [1, 2, 3]}}) + group["a/b"].attrs["encoding-type"] = "not a real encoding type" + + with pytest_8_raises(IORegistryError, match=pattern): + read_elem(group) diff --git a/anndata/tests/test_readwrite.py b/anndata/tests/test_readwrite.py index b4c97d945..a8fde73e7 100644 --- a/anndata/tests/test_readwrite.py +++ b/anndata/tests/test_readwrite.py @@ -1,8 +1,8 @@ -import re from contextlib import contextmanager from importlib.util import find_spec from os import PathLike from pathlib import Path +import re from string import ascii_letters import warnings @@ -15,10 +15,15 @@ import zarr import anndata as ad -from anndata._io.utils import AnnDataReadError +from anndata._io.specs.registry import IORegistryError from anndata.compat import _read_attr, DaskArray -from anndata.tests.helpers import gen_adata, assert_equal, as_dense_dask_array +from anndata.tests.helpers import ( + gen_adata, + assert_equal, + as_dense_dask_array, + pytest_8_raises, +) HERE = Path(__file__).parent @@ -295,13 +300,14 @@ def test_read_full_io_error(tmp_path, name, read, write): write(adata, path) with store_context(path) as store: store["obs"].attrs["encoding-type"] = "invalid" - with pytest.raises( - AnnDataReadError, match=r"raised while reading key '/obs'" + with pytest_8_raises( + IORegistryError, + match=r"raised while reading key '/obs'", ) as exc_info: read(path) assert re.search( r"No read method registered for IOSpec\(encoding_type='invalid', encoding_version='0.2.0'\)", - str(exc_info.value.__cause__), + str(exc_info.value), ) @@ -611,7 +617,7 @@ def test_dataframe_reserved_columns(tmp_path, diskfmt): to_write.obs[colname] = np.ones(5) with pytest.raises(ValueError) as exc_info: getattr(to_write, f"write_{diskfmt}")(adata_pth) - assert colname in str(exc_info.value.__cause__) + assert colname in str(exc_info.value) for colname in reserved: to_write = orig.copy() to_write.varm["df"] = pd.DataFrame( @@ -619,7 +625,7 @@ def test_dataframe_reserved_columns(tmp_path, diskfmt): ) with pytest.raises(ValueError) as exc_info: getattr(to_write, f"write_{diskfmt}")(adata_pth) - assert colname in str(exc_info.value.__cause__) + assert colname in str(exc_info.value) def test_write_large_categorical(tmp_path, diskfmt): @@ -673,9 +679,10 @@ def test_write_string_types(tmp_path, diskfmt): adata.obs[b"c"] = np.zeros(3) # This should error, and tell you which key is at fault - with pytest.raises(TypeError, match=r"writing key 'obs'") as exc_info: + with pytest_8_raises(TypeError, match=r"writing key 'obs'") as exc_info: write(adata_pth) - assert str(b"c") in str(exc_info.value.__cause__) + + assert str("b'c'") in str(exc_info.value) @pytest.mark.parametrize( diff --git a/docs/release-notes/0.10.0.md b/docs/release-notes/0.10.0.md index cb85d7a5a..fab840f03 100644 --- a/docs/release-notes/0.10.0.md +++ b/docs/release-notes/0.10.0.md @@ -12,6 +12,7 @@ * Improved error messages when combining dataframes with duplicated column names {pr}`1029` {user}`ivirshup` * Improved warnings when modifying views of `AlingedMappings` {pr}`1016` {user}`flying-sheep` {user}`ivirshup` +* `AnnDataReadError`s have been removed. The original error is now thrown with additional information in a note {pr}`1055` {user}`ivirshup` ```{rubric} Documentation diff --git a/pyproject.toml b/pyproject.toml index f8e97c1a0..845171ef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "numpy>=1.16.5", # required by pandas 1.x "scipy>1.4", "h5py>=3", + "exceptiongroup; python_version<'3.11'", "natsort", "packaging>=20", "array_api_compat", From 5428f16a7d5f46a2135c4beb15438f02f2b0e8d5 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Thu, 24 Aug 2023 16:58:53 +0200 Subject: [PATCH 7/8] Fix is_categorical_dtype warning (#1099) Co-authored-by: Isaac Virshup Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- anndata/_core/aligned_mapping.py | 13 ++++++--- anndata/_core/anndata.py | 10 ++++--- anndata/_core/merge.py | 40 ++++++++++++++-------------- anndata/tests/test_concatenate.py | 9 +++---- anndata/tests/test_io_utils.py | 2 +- anndata/tests/test_readwrite.py | 21 +++++++-------- docs/release-notes/release-latest.md | 3 ++- 7 files changed, 53 insertions(+), 45 deletions(-) diff --git a/anndata/_core/aligned_mapping.py b/anndata/_core/aligned_mapping.py index 34342a111..f4e359a5e 100644 --- a/anndata/_core/aligned_mapping.py +++ b/anndata/_core/aligned_mapping.py @@ -244,12 +244,17 @@ def _validate_value(self, val: V, key: str) -> V: if ( hasattr(val, "index") and isinstance(val.index, cabc.Collection) - and not (val.index == self.dim_names).all() + and not val.index.equals(self.dim_names) ): # Could probably also re-order index if it’s contained - raise ValueError( - f"value.index does not match parent’s axis {self.axes[0]} names" - ) + try: + pd.testing.assert_index_equal(val.index, self.dim_names) + except AssertionError as e: + msg = f"value.index does not match parent’s axis {self.axes[0]} names:\n{e}" + raise ValueError(msg) from None + else: + msg = "Index.equals and pd.testing.assert_index_equal disagree" + raise AssertionError(msg) return super()._validate_value(val, key) @property diff --git a/anndata/_core/anndata.py b/anndata/_core/anndata.py index b7978896f..8052aa780 100644 --- a/anndata/_core/anndata.py +++ b/anndata/_core/anndata.py @@ -1,6 +1,8 @@ """\ Main class and helper functions. """ +from __future__ import annotations + import warnings import collections.abc as cabc from collections import OrderedDict @@ -19,7 +21,7 @@ import numpy as np from numpy import ma import pandas as pd -from pandas.api.types import infer_dtype, is_string_dtype, is_categorical_dtype +from pandas.api.types import infer_dtype, is_string_dtype from scipy import sparse from scipy.sparse import issparse, csr_matrix @@ -1114,9 +1116,11 @@ def __getitem__(self, index: Index) -> "AnnData": oidx, vidx = self._normalize_indices(index) return AnnData(self, oidx=oidx, vidx=vidx, asview=True) - def _remove_unused_categories(self, df_full, df_sub, uns): + def _remove_unused_categories( + self, df_full: pd.DataFrame, df_sub: pd.DataFrame, uns: dict[str, Any] + ): for k in df_full: - if not is_categorical_dtype(df_full[k]): + if not isinstance(df_full[k].dtype, pd.CategoricalDtype): continue all_categories = df_full[k].cat.categories with pd.option_context("mode.chained_assignment", None): diff --git a/anndata/_core/merge.py b/anndata/_core/merge.py index b2ef895ba..f5eddf8b0 100644 --- a/anndata/_core/merge.py +++ b/anndata/_core/merge.py @@ -4,27 +4,25 @@ from __future__ import annotations from collections import OrderedDict -from collections.abc import Mapping, MutableSet -from functools import reduce, singledispatch -from itertools import repeat -from operator import and_, or_, sub -from typing import ( - Any, +from collections.abc import ( Callable, Collection, + Mapping, + MutableSet, Iterable, - Optional, - Tuple, - TypeVar, - Union, - Literal, + Sequence, ) +from functools import reduce, singledispatch +from itertools import repeat +from operator import and_, or_, sub +from typing import Any, Optional, TypeVar, Union, Literal import typing from warnings import warn, filterwarnings from natsort import natsorted import numpy as np import pandas as pd +from pandas.api.extensions import ExtensionDtype from scipy import sparse from scipy.sparse import spmatrix @@ -211,7 +209,7 @@ def unify_dtypes(dfs: Iterable[pd.DataFrame]) -> list[pd.DataFrame]: df_dtypes = [dict(df.dtypes) for df in dfs] columns = reduce(lambda x, y: x.union(y), [df.columns for df in dfs]) - dtypes = {col: list() for col in columns} + dtypes: dict[str, list[np.dtype | ExtensionDtype]] = {col: [] for col in columns} for col in columns: for df in df_dtypes: dtypes[col].append(df.get(col, None)) @@ -235,7 +233,9 @@ def unify_dtypes(dfs: Iterable[pd.DataFrame]) -> list[pd.DataFrame]: return dfs -def try_unifying_dtype(col: list) -> pd.core.dtypes.base.ExtensionDtype | None: +def try_unifying_dtype( + col: Sequence[np.dtype | ExtensionDtype], +) -> pd.core.dtypes.base.ExtensionDtype | None: """ If dtypes can be unified, returns the dtype they would be unified to. @@ -248,12 +248,12 @@ def try_unifying_dtype(col: list) -> pd.core.dtypes.base.ExtensionDtype | None: A list of dtypes to unify. Can be numpy/ pandas dtypes, or None (which denotes a missing value) """ - dtypes = set() + dtypes: set[pd.CategoricalDtype] = set() # Categorical - if any([pd.api.types.is_categorical_dtype(x) for x in col]): + if any(isinstance(dtype, pd.CategoricalDtype) for dtype in col): ordered = False for dtype in col: - if pd.api.types.is_categorical_dtype(dtype): + if isinstance(dtype, pd.CategoricalDtype): dtypes.add(dtype) ordered = ordered | dtype.ordered elif not pd.isnull(dtype): @@ -261,13 +261,13 @@ def try_unifying_dtype(col: list) -> pd.core.dtypes.base.ExtensionDtype | None: if len(dtypes) > 0 and not ordered: categories = reduce( lambda x, y: x.union(y), - [x.categories for x in dtypes if not pd.isnull(x)], + [dtype.categories for dtype in dtypes if not pd.isnull(dtype)], ) return pd.CategoricalDtype(natsorted(categories), ordered=False) # Boolean - elif all([pd.api.types.is_bool_dtype(x) or x is None for x in col]): - if any([x is None for x in col]): + elif all(pd.api.types.is_bool_dtype(dtype) or dtype is None for dtype in col): + if any(dtype is None for dtype in col): return pd.BooleanDtype() else: return None @@ -942,7 +942,7 @@ def merge_outer(mappings, batch_keys, *, join_index="-", merge=merge_unique): return out -def _resolve_dim(*, dim: str = None, axis: int = None) -> Tuple[int, str]: +def _resolve_dim(*, dim: str = None, axis: int = None) -> tuple[int, str]: _dims = ("obs", "var") if (dim is None and axis is None) or (dim is not None and axis is not None): raise ValueError( diff --git a/anndata/tests/test_concatenate.py b/anndata/tests/test_concatenate.py index 3da17951b..106c4189f 100644 --- a/anndata/tests/test_concatenate.py +++ b/anndata/tests/test_concatenate.py @@ -8,7 +8,6 @@ import numpy as np from numpy import ma import pandas as pd -from pandas.api.types import is_categorical_dtype import pytest from scipy import sparse from boltons.iterutils import research, remap, default_exit @@ -128,7 +127,7 @@ def fix_known_differences(orig, result, backwards_compat=True): # Possibly need to fix this, ordered categoricals lose orderedness for k, dtype in orig.obs.dtypes.items(): - if is_categorical_dtype(dtype) and dtype.ordered: + if isinstance(dtype, pd.CategoricalDtype) and dtype.ordered: result.obs[k] = result.obs[k].astype(dtype) return orig, result @@ -1184,8 +1183,8 @@ def test_concat_categories_maintain_dtype(): result = concat({"a": a, "b": b, "c": c}, join="outer") - assert pd.api.types.is_categorical_dtype( - result.obs["cat"] + assert isinstance( + result.obs["cat"].dtype, pd.CategoricalDtype ), f"Was {result.obs['cat'].dtype}" assert pd.api.types.is_string_dtype(result.obs["cat_ordered"]) @@ -1212,7 +1211,7 @@ def test_concat_ordered_categoricals_retained(): c = concat([a, b]) - assert pd.api.types.is_categorical_dtype(c.obs["cat_ordered"]) + assert isinstance(c.obs["cat_ordered"].dtype, pd.CategoricalDtype) assert c.obs["cat_ordered"].cat.ordered diff --git a/anndata/tests/test_io_utils.py b/anndata/tests/test_io_utils.py index f884020a8..8b94a5feb 100644 --- a/anndata/tests/test_io_utils.py +++ b/anndata/tests/test_io_utils.py @@ -63,7 +63,7 @@ def test_clean_uns(): ) _clean_uns(adata) assert "species_categories" not in adata.uns - assert pd.api.types.is_categorical_dtype(adata.obs["species"]) + assert isinstance(adata.obs["species"].dtype, pd.CategoricalDtype) assert adata.obs["species"].tolist() == ["a", "b", "a"] # var’s categories were overwritten by obs’s, # which we can detect here because var has too high codes diff --git a/anndata/tests/test_readwrite.py b/anndata/tests/test_readwrite.py index a8fde73e7..750a39142 100644 --- a/anndata/tests/test_readwrite.py +++ b/anndata/tests/test_readwrite.py @@ -9,7 +9,6 @@ import h5py import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype import pytest from scipy.sparse import csr_matrix, csc_matrix import zarr @@ -129,7 +128,7 @@ def test_readwrite_h5ad(tmp_path, typ, dataset_kwargs, backing_h5ad): X = typ(X_list) adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) - assert not is_categorical_dtype(adata_src.obs["oanno1"]) + assert not isinstance(adata_src.obs["oanno1"].dtype, pd.CategoricalDtype) adata_src.raw = adata_src adata_src.write(backing_h5ad, **dataset_kwargs) @@ -137,12 +136,12 @@ def test_readwrite_h5ad(tmp_path, typ, dataset_kwargs, backing_h5ad): adata_mid.write(mid_pth, **dataset_kwargs) adata = ad.read_h5ad(mid_pth) - assert is_categorical_dtype(adata.obs["oanno1"]) - assert not is_categorical_dtype(adata.obs["oanno2"]) + assert isinstance(adata.obs["oanno1"].dtype, pd.CategoricalDtype) + assert not isinstance(adata.obs["oanno2"].dtype, pd.CategoricalDtype) assert adata.obs.index.tolist() == ["name1", "name2", "name3"] assert adata.obs["oanno1"].cat.categories.tolist() == ["cat1", "cat2"] assert adata.obs["oanno1c"].cat.categories.tolist() == ["cat1"] - assert is_categorical_dtype(adata.raw.var["vanno2"]) + assert isinstance(adata.raw.var["vanno2"].dtype, pd.CategoricalDtype) pd.testing.assert_frame_equal(adata.obs, adata_src.obs) pd.testing.assert_frame_equal(adata.var, adata_src.var) assert_equal(adata.var.index, adata_src.var.index) @@ -167,16 +166,16 @@ def test_readwrite_zarr(typ, tmp_path): X = typ(X_list) adata_src = ad.AnnData(X, obs=obs_dict, var=var_dict, uns=uns_dict) adata_src.raw = adata_src - assert not is_categorical_dtype(adata_src.obs["oanno1"]) + assert not isinstance(adata_src.obs["oanno1"].dtype, pd.CategoricalDtype) adata_src.write_zarr(tmp_path / "test_zarr_dir", chunks=True) adata = ad.read_zarr(tmp_path / "test_zarr_dir") - assert is_categorical_dtype(adata.obs["oanno1"]) - assert not is_categorical_dtype(adata.obs["oanno2"]) + assert isinstance(adata.obs["oanno1"].dtype, pd.CategoricalDtype) + assert not isinstance(adata.obs["oanno2"].dtype, pd.CategoricalDtype) assert adata.obs.index.tolist() == ["name1", "name2", "name3"] assert adata.obs["oanno1"].cat.categories.tolist() == ["cat1", "cat2"] assert adata.obs["oanno1c"].cat.categories.tolist() == ["cat1"] - assert is_categorical_dtype(adata.raw.var["vanno2"]) + assert isinstance(adata.raw.var["vanno2"].dtype, pd.CategoricalDtype) pd.testing.assert_frame_equal(adata.obs, adata_src.obs) pd.testing.assert_frame_equal(adata.var, adata_src.var) assert_equal(adata.var.index, adata_src.var.index) @@ -251,8 +250,8 @@ def test_readwrite_backed(typ, backing_h5ad): adata_src.write() adata = ad.read(backing_h5ad) - assert is_categorical_dtype(adata.obs["oanno1"]) - assert not is_categorical_dtype(adata.obs["oanno2"]) + assert isinstance(adata.obs["oanno1"].dtype, pd.CategoricalDtype) + assert not isinstance(adata.obs["oanno2"].dtype, pd.CategoricalDtype) assert adata.obs.index.tolist() == ["name1", "name2", "name3"] assert adata.obs["oanno1"].cat.categories.tolist() == ["cat1", "cat2"] assert_equal(adata, adata_src) diff --git a/docs/release-notes/release-latest.md b/docs/release-notes/release-latest.md index 0188b2596..27cf8d866 100644 --- a/docs/release-notes/release-latest.md +++ b/docs/release-notes/release-latest.md @@ -1,7 +1,8 @@ -## Version 0.10.0 +## Version 0.10 ```{include} /release-notes/0.10.0.md ``` + ## Version 0.9 ```{include} /release-notes/0.9.3.md From 70c5d732318d1570ffbd8fa524e8912439484046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+syelman@users.noreply.github.com> Date: Fri, 25 Aug 2023 14:51:04 +0200 Subject: [PATCH 8/8] Dask Distributed Write Fix For Zarr (#1079) * init * add tests * give error for h5py and distributed * add importorskip * add dask distributed to tests * fix extra line * Update pyproject.toml Co-authored-by: Philipp A. * pytest mark need * remove unneeded mark --------- Co-authored-by: Philipp A --- anndata/_io/specs/methods.py | 29 +++++++++++++++++++++++++- anndata/tests/test_dask.py | 40 ++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/anndata/_io/specs/methods.py b/anndata/_io/specs/methods.py index ea89d5610..01086a874 100644 --- a/anndata/_io/specs/methods.py +++ b/anndata/_io/specs/methods.py @@ -38,6 +38,18 @@ H5File = h5py.File +#################### +# Dask utils # +#################### + +try: + from dask.utils import SerializableLock as Lock +except ImportError: + from threading import Lock + +# to fix https://github.com/dask/distributed/issues/780 +GLOBAL_LOCK = Lock() + #################### # Dispatch methods # #################### @@ -331,9 +343,24 @@ def write_basic(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})): @_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0")) +def write_basic_dask_zarr(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})): + import dask.array as da + + g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs) + da.store(elem, g, lock=GLOBAL_LOCK) + + +# Adding this seperately because h5py isn't serializable +# https://github.com/pydata/xarray/issues/4242 @_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0")) -def write_basic_dask(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})): +def write_basic_dask_h5(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})): import dask.array as da + import dask.config as dc + + if dc.get("scheduler", None) == "dask.distributed": + raise ValueError( + "Cannot write dask arrays to hdf5 when using distributed scheduler" + ) g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs) da.store(elem, g) diff --git a/anndata/tests/test_dask.py b/anndata/tests/test_dask.py index e2d820e08..cb745a8f5 100644 --- a/anndata/tests/test_dask.py +++ b/anndata/tests/test_dask.py @@ -11,6 +11,8 @@ gen_adata, assert_equal, ) +from anndata.experimental import write_elem, read_elem +from anndata.experimental.merge import as_group from anndata.compat import DaskArray pytest.importorskip("dask.array") @@ -94,6 +96,44 @@ def test_dask_write(adata, tmp_path, diskfmt): assert isinstance(orig.varm["a"], DaskArray) +def test_dask_distributed_write(adata, tmp_path, diskfmt): + import dask.array as da + import dask.distributed as dd + import numpy as np + + pth = tmp_path / f"test_write.{diskfmt}" + g = as_group(pth, mode="w") + + with dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster: + with dd.Client(cluster): + M, N = adata.X.shape + adata.obsm["a"] = da.random.random((M, 10)) + adata.obsm["b"] = da.random.random((M, 10)) + adata.varm["a"] = da.random.random((N, 10)) + orig = adata + if diskfmt == "h5ad": + with pytest.raises( + ValueError, match="Cannot write dask arrays to hdf5" + ): + write_elem(g, "", orig) + return + write_elem(g, "", orig) + curr = read_elem(g) + + with pytest.raises(Exception): + assert_equal(curr.obsm["a"], curr.obsm["b"]) + + assert_equal(curr.varm["a"], orig.varm["a"]) + assert_equal(curr.obsm["a"], orig.obsm["a"]) + + assert isinstance(curr.X, np.ndarray) + assert isinstance(curr.obsm["a"], np.ndarray) + assert isinstance(curr.varm["a"], np.ndarray) + assert isinstance(orig.X, DaskArray) + assert isinstance(orig.obsm["a"], DaskArray) + assert isinstance(orig.varm["a"], DaskArray) + + def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt): import dask.array as da import numpy as np diff --git a/pyproject.toml b/pyproject.toml index 845171ef8..790fd971a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ test = [ "joblib", "boltons", "scanpy", - "dask[array]", + "dask[array,distributed]", "awkward>=2.3", "pytest_memray", ]