From 930f9544b9d9369c1d6c6b755b808cbc07930f5f Mon Sep 17 00:00:00 2001 From: fynnbe Date: Mon, 11 Mar 2024 10:27:57 +0100 Subject: [PATCH] WIP align with current spec --- bioimageio/core/_prediction_pipeline.py | 12 +- bioimageio/core/proc_ops.py | 4 +- bioimageio/core/proc_setup.py | 7 +- bioimageio/core/utils/_digest_spec.py | 2 +- bioimageio/core/utils/image_helper.py | 111 ++++++++++-------- pyproject.toml | 2 +- scripts/show_diff.py | 2 +- tests/conftest.py | 26 ++-- tests/test_cli.py | 7 +- tests/test_prediction.py | 49 ++++---- ...t_prediction_pipeline_device_management.py | 41 +++---- tests/utils/test_image_helper.py | 6 +- .../weight_converter/keras/test_tensorflow.py | 1 + tests/weight_converter/torch/test_onnx.py | 1 + .../torch/test_torchscript.py | 1 + 15 files changed, 129 insertions(+), 143 deletions(-) diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py index 4f7db9e2..912aa9dd 100644 --- a/bioimageio/core/_prediction_pipeline.py +++ b/bioimageio/core/_prediction_pipeline.py @@ -37,8 +37,8 @@ def __init__( self._preprocessing = preprocessing self._postprocessing = postprocessing if isinstance(bioimageio_model, v0_4.ModelDescr): - self._input_ids = [TensorId(d.name) for d in bioimageio_model.inputs] - self._output_ids = [TensorId(d.name) for d in bioimageio_model.outputs] + self._input_ids = [TensorId(str(d.name)) for d in bioimageio_model.inputs] + self._output_ids = [TensorId(str(d.name)) for d in bioimageio_model.outputs] else: self._input_ids = [d.id for d in bioimageio_model.inputs] self._output_ids = [d.id for d in bioimageio_model.outputs] @@ -58,7 +58,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore def predict(self, *input_tensors: xr.DataArray, **named_input_tensors: xr.DataArray) -> List[xr.DataArray]: """Predict input_tensor with the model without applying pre/postprocessing.""" - named_tensors = [named_input_tensors[k] for k in self._input_ids[len(input_tensors) :]] + named_tensors = [named_input_tensors[str(k)] for k in self._input_ids[len(input_tensors) :]] return self._adapter.forward(*input_tensors, *named_tensors) def apply_preprocessing(self, sample: Sample) -> None: @@ -71,11 +71,11 @@ def apply_postprocessing(self, sample: Sample) -> None: for op in self._postprocessing: op(sample) - def forward_sample(self, input_sample: Sample): + def forward_sample(self, input_sample: Sample) -> Sample: """Apply preprocessing, run prediction and apply postprocessing.""" self.apply_preprocessing(input_sample) - prediction_tensors = self.predict(**input_sample.data) + prediction_tensors = self.predict(**{str(k): v for k, v in input_sample.data.items()}) prediction = Sample(data=dict(zip(self._output_ids, prediction_tensors)), stat=input_sample.stat) self.apply_postprocessing(prediction) return prediction @@ -142,7 +142,7 @@ def create_prediction_pipeline( ) if isinstance(bioimageio_model, v0_4.ModelDescr): - input_ids = [TensorId(ipt.name) for ipt in bioimageio_model.inputs] + input_ids = [TensorId(str(ipt.name)) for ipt in bioimageio_model.inputs] else: input_ids = [ipt.id for ipt in bioimageio_model.inputs] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py index 8a7b15f6..7c179e28 100644 --- a/bioimageio/core/proc_ops.py +++ b/bioimageio/core/proc_ops.py @@ -3,7 +3,6 @@ from dataclasses import InitVar, dataclass, field from typing import ( Collection, - Hashable, Literal, Mapping, Optional, @@ -302,7 +301,7 @@ def from_proc_descr( return cls( input=tensor_id, output=tensor_id, - reference_tensor=cast(TensorId, kwargs.reference_tensor), + reference_tensor=TensorId(str(kwargs.reference_tensor)), axes=axes, eps=kwargs.eps, ) @@ -556,4 +555,3 @@ def get_proc_class(proc_spec: ProcDescr): return ZeroMeanUnitVariance else: assert_never(proc_spec) - diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py index a71ba023..a375a2b7 100644 --- a/bioimageio/core/proc_setup.py +++ b/bioimageio/core/proc_setup.py @@ -7,7 +7,6 @@ Sequence, Set, Union, - cast, ) from typing_extensions import assert_never @@ -77,8 +76,8 @@ def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcess post_measures: Set[Measure] = set() if isinstance(model, v0_4.ModelDescr): - input_ids = {TensorId(d.name) for d in model.inputs} - output_ids = {TensorId(d.name) for d in model.outputs} + input_ids = {TensorId(str(d.name)) for d in model.inputs} + output_ids = {TensorId(str(d.name)) for d in model.outputs} else: input_ids = {d.id for d in model.inputs} output_ids = {d.id for d in model.outputs} @@ -98,7 +97,7 @@ def prepare_procs(tensor_descrs: Sequence[TensorDescr]): for proc_d in proc_descrs: proc_class = get_proc_class(proc_d) - tensor_id = cast(TensorId, t_descr.name) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id + tensor_id = TensorId(str(t_descr.name)) if isinstance(t_descr, v0_4.TensorDescrBase) else t_descr.id req = proc_class.from_proc_descr(proc_d, tensor_id) # pyright: ignore[reportArgumentType] for m in req.required_measures: if m.tensor_id in input_ids: diff --git a/bioimageio/core/utils/_digest_spec.py b/bioimageio/core/utils/_digest_spec.py index 42ba8974..ad41789f 100644 --- a/bioimageio/core/utils/_digest_spec.py +++ b/bioimageio/core/utils/_digest_spec.py @@ -2,7 +2,7 @@ import xarray as xr -from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model import AnyModelDescr, v0_4 from bioimageio.spec.utils import load_array diff --git a/bioimageio/core/utils/image_helper.py b/bioimageio/core/utils/image_helper.py index 3dac1772..80303260 100644 --- a/bioimageio/core/utils/image_helper.py +++ b/bioimageio/core/utils/image_helper.py @@ -31,69 +31,76 @@ OutputTensor = Union[OutputTensorDescr04, OutputTensorDescr] -def transpose_image( - image: NDArray[Any], +def interprete_array( + nd_array: NDArray[Any], + desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], +) -> xr.DataArray: + if isinstance(desired_axes, str): + desired_space_axes = [a for a in desired_axes if a in "zyx"] + else: + desired_space_axes = [a for a in desired_axes if a.type == "space"] + + ndim = nd_array.ndim + if ndim == 2 and len(desired_space_axes) >= 2: + current_axes = ( + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[0]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[1]), + ) + elif ndim == 3 and len(desired_space_axes) == 2: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + ) + elif ndim == 3 and len(desired_space_axes) == 3: + current_axes = ( + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[0]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[2]), + ) + elif ndim == 4: + current_axes = ( + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[0])]), + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[1]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[3]), + ) + elif ndim == 5: + current_axes = ( + BatchAxis(), + ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(nd_array.shape[1])]), + SpaceInputAxis(id=AxisId("z"), size=nd_array.shape[2]), + SpaceInputAxis(id=AxisId("y"), size=nd_array.shape[3]), + SpaceInputAxis(id=AxisId("x"), size=nd_array.shape[4]), + ) + else: + raise ValueError(f"Could not guess a mapping of {nd_array.shape} to {desired_axes}") + + current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) + return xr.DataArray(nd_array, dims=current_axes_ids) + + +def transpose_array( + arary: xr.DataArray, desired_axes: Union[v0_4.AxesStr, Sequence[AnyAxis]], current_axes: Optional[Union[v0_4.AxesStr, Sequence[AnyAxis]]] = None, ) -> xr.DataArray: """Transpose an image to match desired axes. Args: - image: the input image + array: the input array desired_axes: the desired image axes current_axes: the axes of the input image """ - # if the image axes are not given deduce them from the required axes and image shape - if current_axes is None: - if isinstance(desired_axes, str): - desired_space_axes = [a for a in desired_axes if a in "zyx"] - else: - desired_space_axes = [a for a in desired_axes if a.type == "space"] - - ndim = image.ndim - if ndim == 2 and len(desired_space_axes) >= 2: - current_axes = ( - SpaceInputAxis(id=AxisId("y"), size=image.shape[0]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[1]), - ) - elif ndim == 3 and len(desired_space_axes) == 2: - current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), - ) - elif ndim == 3 and len(desired_space_axes) == 3: - current_axes = ( - SpaceInputAxis(id=AxisId("z"), size=image.shape[0]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[2]), - ) - elif ndim == 4: - current_axes = ( - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[0])]), - SpaceInputAxis(id=AxisId("z"), size=image.shape[1]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[2]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[3]), - ) - elif ndim == 5: - current_axes = ( - BatchAxis(), - ChannelAxis(channel_names=[Identifier(f"channel{i}") for i in range(image.shape[1])]), - SpaceInputAxis(id=AxisId("z"), size=image.shape[2]), - SpaceInputAxis(id=AxisId("y"), size=image.shape[3]), - SpaceInputAxis(id=AxisId("x"), size=image.shape[4]), - ) - else: - raise ValueError(f"Could not guess a mapping of {image.shape} to {desired_axes}") - current_axes_ids = tuple(current_axes) if isinstance(current_axes, str) else tuple(a.id for a in current_axes) - desired_axes_ids = tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) - tensor = xr.DataArray(image, dims=current_axes_ids) + desired_axes_ids = ( + tuple(map(AxisId, desired_axes)) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) + ) # expand the missing image axes - missing_axes = tuple(set(desired_axes_ids) - set(current_axes_ids)) - tensor = tensor.expand_dims(dim=missing_axes) + missing_axes = tuple(set(desired_axes_ids) - set(map(AxisId, array.dims))) + array = array.expand_dims(dim=missing_axes) # transpose to the correct axis order - return tensor.transpose(*tuple(desired_axes_ids)) + return arraytensor.transpose(*tuple(desired_axes_ids)) def convert_axes_for_known_shape(axes: v0_4.AxesStr, shape: Sequence[int]): @@ -117,7 +124,7 @@ def load_tensor( is_volume = len([a for a in guess_axes if a.type in ("time", "space")]) > 2 im = imageio.volread(path) if is_volume else imageio.imread(path) - im = transpose_image(im, desired_axes=desired_axes, current_axes=current_axes) + im = transpose_array(im, desired_axes=desired_axes, current_axes=current_axes) return xr.DataArray( im, dims=tuple(desired_axes) if isinstance(desired_axes, str) else tuple(a.id for a in desired_axes) diff --git a/pyproject.toml b/pyproject.toml index db0f8626..7d715f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ include = ["bioimageio", "scripts", "tests"] pythonPlatform = "All" pythonVersion = "3.8" reportDuplicateImport = "error" -reportImplicitStringConcatenation = "warning" +reportImplicitStringConcatenation = "error" reportIncompatibleMethodOverride = true reportMatchNotExhaustive = "error" reportMissingSuperCall = "error" diff --git a/scripts/show_diff.py b/scripts/show_diff.py index 77623343..4a5d2223 100644 --- a/scripts/show_diff.py +++ b/scripts/show_diff.py @@ -7,7 +7,7 @@ from bioimageio.core import load_description, save_bioimageio_yaml_only if __name__ == "__main__": - rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_specs/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" + rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/pydantic_axes/example_descriptions/models/unet2d_nuclei_broad/rdf_v0_4_9.yaml" local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore model_as_is = load_description(rdf_source, format_version="discover") diff --git a/tests/conftest.py b/tests/conftest.py index 9c31410d..324586d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import subprocess import warnings from types import MappingProxyType @@ -10,7 +9,6 @@ from pydantic import FilePath from pytest import FixtureRequest, fixture -os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports from bioimageio.spec import __version__ as bioimageio_spec_version from bioimageio.spec._package import save_bioimageio_package @@ -35,50 +33,50 @@ MODEL_SOURCES = { "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_keras_tf/rdf.yaml" ), "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_keras_tf2/rdf.yaml" ), "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_nuclei_broad/rdf.yaml" ), "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_nuclei_broad/rdf_expand_output_shape.yaml" ), "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_fixed_shape/rdf.yaml" ), "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_multi_tensor/rdf.yaml" ), "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "unet2d_diff_output_shape/rdf.yaml" ), "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" ), "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" "/stardist_example_model/rdf.yaml" ), "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "stardist_example_model/rdf_wrong_shape.yaml" ), "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "stardist_example_model/rdf_wrong_shape2.yaml" ), "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" "upsample_test_model/rdf.yaml" ), } diff --git a/tests/test_cli.py b/tests/test_cli.py index 967d5d80..601882b3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,14 +1,9 @@ -import os import subprocess -from pathlib import Path -from typing import Any, List, Optional, Sequence, Set +from typing import Any, List, Sequence -import numpy as np import pytest from pydantic import FilePath -from bioimageio.core import load_description - def run_subprocess(commands: Sequence[str], **kwargs: Any) -> "subprocess.CompletedProcess[str]": return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index a0e34b08..a95eb3f4 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -4,8 +4,11 @@ import numpy as np from numpy.testing import assert_array_almost_equal +from bioimageio.core.utils import get_test_inputs from bioimageio.spec import load_description -from bioimageio.spec.model.v0_5 import ModelDescr +from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 +from bioimageio.spec.model.v0_5 import InputTensorDescr, ModelDescr def test_predict_image(any_model: Path, tmpdir: Path): @@ -26,7 +29,7 @@ def test_predict_image(any_model: Path, tmpdir: Path): assert_array_almost_equal(res, exp, decimal=4) -def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): +def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not: Path, tmpdir: Path): from bioimageio.core.prediction import predict_image spec = load_description(unet2d_fixed_shape_or_not) @@ -44,24 +47,18 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): assert_array_almost_equal(res, exp, decimal=4) -def _test_predict_with_padding(model, tmp_path): +def _test_predict_with_padding(any_model: Path, tmp_path: Path): from bioimageio.core.prediction import predict_image - spec = load_description(model) - assert isinstance(spec, Model) + model = load_description(any_model) + assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) - input_spec, output_spec = spec.inputs[0], spec.outputs[0] - channel_axis = input_spec.axes.index("c") + input_spec, output_spec = model.inputs[0], model.outputs[0] + channel_axis = "c" if isinstance(input_spec, InputTensorDescr_v0_4) else [a.id for a in input_spec.axes][0] channel_first = channel_axis == 1 - image = np.load(str(spec.test_inputs[0])) - assert image.shape[channel_axis] == 1 - if channel_first: - image = image[0, 0] - else: - image = image[0, ..., 0] - original_shape = image.shape - assert image.ndim == 2 + # TODO: check more tensors + image = get_test_inputs(model)[0] if isinstance(output_spec.shape, list): n_channels = output_spec.shape[channel_axis] @@ -106,15 +103,17 @@ def check_result(): assert res.shape == exp_shape # test with dynamic padding - predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) + predict_image(any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) check_result() # test with fixed padding - predict_image(model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}) + predict_image( + any_model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"} + ) check_result() # test with automated padding - predict_image(model, in_path, out_path, padding=True) + predict_image(any_model, in_path, out_path, padding=True) check_result() @@ -133,7 +132,7 @@ def test_predict_image_with_padding_channel_last(stardist, tmp_path): _test_predict_with_padding(stardist, tmp_path) -def _test_predict_image_with_tiling(model, tmp_path: Path, exp_mean_deviation): +def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): from bioimageio.core.prediction import predict_image spec = load_description(model) @@ -166,27 +165,27 @@ def check_result(): # prediction with tiling with the parameters above may not be suited for any model # so we only run it for the pytorch unet2d here -def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model, tmp_path: Path): +def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) -def test_predict_image_with_tiling_2(unet2d_diff_output_shape, tmp_path: Path): +def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) -def test_predict_image_with_tiling_3(shape_change_model, tmp_path: Path): +def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) -def test_predict_image_with_tiling_channel_last(stardist, tmp_path: Path): +def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): _test_predict_image_with_tiling(stardist, tmp_path, 0.13) -def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape, tmp_path: Path): +def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape: Path, tmp_path: Path): _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) -def test_predict_images(unet2d_nuclei_broad_model, tmp_path: Path): +def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): from bioimageio.core.prediction import predict_images n_images = 5 diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py index 16354d18..1236383a 100644 --- a/tests/test_prediction_pipeline_device_management.py +++ b/tests/test_prediction_pipeline_device_management.py @@ -1,14 +1,12 @@ from pathlib import Path -import numpy as np -import xarray as xr from numpy.testing import assert_array_almost_equal +from bioimageio.core import load_description +from bioimageio.core.utils import get_test_inputs, get_test_outputs from bioimageio.core.utils.testing import skip_on -from bioimageio.spec import load_description from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat -from bioimageio.spec.utils import load_array class TooFewDevicesException(Exception): @@ -27,24 +25,13 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert isinstance(bio_model, (ModelDescr, ModelDescr04)) pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) - if isinstance(bio_model, ModelDescr04): - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] - else: - inputs = [ - xr.DataArray(load_array(ipt.test_tensor), dims=tuple(a.id for a in ipt.axes)) for ipt in bio_model.inputs - ] + inputs = get_test_inputs(bio_model) with pred_pipe as pp: outputs = pp.forward(*inputs) assert isinstance(outputs, list) - expected_outputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_outputs, bio_model.outputs) - ] + expected_outputs = get_test_outputs(bio_model) assert len(outputs) == len(expected_outputs) for out, exp in zip(outputs, expected_outputs): @@ -59,26 +46,26 @@ def _test_device_management(model_package: Path, weight_format: WeightsFormat): assert_array_almost_equal(out, exp, decimal=4) -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torch(any_torch_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_torch(any_torch_model: Path): _test_device_management(any_torch_model, "pytorch_state_dict") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torchscript(any_torchscript_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_torchscript(any_torchscript_model: Path): _test_device_management(any_torchscript_model, "torchscript") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_onnx(any_onnx_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_onnx(any_onnx_model: Path): _test_device_management(any_onnx_model, "onnx") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_tensorflow(any_tensorflow_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_tensorflow(any_tensorflow_model: Path): _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_keras(any_keras_model): +@skip_on(TooFewDevicesException, reason="Too few devices") # pyright: ignore[reportArgumentType] +def test_device_management_keras(any_keras_model: Path): _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/utils/test_image_helper.py b/tests/utils/test_image_helper.py index 8e86a919..6e0e9c08 100644 --- a/tests/utils/test_image_helper.py +++ b/tests/utils/test_image_helper.py @@ -2,18 +2,18 @@ def test_transform_input_image(): - from bioimageio.core.utils.image_helper import transpose_image + from bioimageio.core.utils.image_helper import transpose_array ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] im = np.random.rand(256, 256) for axes in ax_list: - inp = transpose_image(im, axes) + inp = transpose_array(im, axes) assert inp.ndim == len(axes) ax_list = ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] vol = np.random.rand(64, 64, 64) for axes in ax_list: - inp = transpose_image(vol, axes) + inp = transpose_array(vol, axes) assert inp.ndim == len(axes) diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 6cc42c57..069b6f23 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking import zipfile from pathlib import Path diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index c2efbcd8..a0315650 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking import os from pathlib import Path diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index e3f6e42c..945e778b 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -1,3 +1,4 @@ +# type: ignore # TODO enable type checking from pathlib import Path import pytest