diff --git a/pyproject.toml b/pyproject.toml index a5e398e0f..257ca05b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "exceptiongroup; python_version<'3.11'", "natsort", "packaging>=20.0", + "pint", # array-api-compat 1.5 has https://github.com/scverse/anndata/issues/1410 "array_api_compat>1.4,!=1.5", ] diff --git a/src/anndata/__init__.py b/src/anndata/__init__.py index c2006cd72..0b8ea96dc 100644 --- a/src/anndata/__init__.py +++ b/src/anndata/__init__.py @@ -21,7 +21,7 @@ # Backport package for exception groups import exceptiongroup # noqa: F401 -from ._core.anndata import AnnData +from ._core.anndata import AnnData, units from ._core.merge import concat from ._core.raw import Raw from ._io import ( @@ -81,4 +81,5 @@ def read(*args, **kwargs): "ExperimentalFeatureWarning", "experimental", "settings", + "units", ] diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 9184fa9c9..c5248ec7d 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -25,6 +25,7 @@ from natsort import natsorted from numpy import ma from pandas.api.types import infer_dtype +from pint import UnitRegistry from scipy import sparse from scipy.sparse import issparse @@ -66,6 +67,8 @@ if TYPE_CHECKING: from os import PathLike +units = UnitRegistry() + class StorageType(Enum): Array = np.ndarray diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 2822154da..a4cb58f35 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -13,7 +13,7 @@ from scipy import sparse import anndata as ad -from anndata import AnnData, Raw +from anndata import AnnData, Raw, units from anndata._core import views from anndata._core.index import _normalize_indices from anndata._core.merge import intersect_keys @@ -293,6 +293,27 @@ def write_raw(f, k, raw, _writer, dataset_kwargs=MappingProxyType({})): _writer.write_elem(g, "varm", dict(raw.varm), dataset_kwargs=dataset_kwargs) +######## +# Pint # +######## + + +@_REGISTRY.register_read(H5Group, IOSpec("pint.Quantity", "0.1.0")) +@_REGISTRY.register_read(ZarrGroup, IOSpec("pint.Quantity", "0.1.0")) +def read_quantity(elem, _reader): + v_magnitude = _reader.read_elem(elem["magnitude"]) + v_units = units[_reader.read_elem(elem["units"])] + return v_magnitude * v_units + + +@_REGISTRY.register_write(H5Group, units.Quantity, IOSpec("pint.Quantity", "0.1.0")) +@_REGISTRY.register_write(ZarrGroup, units.Quantity, IOSpec("pint.Quantity", "0.1.0")) +def write_quantity(f, k, v, _writer, dataset_kwargs=MappingProxyType({})): + g = f.require_group(k) + _writer.write_elem(g, "magnitude", v.magnitude, dataset_kwargs=dataset_kwargs) + _writer.write_elem(g, "units", str(v.units), dataset_kwargs=dataset_kwargs) + + ############ # Mappings # ############ diff --git a/src/anndata/tests/test_readwrite.py b/src/anndata/tests/test_readwrite.py index 1c2630f98..5b7954fb5 100644 --- a/src/anndata/tests/test_readwrite.py +++ b/src/anndata/tests/test_readwrite.py @@ -821,3 +821,17 @@ def test_io_dtype(tmp_path, diskfmt, dtype): curr = read(pth) assert curr.X.dtype == dtype + + +def test_readwrite_units(read, write, name, tmp_path): + X_arr = np.array(X_list) + adata = ad.AnnData( + X=X_arr, + uns={"size": 100 * ad.units["um"]}, + obsm={"X_spatial": X_arr * ad.units.mm}, + ) + write(tmp_path / name, adata) + ad_read = read(tmp_path / name) + + assert adata.uns["spot_size"] == ad_read.uns["spot_size"] + assert (adata.obsm["X_spatial_units"] == ad_read.obsm["X_spatial_units"]).all()