From 3744f1f185aa4e287d8fafcc3bd5d7b0eaf69ec3 Mon Sep 17 00:00:00 2001 From: Mark Bader Date: Tue, 10 Dec 2024 09:57:50 +0100 Subject: [PATCH] Fix axes configuration when adding existing zarr arrays (#1204) * Change default behaviour for axes when creating zarrita array. * add typehint. * Update changelog. * Add test. * Fix type annotations for added test. --- webknossos/Changelog.md | 1 + webknossos/tests/dataset/test_layer.py | 31 +++++++++++++++ webknossos/webknossos/dataset/_array.py | 47 ++++++++++++++++++++--- webknossos/webknossos/dataset/layer.py | 1 + webknossos/webknossos/dataset/mag_view.py | 5 ++- 5 files changed, 78 insertions(+), 7 deletions(-) create mode 100644 webknossos/tests/dataset/test_layer.py diff --git a/webknossos/Changelog.md b/webknossos/Changelog.md index e71a64f31..cc8c67d15 100644 --- a/webknossos/Changelog.md +++ b/webknossos/Changelog.md @@ -80,6 +80,7 @@ Removed the CZI installation extra from `pip install webknossos[all]` by default ### Fixed - Fixed an issue with merging annotations with compressed fallback layers. +- Fixed an issue where adding a Zarr array with other axes than `cxyz` leads to an error. [#1204](https://github.com/scalableminds/webknossos-libs/pull/1204) diff --git a/webknossos/tests/dataset/test_layer.py b/webknossos/tests/dataset/test_layer.py new file mode 100644 index 000000000..7374950c6 --- /dev/null +++ b/webknossos/tests/dataset/test_layer.py @@ -0,0 +1,31 @@ +from pathlib import Path + +import numpy as np +from zarrita import Array + +import webknossos as wk + + +def test_add_mag_from_zarrarray(tmp_path: Path) -> None: + dataset = wk.Dataset( + tmp_path / "test_add_mag_from_zarrarray", voxel_size=(10, 10, 10) + ) + layer = dataset.add_layer( + "color", + wk.COLOR_CATEGORY, + data_format="zarr3", + bounding_box=wk.BoundingBox((0, 0, 0), (16, 16, 16)), + ) + zarr_mag_path = tmp_path / "zarr_data" / "mag1.zarr" + zarr_data = np.random.randint(0, 255, (16, 16, 16), dtype="uint8") + zarr_mag = Array.create( + store=zarr_mag_path, shape=(16, 16, 16), chunk_shape=(8, 8, 8), dtype="uint8" + ) + zarr_mag[:] = zarr_data + + layer.add_mag_from_zarrarray("1", zarr_mag_path, extend_layer_bounding_box=False) + + assert layer.get_mag("1").read().shape == (1, 16, 16, 16) + assert layer.get_mag("1").info.num_channels == 1 + assert layer.get_mag("1").info.dimension_names == ("c", "x", "y", "z") + assert (layer.get_mag("1").read()[0] == zarr_data).all() diff --git a/webknossos/webknossos/dataset/_array.py b/webknossos/webknossos/dataset/_array.py index 563de8e62..30fe567f6 100644 --- a/webknossos/webknossos/dataset/_array.py +++ b/webknossos/webknossos/dataset/_array.py @@ -498,15 +498,43 @@ def info(self) -> ArrayInfo: from zarrita.sharding import ShardingCodec zarray = self._zarray + dimension_names: tuple[str, ...] if (names := getattr(zarray.metadata, "dimension_names", None)) is None: - dimension_names = ("c", "x", "y", "z") + if (shape := getattr(zarray.metadata, "shape", None)) is None: + raise ValueError( + "Unable to determine the shape of the Zarrita Array. Neither dimension_names nor shape are present in the metadata file zarr.json." + ) + else: + if len(shape) == 2: + dimension_names = ("x", "y") + num_channels = 1 + elif len(shape) == 3: + dimension_names = ("x", "y", "z") + num_channels = 1 + elif len(shape) == 4: + dimension_names = ("c", "x", "y", "z") + num_channels = shape[0] + else: + raise ValueError( + "Unusual shape for Zarrita array, please specify the dimension names in the metadata file zarr.json." + ) else: dimension_names = names + if (shape := getattr(zarray.metadata, "shape", None)) is None: + shape = VecInt.ones(dimension_names) + if "c" in dimension_names: + num_channels = zarray.metadata.shape[dimension_names.index("c")] + else: + num_channels = 1 x_index, y_index, z_index = ( dimension_names.index("x"), dimension_names.index("y"), dimension_names.index("z"), ) + if "c" not in dimension_names: + shape = (num_channels,) + shape + dimension_names = ("c",) + dimension_names + array_shape = VecInt(shape, axes=dimension_names) if isinstance(zarray, Array): if len(zarray.codec_pipeline.codecs) == 1 and isinstance( zarray.codec_pipeline.codecs[0], ShardingCodec @@ -516,7 +544,7 @@ def info(self) -> ArrayInfo: chunk_shape = sharding_codec.configuration.chunk_shape return ArrayInfo( data_format=DataFormat.Zarr3, - num_channels=zarray.metadata.shape[0], + num_channels=num_channels, voxel_type=zarray.metadata.dtype, compression_mode=self._has_compression_codecs( sharding_codec.codec_pipeline.codecs @@ -536,12 +564,13 @@ def info(self) -> ArrayInfo: chunk_shape[z_index], ) ), + shape=array_shape, dimension_names=dimension_names, ) chunk_shape = zarray.metadata.chunk_grid.configuration.chunk_shape return ArrayInfo( data_format=DataFormat.Zarr3, - num_channels=zarray.metadata.shape[0], + num_channels=num_channels, voxel_type=zarray.metadata.dtype, compression_mode=self._has_compression_codecs( zarray.codec_pipeline.codecs @@ -550,13 +579,14 @@ def info(self) -> ArrayInfo: chunk_shape[x_index], chunk_shape[y_index], chunk_shape[z_index] ) or Vec3Int.full(1), + shape=array_shape, chunks_per_shard=Vec3Int.full(1), dimension_names=dimension_names, ) else: return ArrayInfo( data_format=DataFormat.Zarr, - num_channels=zarray.metadata.shape[0], + num_channels=num_channels, voxel_type=zarray.metadata.dtype, compression_mode=zarray.metadata.compressor is not None, chunk_shape=Vec3Int( @@ -565,6 +595,7 @@ def info(self) -> ArrayInfo: zarray.metadata.chunks[z_index], ) or Vec3Int.full(1), + shape=array_shape, chunks_per_shard=Vec3Int.full(1), dimension_names=dimension_names, ) @@ -634,9 +665,13 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarritaArray": def read(self, bbox: NDBoundingBox) -> np.ndarray: shape = bbox.size.to_tuple() zarray = self._zarray - slice_tuple = (slice(None),) + bbox.to_slices() with _blosc_disable_threading(): - data = zarray[slice_tuple] + try: + slice_tuple = (slice(None),) + bbox.to_slices() + data = zarray[slice_tuple] + except IndexError: + # The data is stored without channel axis + data = zarray[bbox.to_slices()] shape_with_channels = (self.info.num_channels,) + shape if data.shape != shape_with_channels: diff --git a/webknossos/webknossos/dataset/layer.py b/webknossos/webknossos/dataset/layer.py index 635ea01cf..4fa7a58ad 100644 --- a/webknossos/webknossos/dataset/layer.py +++ b/webknossos/webknossos/dataset/layer.py @@ -1485,6 +1485,7 @@ def _setup_mag(self, mag: Mag, path: Optional[str] = None) -> None: info.chunk_shape, info.chunks_per_shard, info.compression_mode, + info.shape, False, UPath(resolved_path), ) diff --git a/webknossos/webknossos/dataset/mag_view.py b/webknossos/webknossos/dataset/mag_view.py index e6469934a..2d05f43c5 100644 --- a/webknossos/webknossos/dataset/mag_view.py +++ b/webknossos/webknossos/dataset/mag_view.py @@ -125,6 +125,7 @@ def __init__( chunk_shape: Vec3Int, chunks_per_shard: Vec3Int, compression_mode: bool, + shape: Optional[VecInt] = None, create: bool = False, path: Optional[UPath] = None, ) -> None: @@ -145,7 +146,9 @@ def __init__( layer.num_channels, *VecInt.ones(layer.bounding_box.axes), axes=("c",) + layer.bounding_box.axes, - ), + ) + if shape is None + else shape, dimension_names=("c",) + layer.bounding_box.axes, ) if create: