Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Mar 15, 2024
1 parent ccc84eb commit f798344
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 128 deletions.
8 changes: 4 additions & 4 deletions bioimageio/core/model_adapters/_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ class ModelAdapter(ABC):
"""
Represents model *without* any preprocessing or postprocessing.
>>> from bioimageio.core import read_description
>>> model = read_description()
>>> from bioimageio.core import load_description
>>> model = load_description()
>>> print("option 1:")
option 1:
>>> adapter = ModelAdapter.create(model)
>>> adapter.forward()
>>> adapter.forward # (...)
>>> adapter.unload()
>>> print("option 2:")
option 2:
>>> with ModelAdapter.create(model) as adapter:
>>> adapter.forward()
>>> adapter.forward # (...)
"""

Expand Down
157 changes: 36 additions & 121 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

import subprocess
import warnings
from types import MappingProxyType
from typing import List, Set
from typing import List

from filelock import FileLock
from loguru import logger
from pydantic import FilePath
from pytest import FixtureRequest, TempPathFactory, fixture
from pytest import FixtureRequest, fixture

from bioimageio.spec import __version__ as bioimageio_spec_version
from bioimageio.spec._package import save_bioimageio_package

warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}")

Expand Down Expand Up @@ -47,7 +43,7 @@
),
"unet2d_expand_output_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
"unet2d_nuclei_broad/rdf_expand_output_shape_v0_4.bioimageio.yaml"
"unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml"
),
"unet2d_fixed_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/"
Expand Down Expand Up @@ -110,59 +106,6 @@
skip_tensorflow = tensorflow is None
skip_tensorflow_js = True # TODO: add a tensorflow_js example model

# load all model packages we need for testing
load_model_packages: Set[str] = set()
if not skip_torch:
load_model_packages |= set(TORCH_MODELS + TORCHSCRIPT_MODELS)

if not skip_onnx:
load_model_packages |= set(ONNX_MODELS)

if not skip_tensorflow:
load_model_packages |= set(TENSORFLOW_JS_MODELS)
if tf_major_version == 1:
load_model_packages |= set(KERAS_TF1_MODELS)
load_model_packages |= set(TENSORFLOW1_MODELS)
elif tf_major_version == 2:
load_model_packages |= set(KERAS_TF2_MODELS)
load_model_packages |= set(TENSORFLOW2_MODELS)


@fixture(scope="session")
def model_packages(
tmp_path_factory: TempPathFactory, worker_id: str
) -> MappingProxyType[str, FilePath]:
"""prepare model packages (only run with one worker)
see https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
"""
root_tmp_dir = tmp_path_factory.getbasetemp().parent

packages = MappingProxyType(
{
name: (root_tmp_dir / name).with_suffix(".zip")
for name in load_model_packages
}
)

def generate_packages():
for name in load_model_packages:
actual_out = save_bioimageio_package(
MODEL_SOURCES[name], output_path=packages[name]
)
assert actual_out == packages[name]

info_path = root_tmp_dir / "packages_created"
if worker_id == "master":
# no workers
generate_packages()
else:
with FileLock(info_path.with_suffix(".lock")):
if not info_path.is_file():
generate_packages()
_ = info_path.write_text("")

return packages


@fixture(scope="session")
def mamba_cmd():
Expand All @@ -185,24 +128,18 @@ def mamba_cmd():


@fixture(params=[] if skip_torch else TORCH_MODELS)
def any_torch_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_torch_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(params=[] if skip_torch else TORCHSCRIPT_MODELS)
def any_torchscript_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_torchscript_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(params=[] if skip_onnx else ONNX_MODELS)
def any_onnx_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_onnx_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(
Expand All @@ -212,10 +149,8 @@ def any_onnx_model(
else TENSORFLOW1_MODELS if tf_major_version == 1 else TENSORFLOW2_MODELS
)
)
def any_tensorflow_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_tensorflow_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(
Expand All @@ -225,25 +160,21 @@ def any_tensorflow_model(
else KERAS_TF1_MODELS if tf_major_version == 1 else KERAS_TF2_MODELS
)
)
def any_keras_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_keras_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(params=[] if skip_tensorflow_js else TENSORFLOW_JS_MODELS)
def any_tensorflow_js_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def any_tensorflow_js_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# fixture to test with all models that should run in the current environment
# we exclude stardist_wrong_shape here because it is not a valid model
# and included only to test that validation for this model fails
@fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"})
def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]):
return model_packages[request.param]
@fixture(params=set(MODEL_SOURCES) - {"stardist_wrong_shape", "stardist_wrong_shape2"})
def any_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# TODO it would be nice to just generate fixtures for all the individual models dynamically
Expand All @@ -256,10 +187,8 @@ def any_model(request: FixtureRequest, model_packages: MappingProxyType[str, Fil
@fixture(
params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]
)
def unet2d_fixed_shape_or_not(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_fixed_shape_or_not(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(
Expand All @@ -269,10 +198,8 @@ def unet2d_fixed_shape_or_not(
else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
)
)
def convert_to_onnx(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def convert_to_onnx(request: FixtureRequest):
return MODEL_SOURCES[request.param]


@fixture(
Expand All @@ -282,50 +209,38 @@ def convert_to_onnx(
else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]
)
)
def unet2d_keras(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_keras(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing torch
@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
def unet2d_nuclei_broad_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_nuclei_broad_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing torch
@fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"])
def unet2d_diff_output_shape(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_diff_output_shape(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing torch
@fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"])
def unet2d_expand_output_shape(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_expand_output_shape(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing torch
@fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
def unet2d_fixed_shape(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def unet2d_fixed_shape(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing torch
@fixture(params=[] if skip_torch else ["shape_change"])
def shape_change_model(
request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]
):
return model_packages[request.param]
def shape_change_model(request: FixtureRequest):
return MODEL_SOURCES[request.param]


# written as model group to automatically skip on missing tensorflow 1
Expand All @@ -346,5 +261,5 @@ def stardist_wrong_shape2(request: FixtureRequest):

# written as model group to automatically skip on missing tensorflow 1
@fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
def stardist(request: FixtureRequest, model_packages: MappingProxyType[str, FilePath]):
return model_packages[request.param]
def stardist(request: FixtureRequest):
return MODEL_SOURCES[request.param]
6 changes: 6 additions & 0 deletions tests/test_any_model_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from bioimageio.spec import load_description_and_validate_format_only


def test_model(any_model: str):
summary = load_description_and_validate_format_only(any_model)
assert summary.status == "passed", summary.format()
4 changes: 3 additions & 1 deletion tests/test_prediction_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat
from bioimageio.core._prediction_pipeline import create_prediction_pipeline

bio_model = load_description(model_package)
assert isinstance(bio_model, (ModelDescr, ModelDescr04))
assert isinstance(
bio_model, (ModelDescr, ModelDescr04)
), bio_model.validation_summary.format()
pp = create_prediction_pipeline(
bioimageio_model=bio_model, weight_format=weights_format
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def test_test_model(any_model: Path):
from bioimageio.core._resource_tests import test_model

summary = test_model(any_model)
assert summary.status == "passed"
assert summary.status == "passed", summary.format()


def test_test_resource(any_model: Path):
from bioimageio.core._resource_tests import test_description

summary = test_description(any_model)
assert summary.status == "passed"
assert summary.status == "passed", summary.format()

0 comments on commit f798344

Please sign in to comment.