Skip to content

Commit

Permalink
all model adapters in backends
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Dec 6, 2024
1 parent a8a50ec commit 9690574
Show file tree
Hide file tree
Showing 14 changed files with 482 additions and 270 deletions.
3 changes: 2 additions & 1 deletion bioimageio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from ._settings import settings
from .axis import Axis, AxisId
from .backends import create_model_adapter
from .block_meta import BlockMeta
from .common import MemberId
from .prediction import predict, predict_many
Expand Down Expand Up @@ -73,6 +74,7 @@
"commands",
"common",
"compute_dataset_measures",
"create_model_adapter",
"create_prediction_pipeline",
"digest_spec",
"dump_description",
Expand Down Expand Up @@ -104,7 +106,6 @@
"Stat",
"tensor",
"Tensor",
"test_description_in_conda_env",
"test_description",
"test_model",
"test_resource",
Expand Down
127 changes: 0 additions & 127 deletions bioimageio/core/_create_model_adapter.py

This file was deleted.

93 changes: 0 additions & 93 deletions bioimageio/core/_model_adapter.py

This file was deleted.

3 changes: 2 additions & 1 deletion bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from bioimageio.spec._internal.common_nodes import ResourceDescrBase
from bioimageio.spec._internal.io import is_yaml_value
from bioimageio.spec._internal.io_utils import read_yaml, write_yaml
from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import WeightsFormat
from bioimageio.spec.summary import (
Expand Down Expand Up @@ -192,7 +193,7 @@ def test_description(
decimal=decimal,
determinism=determinism,
expected_type=expected_type,
sha256=sha256,
sha256=sha256,
)
return rd.validation_summary

Expand Down
Empty file.
3 changes: 3 additions & 0 deletions bioimageio/core/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._model_adapter import create_model_adapter

__all__ = ["create_model_adapter"]
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create(
for wf in weight_format_priority_order:
if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
try:
from ._pytorch_model_adapter import PytorchModelAdapter
from .pytorch_backend import PytorchModelAdapter

return PytorchModelAdapter(
outputs=model_description.outputs,
Expand All @@ -87,7 +87,7 @@ def create(
and weights.tensorflow_saved_model_bundle is not None
):
try:
from ._tensorflow_model_adapter import TensorflowModelAdapter
from .tensorflow_backend import TensorflowModelAdapter

return TensorflowModelAdapter(
model_description=model_description, devices=devices
Expand All @@ -96,7 +96,7 @@ def create(
errors.append((wf, e))
elif wf == "onnx" and weights.onnx is not None:
try:
from ._onnx_model_adapter import ONNXModelAdapter
from .onnx_backend import ONNXModelAdapter

return ONNXModelAdapter(
model_description=model_description, devices=devices
Expand All @@ -105,7 +105,7 @@ def create(
errors.append((wf, e))
elif wf == "torchscript" and weights.torchscript is not None:
try:
from ._torchscript_model_adapter import TorchscriptModelAdapter
from .torchscript_backend import TorchscriptModelAdapter

return TorchscriptModelAdapter(
model_description=model_description, devices=devices
Expand All @@ -117,13 +117,10 @@ def create(
# we try to first import the keras model adapter using the separate package and,
# if it is not available, try to load the one using tf
try:
from ._keras import (
KerasModelAdapter,
keras, # type: ignore
)

if keras is None:
from ._tensorflow_model_adapter import KerasModelAdapter
try:
from .keras_backend import KerasModelAdapter
except Exception:
from .tensorflow_backend import KerasModelAdapter

return KerasModelAdapter(
model_description=model_description, devices=devices
Expand All @@ -134,10 +131,11 @@ def create(
assert errors
if len(weight_format_priority_order) == 1:
assert len(errors) == 1
wf, e = errors[0]
raise ValueError(
f"The '{weight_format_priority_order[0]}' model adapter could not be created"
+ f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
) from errors[0][1]
f"The '{wf}' model adapter could not be created"
+ f" in this environment:\n{e.__class__.__name__}({e}).\n\n"
) from e

else:
error_list = "\n - ".join(
Expand Down Expand Up @@ -165,13 +163,3 @@ def unload(self):
Unload model from any devices, freeing their memory.
The moder adapter should be considered unusable afterwards.
"""


def get_weight_formats() -> List[str]:
"""
Return list of supported weight types
"""
return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)


create_model_adapter = ModelAdapter.create
Loading

0 comments on commit 9690574

Please sign in to comment.