Skip to content

Commit

Permalink
update image helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Mar 12, 2024
1 parent 1a5c50b commit 0555bf5
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 150 deletions.
33 changes: 14 additions & 19 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from bioimageio.core._prediction_pipeline import create_prediction_pipeline
from bioimageio.core.common import AxisId, BatchSize
from bioimageio.core.utils import VERSION
from bioimageio.core.utils.image_helper import pad_to
from bioimageio.core.utils import VERSION, get_test_inputs
from bioimageio.core.utils.image_helper import resize_to
from bioimageio.spec import InvalidDescr, ResourceDescr, build_description, dump_description, load_description
from bioimageio.spec._internal.common_nodes import ResourceDescrBase
from bioimageio.spec._internal.io_utils import load_array
Expand All @@ -19,7 +19,7 @@


def test_model(
source: PermissiveFileSource,
source: Union[v0_5.ModelDescr, PermissiveFileSource],
weight_format: Optional[WeightsFormat] = None,
devices: Optional[List[str]] = None,
decimal: int = 4,
Expand Down Expand Up @@ -82,18 +82,17 @@ def load_description_and_test(
_test_expected_resource_type(rd, expected_type)

if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)):
if isinstance(rd, v0_4.ModelDescr):
_test_model_inference_v0_4(rd, weight_format, devices, decimal)
else:
_test_model_inference_impl(rd, weight_format, devices)
_test_model_inference(rd, weight_format, devices, decimal)
if not isinstance(rd, v0_4.ModelDescr):
_test_model_inference_parametrized(rd, weight_format, devices)

# TODO: add execution of jupyter notebooks
# TODO: add more tests

return rd


def _test_model_inference_v0_4(
def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: Optional[WeightsFormat],
devices: Optional[List[str]],
Expand All @@ -107,11 +106,11 @@ def _test_model_inference_v0_4(
expected = [xr.DataArray(load_array(src), dims=d.axes) for src, d in zip(model.test_outputs, model.outputs)]
else:
inputs = [
xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes))
xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes))
for d in model.inputs
]
expected = [
xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes))
xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(str(a.id) for a in d.axes))
for d in model.outputs
]

Expand Down Expand Up @@ -152,7 +151,7 @@ def _test_model_inference_v0_4(
)


def _test_model_inference_impl(
def _test_model_inference_parametrized(
model: v0_5.ModelDescr,
weight_format: Optional[WeightsFormat],
devices: Optional[List[str]],
Expand All @@ -162,10 +161,7 @@ def _test_model_inference_impl(
return

try:
test_inputs = [
xr.DataArray(load_array(d.test_tensor.download().path), dims=tuple(a.id for a in d.axes))
for d in model.inputs
]
test_inputs = get_test_inputs(model)

def generate_test_cases():
tested: Set[str] = set()
Expand All @@ -178,7 +174,7 @@ def generate_test_cases():
tested.add(hashable_target_size)

resized_test_inputs = [
pad_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs)
resize_to(t, target_sizes[t_descr.id]) for t, t_descr in zip(test_inputs, model.inputs)
]
expected_output_shapes = [target_sizes[t_descr.id] for t_descr in model.outputs]
yield n, batch_size, resized_test_inputs, expected_output_shapes
Expand All @@ -203,8 +199,7 @@ def generate_test_cases():

model.validation_summary.add_detail(
ValidationDetail(
name="Reproduce test outputs from test inputs with batch_size:"
+ f" {batch_size} and size parameter n: {n}",
name="Run inference for inputs with batch_size:" + f" {batch_size} and size parameter n: {n}",
status="passed" if error is None else "failed",
errors=(
[]
Expand All @@ -224,7 +219,7 @@ def generate_test_cases():
tb = traceback.format_tb(e.__traceback__)
model.validation_summary.add_detail(
ValidationDetail(
name="Reproduce test outputs from test inputs",
name="Run inference for parametrized inputs",
status="failed",
errors=[
ErrorEntry(
Expand Down
10 changes: 9 additions & 1 deletion bioimageio/core/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Literal

import xarray as xr

Expand All @@ -10,6 +10,14 @@

TensorId = v0_5.TensorId
AxisId = v0_5.AxisId


@dataclass
class Axis:
id: AxisId
type: Literal["batch", "channel", "index", "space", "time"]


BatchSize = int
Tensor = xr.DataArray

Expand Down
4 changes: 2 additions & 2 deletions bioimageio/core/utils/_digest_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def get_test_inputs(model: AnyModelDescr) -> List[xr.DataArray]:
if isinstance(model, v0_4.ModelDescr):
return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.inputs, model.test_inputs)]
else:
return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.inputs]
return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.inputs]


def get_test_outputs(model: AnyModelDescr) -> List[xr.DataArray]:
if isinstance(model, v0_4.ModelDescr):
return [xr.DataArray(load_array(tt), dims=tuple(d.axes)) for d, tt in zip(model.outputs, model.test_outputs)]
else:
return [xr.DataArray(load_array(d.test_tensor), dims=tuple(a.id for a in d.axes)) for d in model.outputs]
return [xr.DataArray(load_array(d.test_tensor), dims=tuple(str(a.id) for a in d.axes)) for d in model.outputs]
Loading

0 comments on commit 0555bf5

Please sign in to comment.