Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): support for zarr-python>=3.0.0b0 #1726

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test = [
"loompy>=3.0.5",
"pytest>=8.2",
"pytest-cov>=2.10",
"zarr<3.0.0a0",
"zarr>=3.0.0b0",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
f"Matrices must have same format. Currently are "
f"{self.format!r} and {sparse_matrix.format!r}"
)
indptr_offset = len(self.group["indices"])
[indptr_offset] = self.group["indices"].shape
if self.group["indptr"].dtype == np.int32:
new_nnz = indptr_offset + len(sparse_matrix.indices)
if new_nnz >= np.iinfo(np.int32).max:
Expand Down
92 changes: 65 additions & 27 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}
elif isinstance(elem, h5py.Dataset):
return h5ad.read_dataset(elem) # TODO: Handle legacy

Expand All @@ -162,7 +162,7 @@
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}

Check warning on line 165 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L165

Added line #L165 was not covered by tests
elif isinstance(elem, ZarrArray):
return zarr.read_dataset(elem) # TODO: Handle legacy

Expand Down Expand Up @@ -335,7 +335,7 @@
@_REGISTRY.register_read(H5Group, IOSpec("dict", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("dict", "0.1.0"))
def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisStorable]:
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}


@_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0"))
Expand Down Expand Up @@ -391,7 +391,7 @@
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
"""Write methods which underlying library handles natively."""
f.create_dataset(k, data=elem, **dataset_kwargs)
f.create_dataset(k, data=elem, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)


_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
Expand All @@ -412,8 +412,12 @@
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import dask.array as da
import zarr

g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
if Version(zarr.__version__) >= Version("3.0.0b0"):
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
else:
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)

Check warning on line 420 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L420

Added line #L420 was not covered by tests
da.store(elem, g, lock=GLOBAL_LOCK)


Expand Down Expand Up @@ -506,23 +510,37 @@
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import numcodecs

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

f.create_dataset(
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
import numcodecs

Check warning on line 516 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L516

Added line #L516 was not covered by tests

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)

Check warning on line 520 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L518-L520

Added lines #L518 - L520 were not covered by tests
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

Check warning on line 523 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L522-L523

Added lines #L522 - L523 were not covered by tests

f.create_dataset(

Check warning on line 525 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L525

Added line #L525 was not covered by tests
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem

Check warning on line 532 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L532

Added line #L532 was not covered by tests
else:
from zarr.codecs import VLenUTF8Codec

f.create_array(
k,
shape=elem.shape,
dtype=str if ad.settings.zarr_write_format == 3 else object,
codecs=[VLenUTF8Codec()] if ad.settings.zarr_write_format == 3 else None,
**dataset_kwargs,
)
f[k][:] = elem


###############
Expand Down Expand Up @@ -577,7 +595,9 @@
):
from anndata.compat import _to_fixed_length_strings

f.create_dataset(k, data=_to_fixed_length_strings(elem), **dataset_kwargs)
f.create_dataset(

Check warning on line 598 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L598

Added line #L598 was not covered by tests
k, data=_to_fixed_length_strings(elem), shape=elem.shape, **dataset_kwargs
)


#################
Expand All @@ -603,9 +623,27 @@
if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)

g.create_dataset("data", data=value.data, **dataset_kwargs)
g.create_dataset("indices", data=value.indices, **dataset_kwargs)
g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs)
g.create_dataset(
"data",
data=value.data,
shape=value.data.shape,
dtype=value.data.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indices",
data=value.indices,
shape=value.indices.shape,
dtype=value.indices.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indptr",
data=value.indptr,
shape=value.indptr.shape,
dtype=indptr_dtype,
**dataset_kwargs,
)


write_csr = partial(write_sparse_compressed, fmt="csr")
Expand Down Expand Up @@ -1121,7 +1159,7 @@
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
return f.create_dataset(key, data=np.array(value), **dataset_kwargs)
return f.create_dataset(key, data=np.array(value), shape=(), **dataset_kwargs)


