Skip to content

Commit

Permalink
Fix axes configuration when adding existing zarr arrays (#1204)
Browse files Browse the repository at this point in the history
* Change default behaviour for axes when creating zarrita array.

* add typehint.

* Update changelog.

* Add test.

* Fix type annotations for added test.
  • Loading branch information
markbader authored Dec 10, 2024
1 parent 8498db0 commit 3744f1f
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 7 deletions.
1 change: 1 addition & 0 deletions webknossos/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Removed the CZI installation extra from `pip install webknossos[all]` by default

### Fixed
- Fixed an issue with merging annotations with compressed fallback layers.
- Fixed an issue where adding a Zarr array with other axes than `cxyz` leads to an error. [#1204](https://github.com/scalableminds/webknossos-libs/pull/1204)



Expand Down
31 changes: 31 additions & 0 deletions webknossos/tests/dataset/test_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

import numpy as np
from zarrita import Array

import webknossos as wk


def test_add_mag_from_zarrarray(tmp_path: Path) -> None:
dataset = wk.Dataset(
tmp_path / "test_add_mag_from_zarrarray", voxel_size=(10, 10, 10)
)
layer = dataset.add_layer(
"color",
wk.COLOR_CATEGORY,
data_format="zarr3",
bounding_box=wk.BoundingBox((0, 0, 0), (16, 16, 16)),
)
zarr_mag_path = tmp_path / "zarr_data" / "mag1.zarr"
zarr_data = np.random.randint(0, 255, (16, 16, 16), dtype="uint8")
zarr_mag = Array.create(
store=zarr_mag_path, shape=(16, 16, 16), chunk_shape=(8, 8, 8), dtype="uint8"
)
zarr_mag[:] = zarr_data

layer.add_mag_from_zarrarray("1", zarr_mag_path, extend_layer_bounding_box=False)

assert layer.get_mag("1").read().shape == (1, 16, 16, 16)
assert layer.get_mag("1").info.num_channels == 1
assert layer.get_mag("1").info.dimension_names == ("c", "x", "y", "z")
assert (layer.get_mag("1").read()[0] == zarr_data).all()
47 changes: 41 additions & 6 deletions webknossos/webknossos/dataset/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,43 @@ def info(self) -> ArrayInfo:
from zarrita.sharding import ShardingCodec

zarray = self._zarray
dimension_names: tuple[str, ...]
if (names := getattr(zarray.metadata, "dimension_names", None)) is None:
dimension_names = ("c", "x", "y", "z")
if (shape := getattr(zarray.metadata, "shape", None)) is None:
raise ValueError(
"Unable to determine the shape of the Zarrita Array. Neither dimension_names nor shape are present in the metadata file zarr.json."
)
else:
if len(shape) == 2:
dimension_names = ("x", "y")
num_channels = 1
elif len(shape) == 3:
dimension_names = ("x", "y", "z")
num_channels = 1
elif len(shape) == 4:
dimension_names = ("c", "x", "y", "z")
num_channels = shape[0]
else:
raise ValueError(
"Unusual shape for Zarrita array, please specify the dimension names in the metadata file zarr.json."
)
else:
dimension_names = names
if (shape := getattr(zarray.metadata, "shape", None)) is None:
shape = VecInt.ones(dimension_names)
if "c" in dimension_names:
num_channels = zarray.metadata.shape[dimension_names.index("c")]
else:
num_channels = 1
x_index, y_index, z_index = (
dimension_names.index("x"),
dimension_names.index("y"),
dimension_names.index("z"),
)
if "c" not in dimension_names:
shape = (num_channels,) + shape
dimension_names = ("c",) + dimension_names
array_shape = VecInt(shape, axes=dimension_names)
if isinstance(zarray, Array):
if len(zarray.codec_pipeline.codecs) == 1 and isinstance(
zarray.codec_pipeline.codecs[0], ShardingCodec
Expand All @@ -516,7 +544,7 @@ def info(self) -> ArrayInfo:
chunk_shape = sharding_codec.configuration.chunk_shape
return ArrayInfo(
data_format=DataFormat.Zarr3,
num_channels=zarray.metadata.shape[0],
num_channels=num_channels,
voxel_type=zarray.metadata.dtype,
compression_mode=self._has_compression_codecs(
sharding_codec.codec_pipeline.codecs
Expand All @@ -536,12 +564,13 @@ def info(self) -> ArrayInfo:
chunk_shape[z_index],
)
),
shape=array_shape,
dimension_names=dimension_names,
)
chunk_shape = zarray.metadata.chunk_grid.configuration.chunk_shape
return ArrayInfo(
data_format=DataFormat.Zarr3,
num_channels=zarray.metadata.shape[0],
num_channels=num_channels,
voxel_type=zarray.metadata.dtype,
compression_mode=self._has_compression_codecs(
zarray.codec_pipeline.codecs
Expand All @@ -550,13 +579,14 @@ def info(self) -> ArrayInfo:
chunk_shape[x_index], chunk_shape[y_index], chunk_shape[z_index]
)
or Vec3Int.full(1),
shape=array_shape,
chunks_per_shard=Vec3Int.full(1),
dimension_names=dimension_names,
)
else:
return ArrayInfo(
data_format=DataFormat.Zarr,
num_channels=zarray.metadata.shape[0],
num_channels=num_channels,
voxel_type=zarray.metadata.dtype,
compression_mode=zarray.metadata.compressor is not None,
chunk_shape=Vec3Int(
Expand All @@ -565,6 +595,7 @@ def info(self) -> ArrayInfo:
zarray.metadata.chunks[z_index],
)
or Vec3Int.full(1),
shape=array_shape,
chunks_per_shard=Vec3Int.full(1),
dimension_names=dimension_names,
)
Expand Down Expand Up @@ -634,9 +665,13 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarritaArray":
def read(self, bbox: NDBoundingBox) -> np.ndarray:
shape = bbox.size.to_tuple()
zarray = self._zarray
slice_tuple = (slice(None),) + bbox.to_slices()
with _blosc_disable_threading():
data = zarray[slice_tuple]
try:
slice_tuple = (slice(None),) + bbox.to_slices()
data = zarray[slice_tuple]
except IndexError:
# The data is stored without channel axis
data = zarray[bbox.to_slices()]

shape_with_channels = (self.info.num_channels,) + shape
if data.shape != shape_with_channels:
Expand Down
1 change: 1 addition & 0 deletions webknossos/webknossos/dataset/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,7 @@ def _setup_mag(self, mag: Mag, path: Optional[str] = None) -> None:
info.chunk_shape,
info.chunks_per_shard,
info.compression_mode,
info.shape,
False,
UPath(resolved_path),
)
Expand Down
5 changes: 4 additions & 1 deletion webknossos/webknossos/dataset/mag_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
chunk_shape: Vec3Int,
chunks_per_shard: Vec3Int,
compression_mode: bool,
shape: Optional[VecInt] = None,
create: bool = False,
path: Optional[UPath] = None,
) -> None:
Expand All @@ -145,7 +146,9 @@ def __init__(
layer.num_channels,
*VecInt.ones(layer.bounding_box.axes),
axes=("c",) + layer.bounding_box.axes,
),
)
if shape is None
else shape,
dimension_names=("c",) + layer.bounding_box.axes,
)
if create:
Expand Down

0 comments on commit 3744f1f

Please sign in to comment.