Skip to content

Commit

Permalink
WIP align with current spec
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Mar 11, 2024
1 parent 92d4373 commit 930f954
Show file tree
Hide file tree
Showing 15 changed files with 129 additions and 143 deletions.
12 changes: 6 additions & 6 deletions bioimageio/core/_prediction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
self._preprocessing = preprocessing
self._postprocessing = postprocessing
if isinstance(bioimageio_model, v0_4.ModelDescr):
self._input_ids = [TensorId(d.name) for d in bioimageio_model.inputs]
self._output_ids = [TensorId(d.name) for d in bioimageio_model.outputs]
self._input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs]
self._output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs]
else:
self._input_ids = [d.id for d in bioimageio_model.inputs]
self._output_ids = [d.id for d in bioimageio_model.outputs]
Expand All @@ -58,7 +58,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore

def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]:
"""Predict input_tensor with the model without applying pre/postprocessing."""
named_tensors = [named_input_tensors[k] for k in self._input_ids[len(input_tensors) :]]
named_tensors = [named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :]]
return self._adapter.forward(*input_tensors, *named_tensors)

def apply_preprocessing(self, sample: Sample) -> None:
Expand All @@ -71,11 +71,11 @@ def apply_postprocessing(self, sample: Sample) -> None:
for op in self._postprocessing:
op(sample)

def forward_sample(self, input_sample: Sample):
def forward_sample(self, input_sample: Sample) -> Sample:
"""Apply preprocessing, run prediction and apply postprocessing."""
self.apply_preprocessing(input_sample)

prediction_tensors = self.predict(**input_sample.data)
prediction_tensors = self.predict(**{str(k): v for k, v in input_sample.data.items()})
prediction = Sample(data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat)
self.apply_postprocessing(prediction)
return prediction
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_prediction_pipeline(
)

if isinstance(bioimageio_model, v0_4.ModelDescr):
input_ids = [TensorId(ipt.name) for ipt in bioimageio_model.inputs]
input_ids = [TensorId(str(ipt.name)) for ipt in bioimageio_model.inputs]
else:
input_ids = [ipt.id for ipt in bioimageio_model.inputs]

Expand Down
4 changes: 1 addition & 3 deletions bioimageio/core/proc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dataclasses import InitVar, dataclass, field
from typing import (
Collection,
Hashable,
Literal,
Mapping,
Optional,
Expand Down Expand Up @@ -302,7 +301,7 @@ def from_proc_descr(
return cls(
input=tensor_id,
output=tensor_id,
reference_tensor=cast(TensorId, kwargs.reference_tensor),
reference_tensor=TensorId(str(kwargs.reference_tensor)),
axes=axes,
eps=kwargs.eps,
)
Expand Down Expand Up @@ -556,4 +555,3 @@ def get_proc_class(proc_spec: ProcDescr):
return ZeroMeanUnitVariance
else:
assert_never(proc_spec)

7 changes: 3 additions & 4 deletions bioimageio/core/proc_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Sequence,
Set,
Union,
cast,
)

from typing_extensions import assert_never
Expand Down Expand Up @@ -77,8 +76,8 @@ def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcess
post_measures: Set[Measure] = set()

if isinstance(model, v0_4.ModelDescr):
input_ids = {TensorId(d.name) for d in model.inputs}
output_ids = {TensorId(d.name) for d in model.outputs}
input_ids = {TensorId(str(d.name)) for d in model.inputs}
output_ids = {TensorId(str(d.name)) for d in model.outputs}
else:
input_ids = {d.id for d in model.inputs}
output_ids = {d.id for d in model.outputs}
Expand All @@ -98,7 +97,7 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]):

for proc_d in proc_descrs:
proc_class = get_proc_class(proc_d)
tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id
tensor_id = TensorId(str(t_descr.name)) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id
req = proc_class.from_proc_descr(proc_d, tensor_id) # pyright: ignore[reportArgumentType]
for m in req.required_measures:
if m.tensor_id in input_ids:
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/core/utils/_digest_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import xarray as xr

from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
from bioimageio.spec.model import AnyModelDescr, v0_4
from bioimageio.spec.utils import load_array


Expand Down
111 changes: 59 additions & 52 deletions bioimageio/core/utils/image_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,69 +31,76 @@
OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr]


def transpose_image(
image: NDArray[Any],
def interprete_array(
nd_array: NDArray[Any],
desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]],
) -> xr.DataArray:
if isinstance(desired_axes, str):
desired_space_axes = [a for a in desired_axes if a in "zyx"]
else:
desired_space_axes = [a for a in desired_axes if a.type == "space"]

ndim = nd_array.ndim
if ndim == 2 and len(desired_space_axes) >= 2:
current_axes = (
SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[0]),
SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]),
)
elif ndim == 3 and len(desired_space_axes) == 2:
current_axes = (
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]),
SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]),
SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]),
)
elif ndim == 3 and len(desired_space_axes) == 3:
current_axes = (
SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[0]),
SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]),
SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]),
)
elif ndim == 4:
current_axes = (
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]),
SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[1]),
SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[2]),
SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[3]),
)
elif ndim == 5:
current_axes = (
BatchAxis(),
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[1])]),
SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[2]),
SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[3]),
SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]),
)
else:
raise ValueError(f"Could not guess a mapping of {nd_array.shape} to {desired_axes}")

current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes)
return xr.DataArray(nd_array, dims=current_axes_ids)