def write_hdf5_scalar(
Expand Down
21 changes: 17 additions & 4 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Generic, TypeVar

from packaging.version import Version

from anndata._io.utils import report_read_key_on_error, report_write_key_on_error
from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal
from anndata.compat import DaskArray, _read_attr
from anndata.compat import DaskArray, ZarrGroup, _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable
Expand Down Expand Up @@ -341,11 +343,22 @@
return lambda *_, **__: None

# Normalize k to absolute path
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)
is_zarr_group_and_is_zarr_package_v2 = False
if isinstance(store, ZarrGroup):
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
is_zarr_group_and_is_zarr_package_v2 = True

Check warning on line 351 in src/anndata/_io/specs/registry.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/registry.py#L351

Added line #L351 was not covered by tests

if is_zarr_group_and_is_zarr_package_v2 or isinstance(store, h5py.Group):
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)

if k == "/":
store.clear()
if isinstance(store, ZarrGroup):
store.store.clear()
else:
store.clear()
elif k in store:
del store[k]

Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def report_read_key_on_error(func):
>>> @report_read_key_on_error
... def read_arr(group):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> z["X"] = [1, 2, 3]
>>> read_arr(z["X"]) # doctest: +SKIP
"""
Expand Down Expand Up @@ -228,7 +228,7 @@ def report_write_key_on_error(func):
>>> @report_write_key_on_error
... def write_arr(group, key, val):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> X = [1, 2, 3]
>>> write_arr(z, "X", X) # doctest: +SKIP
"""
Expand Down
11 changes: 8 additions & 3 deletions src/anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from anndata._warnings import OldFormatWarning

from .._core.anndata import AnnData
from .._settings import settings
from ..compat import _clean_uns, _from_fixed_length_strings
from ..experimental import read_dispatched, write_dispatched
from .specs import read_elem
Expand All @@ -38,12 +39,16 @@ def write_zarr(
if adata.raw is not None:
adata.strings_to_categoricals(adata.raw.var)
# TODO: Use spec writing system for this
f = zarr.open(store, mode="w")
f = zarr.open_group(store, mode="w", zarr_version=settings.zarr_write_format)
f.attrs.setdefault("encoding-type", "anndata")
f.attrs.setdefault("encoding-version", "0.1.0")

def callback(func, s, k, elem, dataset_kwargs, iospec):
if chunks is not None and not isinstance(elem, sparse.spmatrix) and k == "/X":
if (
chunks is not None
and not isinstance(elem, sparse.spmatrix)
and k in {"X", "/X"}
):
dataset_kwargs = dict(dataset_kwargs, chunks=chunks)
func(s, k, elem, dataset_kwargs=dataset_kwargs)

Expand Down Expand Up @@ -73,7 +78,7 @@ def callback(func, elem_name: str, elem, iospec):
return AnnData(
**{
k: read_dispatched(v, callback)
for k, v in elem.items()
for k, v in dict(elem).items()
if not k.startswith("raw.")
}
)
Expand Down
29 changes: 25 additions & 4 deletions src/anndata/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,12 +395,20 @@
##################################################################################
# PLACE REGISTERED SETTINGS HERE SO THEY CAN BE PICKED UP FOR DOCSTRING CREATION #
##################################################################################
V = TypeVar("V")


def validate_bool(val: Any) -> None:
if not isinstance(val, bool):
msg = f"{val} not valid boolean"
raise TypeError(msg)
def gen_validator(_type: type[V]) -> Callable[[V], None]:
def validate_type(val: V) -> None:
if not isinstance(val, _type):
msg = f"{val} not valid {_type}"
raise TypeError(msg)

Check warning on line 405 in src/anndata/_settings.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_settings.py#L404-L405

Added lines #L404 - L405 were not covered by tests

return validate_type


validate_bool = gen_validator(bool)
validate_int = gen_validator(int)


settings.register(
Expand Down Expand Up @@ -429,6 +437,19 @@
get_from_env=check_and_get_bool,
)

settings.register(
"zarr_write_format",
default_value=2,
description="Which version of zarr to write to.",
validate=validate_int,
get_from_env=lambda name, default: check_and_get_environ_var(
f"ANNDATA_{name.upper()}",
str(default),
["2", "3"],
lambda x: int(x),
),
)


def validate_sparse_settings(val: Any) -> None:
validate_bool(val)
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __exit__(self, *_exc_info) -> None:
#############################

if find_spec("zarr") or TYPE_CHECKING:
from zarr.core import Array as ZarrArray
from zarr.hierarchy import Group as ZarrGroup
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup
else:

class ZarrArray:
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def callback(func, elem_name: str, elem, iospec):
elif iospec.encoding_type == "array":
return elem
elif iospec.encoding_type == "dict":
return {k: read_as_backed(v) for k, v in elem.items()}
return {k: read_as_backed(v) for k, v in dict(elem).items()}
else:
return func(elem)

Expand Down
12 changes: 9 additions & 3 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pandas as pd
import pytest
from packaging.version import Version
from pandas.api.types import is_numeric_dtype
from scipy import sparse

Expand Down Expand Up @@ -1040,17 +1041,22 @@
]

