diff --git a/funlib/persistence/arrays/ome_datasets.py b/funlib/persistence/arrays/ome_datasets.py index 4b40586..69719ef 100644 --- a/funlib/persistence/arrays/ome_datasets.py +++ b/funlib/persistence/arrays/ome_datasets.py @@ -1,9 +1,7 @@ import logging -from itertools import chain from pathlib import Path from typing import Sequence -import numpy as np from iohub.ngff import AxisMeta, TransformationMeta, open_ome_zarr from numpy.typing import DTypeLike @@ -11,6 +9,7 @@ from .array import Array from .metadata import MetaData, OME_MetaDataFormat +from .ome_tmp import get_effective_scale, get_effective_translation logger = logging.getLogger(__name__) @@ -57,45 +56,15 @@ def open_ome_ds( units = [axis.unit for axis in axes if axis.unit is not None] types = [axis.type for axis in axes if axis.type is not None] - base_transform = ome_zarr.metadata.multiscales[0].coordinate_transformations - img_transforms = [ - ome_zarr.metadata.multiscales[0].datasets[i].coordinate_transformations - for i, dataset_meta in enumerate(ome_zarr.metadata.multiscales[0].datasets) - if dataset_meta.path == name - ][0] - - scales = [ - t.scale for t in chain(base_transform, img_transforms) if t.type == "scale" - ] - translations = [ - t.translation - for t in chain(base_transform, img_transforms) - if t.type == "translation" - ] - assert all( - all(np.isclose(scale, Coordinate(scale))) for scale in scales - ), f"funlib.persistence only supports integer scales: {scales}" - assert all( - all(np.isclose(translation, Coordinate(translation))) - for translation in translations - ), f"funlib.persistence only supports integer translations: {translations}" - scales = [Coordinate(s) for s in scales] - - # apply translations in order to get final scale/transform for this array - base_scale = scales[0] * 0 + 1 - base_translation = scales[0] * 0 - for t in chain(base_transform, img_transforms): - if t.type == "translation": - base_translation += base_scale * Coordinate(t.translation) - elif t.type == "scale": - base_scale *= Coordinate(t.scale) + scale = get_effective_scale(ome_zarr, name) + offset = get_effective_translation(ome_zarr, name) dataset = ome_zarr[name] metadata = OME_MetaDataFormat().parse( dataset.shape, - offset=list(base_translation), - voxel_size=list(base_scale), + offset=list(offset), + voxel_size=list(scale), axis_names=axis_names, units=units, types=types, diff --git a/funlib/persistence/arrays/ome_tmp.py b/funlib/persistence/arrays/ome_tmp.py new file mode 100644 index 0000000..deb580d --- /dev/null +++ b/funlib/persistence/arrays/ome_tmp.py @@ -0,0 +1,92 @@ +from typing import Literal + +import numpy as np +from iohub.ngff import TransformationMeta + + +def _get_all_transforms(node, image: str | Literal["*"]) -> list[TransformationMeta]: + """Get all transforms metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[TransformationMeta] + All transforms applicable to this image or FOV. + """ + transforms: list[TransformationMeta] = ( + [t for t in node.metadata.multiscales[0].coordinate_transformations] + if node.metadata.multiscales[0].coordinate_transformations is not None + else [] + ) + if image != "*" and image in node: + for i, dataset_meta in enumerate(node.metadata.multiscales[0].datasets): + if dataset_meta.path == image: + transforms.extend( + node.metadata.multiscales[0].datasets[i].coordinate_transformations + ) + elif image != "*": + raise ValueError(f"Key {image} not recognized.") + return transforms + + +def get_effective_scale( + node, + image: str | Literal["*"], +) -> list[float]: + """Get the effective coordinate scale metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[float] + A list of floats representing the total scale + for the image or FOV for each axis. + """ + transforms = node._get_all_transforms(image) + + full_scale = np.ones(len(node.axes), dtype=float) + for transform in transforms: + if transform.type == "scale": + full_scale *= np.array(transform.scale) + + return [float(x) for x in full_scale] + + +def get_effective_translation( + node, + image: str | Literal["*"], +) -> TransformationMeta: + """Get the effective coordinate translation metadata + for one image array or the whole FOV. + + Parameters + ---------- + image : str | Literal["*"] + Name of one image array (e.g. "0") to query, + or "*" for the whole FOV + + Returns + ------- + list[float] + A list of floats representing the total translation + for the image or FOV for each axis. + """ + transforms = node._get_all_transforms(image) + full_translation = np.zeros(len(node.axes), dtype=float) + for transform in transforms: + if transform.type == "translation": + full_translation += np.array(transform.translation) + + return [float(x) for x in full_translation]