diff --git a/funlib/persistence/arrays/array.py b/funlib/persistence/arrays/array.py index 2317e63..fa7e3bc 100644 --- a/funlib/persistence/arrays/array.py +++ b/funlib/persistence/arrays/array.py @@ -5,6 +5,7 @@ import dask.array as da import numpy as np from dask.array.optimization import fuse_slice +from zarr import Array as ZarrArray from funlib.geometry import Coordinate, Roi @@ -83,6 +84,10 @@ def __init__( shape=self._source_data.shape, ) + # used for custom metadata unrelated to indexing with physical units + # only used if not reading from zarr and there is no built in `.attrs` + self._attrs = {} + if lazy_op is not None: self.apply_lazy_ops(lazy_op) @@ -93,6 +98,18 @@ def __init__( self.validate() + @property + def attrs(self) -> dict: + """ + Return dict that can be used to store custom metadata. Will be persistent + for zarr arrays. If reading from zarr, any existing metadata (such as + voxel_size, axis_names, etc.) will also be exposed here. + """ + if isinstance(self._source_data, ZarrArray): + return self._source_data.attrs + else: + return self._attrs + @property def chunk_shape(self) -> Coordinate: return Coordinate(self.data.chunksize) diff --git a/funlib/persistence/arrays/datasets.py b/funlib/persistence/arrays/datasets.py index 5b033d2..a9636c3 100644 --- a/funlib/persistence/arrays/datasets.py +++ b/funlib/persistence/arrays/datasets.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Sequence, Union +from typing import Any, Optional, Sequence, Union import numpy as np import zarr @@ -128,6 +128,7 @@ def prepare_ds( chunk_shape: Optional[Sequence[int]] = None, dtype: DTypeLike = np.float32, mode: str = "a", + custom_metadata: dict[str, Any] | None = None, **kwargs, ) -> Array: """Prepare a Zarr or N5 dataset. @@ -179,6 +180,11 @@ def prepare_ds( The mode to open the dataset in. See https://zarr.readthedocs.io/en/stable/api/creation.html#zarr.creation.open_array + custom_metadata: + + A dictionary of custom metadata to add to the dataset. This will be written to the + zarr .attrs object. + kwargs: See additional arguments available here: @@ -319,14 +325,18 @@ def prepare_ds( raise ArrayNotFoundError(f"Nothing found at path {store}") default_metadata_format = get_default_metadata_format() - ds.attrs.put( - { - default_metadata_format.axis_names_attr: combined_metadata.axis_names, - default_metadata_format.units_attr: combined_metadata.units, - default_metadata_format.voxel_size_attr: combined_metadata.voxel_size, - default_metadata_format.offset_attr: combined_metadata.offset, - } - ) + our_metadata = { + default_metadata_format.axis_names_attr: combined_metadata.axis_names, + default_metadata_format.units_attr: combined_metadata.units, + default_metadata_format.voxel_size_attr: combined_metadata.voxel_size, + default_metadata_format.offset_attr: combined_metadata.offset, + } + # check keys don't conflict + if custom_metadata is not None: + assert set(our_metadata.keys()).isdisjoint(custom_metadata.keys()) + our_metadata.update(custom_metadata) + + ds.attrs.put(our_metadata) # open array array = Array(ds, offset, voxel_size, axis_names, units) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index cfea528..aa00f6d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,6 +15,23 @@ } +@pytest.mark.parametrize("store", stores.keys()) +def test_metadata(tmpdir, store): + store = tmpdir / store + + # test prepare_ds creates array if it does not exist and mode is write + array = prepare_ds( + store, + (10, 10), + mode="w", + custom_metadata={"custom": "metadata"}, + ) + assert array.attrs["custom"] == "metadata" + array.attrs["custom2"] = "new metadata" + + assert open_ds(store).attrs["custom2"] == "new metadata" + + @pytest.mark.parametrize("store", stores.keys()) @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) def test_helpers(tmpdir, store, dtype):