if find_spec("zarr") or TYPE_CHECKING:
from zarr import DirectoryStore
import zarr

if Version(zarr.__version__) > Version("3.0.0b0"):
from zarr.storage import LocalStore
else:
from zarr.storage import DirectoryStore as LocalStore

Check warning on line 1049 in src/anndata/tests/helpers.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/tests/helpers.py#L1049

Added line #L1049 was not covered by tests
else:

class DirectoryStore:
class LocalStore:

Check warning on line 1052 in src/anndata/tests/helpers.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/tests/helpers.py#L1052

Added line #L1052 was not covered by tests
def __init__(self, *_args, **_kwargs) -> None:
cls_name = type(self).__name__
msg = f"zarr must be imported to create a {cls_name} instance."
raise ImportError(msg)


class AccessTrackingStore(DirectoryStore):
class AccessTrackingStore(LocalStore):
_access_count: Counter[str]
_accessed_keys: dict[str, list[str]]

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def tokenize_anndata(adata: ad.AnnData):
res.extend([tokenize(adata.obs), tokenize(adata.var)])
for attr in ["obsm", "varm", "obsp", "varp", "layers"]:
elem = getattr(adata, attr)
res.append(tokenize(list(elem.items())))
res.append(tokenize(list(dict(elem).items())))
res.append(joblib.hash(adata.uns))
if adata.raw is not None:
res.append(tokenize(adata.raw.to_adata()))
Expand Down
5 changes: 4 additions & 1 deletion tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def read_zarr_backed(path):
def callback(func, elem_name, elem, iospec):
if iospec.encoding_type == "anndata" or elem_name.endswith("/"):
return AnnData(
**{k: read_dispatched(v, callback) for k, v in elem.items()}
**{
k: read_dispatched(v, callback)
for k, v in dict(elem).items()
}
)
if iospec.encoding_type in {"csc_matrix", "csr_matrix"}:
return sparse_dataset(elem)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_io_dispatched.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def zarr_reader(func, elem_name: str, elem, iospec):
write_dispatched(f, "/", adata, callback=h5ad_writer)
_ = read_dispatched(f, h5ad_reader)

with zarr.open_group(zarr_path, "w") as f:
write_dispatched(f, "/", adata, callback=zarr_writer)
_ = read_dispatched(f, zarr_reader)
f = zarr.open_group(zarr_path, "w", zarr_version=ad.settings.zarr_write_format)
write_dispatched(f, "/", adata, callback=zarr_writer)
_ = read_dispatched(f, zarr_reader)

assert h5ad_write_keys == zarr_write_keys
assert h5ad_read_keys == zarr_read_keys
Expand Down
Loading
Loading