Skip to content

Commit

Permalink
WIP update adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Feb 13, 2024
1 parent ca02570 commit ce80896
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion bioimageio/core/model_adapters/_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ModelAdapter(ABC):
@classmethod
def create(
cls,
model_description: Union[v0_4.Model, v0_5.Model],
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
*,
devices: Optional[Sequence[str]] = None,
weight_format_priority_order: NotEmpty[Sequence[WeightsFormat]] = DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
Expand Down
4 changes: 3 additions & 1 deletion bioimageio/core/model_adapters/_onnx_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


class ONNXModelAdapter(ModelAdapter):
def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None):
def __init__(
self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None
):
super().__init__()
self._internal_output_axes = [
tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes)
Expand Down
19 changes: 9 additions & 10 deletions bioimageio/core/model_adapters/_tensorflow_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import tensorflow as tf
import xarray as xr

from bioimageio.spec.utils import download
from bioimageio.spec.generic.v0_3 import FileSource #FIXME: getre-export from somewhere?
from bioimageio.spec.common import FileSource, RelativeFilePath
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import RelativeFilePath
from bioimageio.spec.utils import download

from ._model_adapter import ModelAdapter

Expand Down Expand Up @@ -54,11 +53,7 @@ def __init__(
if devices is not None:
warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}")

weight_file = self.require_unzipped(
weights.source.get_absolute(model_description.root)
if isinstance(weights.source, RelativeFilePath)
else weights.source
)
weight_file = self.require_unzipped(weights.source)
self._network = self._get_network(weight_file)
self._internal_output_axes = [
tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes)
Expand Down Expand Up @@ -149,7 +144,9 @@ def unload(self) -> None:
class TensorflowModelAdapter(TensorflowModelAdapterBase):
weight_format = "tensorflow_saved_model_bundle"

def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None):
def __init__(
self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None
):
if model_description.weights.tensorflow_saved_model_bundle is None:
raise ValueError("missing tensorflow_saved_model_bundle weights")

Expand All @@ -163,7 +160,9 @@ def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr]
class KerasModelAdapter(TensorflowModelAdapterBase):
weight_format = "keras_hdf5"

def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None):
def __init__(
self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None
):
if model_description.weights.keras_hdf5 is None:
raise ValueError("missing keras_hdf5 weights")

Expand Down
7 changes: 2 additions & 5 deletions bioimageio/core/model_adapters/_torchscript_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import xarray as xr
from numpy.typing import NDArray

from bioimageio.spec.common import RelativeFilePath
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import RelativeFilePath
from bioimageio.spec.utils import download

from ._model_adapter import ModelAdapter
Expand All @@ -22,10 +22,7 @@ def __init__(
if model_description.weights.torchscript is None:
raise ValueError(f"No torchscript weights found for model {model_description.name}")

src = model_description.weights.torchscript.source
weight_path = download(
src.get_absolute(model_description.root) if isinstance(src, RelativeFilePath) else src
).path
weight_path = download(model_description.weights.torchscript.source).path
if devices is None:
self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
else:
Expand Down

0 comments on commit ce80896

Please sign in to comment.