def transpose_array(
arary: xr.DataArray,
desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]],
current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None,
) -> xr.DataArray:
"""Transpose an image to match desired axes.
Args:
image: the input image
array: the input array
desired_axes: the desired image axes
current_axes: the axes of the input image
"""
# if the image axes are not given deduce them from the required axes and image shape
if current_axes is None:
if isinstance(desired_axes, str):
desired_space_axes = [a for a in desired_axes if a in "zyx"]
else:
desired_space_axes = [a for a in desired_axes if a.type == "space"]

ndim = image.ndim
if ndim == 2 and len(desired_space_axes) >= 2:
current_axes = (
SpaceInputAxis(id=AxisId("y"), size=image.shape[0]),
SpaceInputAxis(id=AxisId("x"), size=image.shape[1]),
)
elif ndim == 3 and len(desired_space_axes) == 2:
current_axes = (
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]),
SpaceInputAxis(id=AxisId("y"), size=image.shape[1]),
SpaceInputAxis(id=AxisId("x"), size=image.shape[2]),
)
elif ndim == 3 and len(desired_space_axes) == 3:
current_axes = (
SpaceInputAxis(id=AxisId("z"), size=image.shape[0]),
SpaceInputAxis(id=AxisId("y"), size=image.shape[1]),
SpaceInputAxis(id=AxisId("x"), size=image.shape[2]),
)
elif ndim == 4:
current_axes = (
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]),
SpaceInputAxis(id=AxisId("z"), size=image.shape[1]),
SpaceInputAxis(id=AxisId("y"), size=image.shape[2]),
SpaceInputAxis(id=AxisId("x"), size=image.shape[3]),
)
elif ndim == 5:
current_axes = (
BatchAxis(),
ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[1])]),
SpaceInputAxis(id=AxisId("z"), size=image.shape[2]),
SpaceInputAxis(id=AxisId("y"), size=image.shape[3]),
SpaceInputAxis(id=AxisId("x"), size=image.shape[4]),
)
else:
raise ValueError(f"Could not guess a mapping of {image.shape} to {desired_axes}")

current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes)
desired_axes_ids = tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes)
tensor = xr.DataArray(image, dims=current_axes_ids)
desired_axes_ids = (
tuple(map(AxisId, desired_axes)) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes)
)
# expand the missing image axes
missing_axes = tuple(set(desired_axes_ids) - set(current_axes_ids))
tensor = tensor.expand_dims(dim=missing_axes)
missing_axes = tuple(set(desired_axes_ids) - set(map(AxisId, array.dims)))
array = array.expand_dims(dim=missing_axes)
# transpose to the correct axis order
return tensor.transpose(*tuple(desired_axes_ids))
return arraytensor.transpose(*tuple(desired_axes_ids))


def convert_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]):
Expand All @@ -117,7 +124,7 @@ def load_tensor(
is_volume = len([a for a in guess_axes if a.type in ("time", "space")]) > 2

im = imageio.volread(path) if is_volume else imageio.imread(path)
im = transpose_image(im, desired_axes=desired_axes, current_axes=current_axes)
im = transpose_array(im, desired_axes=desired_axes, current_axes=current_axes)

return xr.DataArray(
im, dims=tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include = ["bioimageio", "scripts", "tests"]
pythonPlatform = "All"
pythonVersion = "3.8"
reportDuplicateImport = "error"
reportImplicitStringConcatenation = "warning"
reportImplicitStringConcatenation = "error"
reportIncompatibleMethodOverride = true
reportMatchNotExhaustive = "error"
reportMissingSuperCall = "error"
Expand Down
2 changes: 1 addition & 1 deletion scripts/show_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bioimageio.core import load_description, save_bioimageio_yaml_only

if __name__ == "__main__":
rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml"
rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_descriptions/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml"

local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore
model_as_is = load_description(rdf_source, format_version="discover")
Expand Down
26 changes: 12 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import subprocess
import warnings
from types import MappingProxyType
Expand All @@ -10,7 +9,6 @@
from pydantic import FilePath
from pytest import FixtureRequest, fixture

os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports
from bioimageio.spec import __version__ as bioimageio_spec_version
from bioimageio.spec._package import save_bioimageio_package

Expand All @@ -35,50 +33,50 @@

MODEL_SOURCES = {
"unet2d_keras": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_keras_tf/rdf.yaml"
),
"unet2d_keras_tf2": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_keras_tf2/rdf.yaml"
),
"unet2d_nuclei_broad_model": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_nuclei_broad/rdf.yaml"
),
"unet2d_expand_output_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_nuclei_broad/rdf_expand_output_shape.yaml"
),
"unet2d_fixed_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_fixed_shape/rdf.yaml"
),
"unet2d_multi_tensor": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_multi_tensor/rdf.yaml"
),
"unet2d_diff_output_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_diff_output_shape/rdf.yaml"
),
"hpa_densenet": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml"
),
"stardist": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models"
"/stardist_example_model/rdf.yaml"
),
"stardist_wrong_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"stardist_example_model/rdf_wrong_shape.yaml"
),
"stardist_wrong_shape2": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"stardist_example_model/rdf_wrong_shape2.yaml"
),
"shape_change": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"upsample_test_model/rdf.yaml"
),
}
Expand Down
7 changes: 1 addition & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import os
import subprocess
from pathlib import Path
from typing import Any, List, Optional, Sequence, Set
from typing import Any, List, Sequence

import numpy as np
import pytest
from pydantic import FilePath

from bioimageio.core import load_description


def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.CompletedProcess[str]":
return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs)
Expand Down
Loading

0 comments on commit 930f954

Please sign in to comment.