From 32e53f50cf90706a14932444f48c0a834f0f47ac Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Mon, 20 Nov 2023 16:45:55 +0100 Subject: [PATCH] Fixes some typing issues in tensorflow adapter with spec v0.5 --- .../core/model_adapters/_tensorflow_model_adapter.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py index 9a97a6e7..5ecd3cb3 100644 --- a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -6,7 +6,8 @@ import tensorflow as tf import xarray as xr -from bioimageio.core.io import FileSource, download +from bioimageio.spec.utils import download +from bioimageio.spec.generic.v0_3 import FileSource #FIXME: getre-export from somewhere? from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import RelativeFilePath @@ -87,10 +88,10 @@ def _get_network(self, weight_file: FileSource): # alive in between of forward passes (but then the sessions need to be properly opened / closed) def _forward_tf(self, *input_tensors): input_keys = [ - ipt.name if isinstance(ipt, v0_4.InputTensor) else ipt.id for ipt in self.model_description.inputs + ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id for ipt in self.model_description.inputs ] output_keys = [ - out.name if isinstance(out, v0_4.OutputTensor) else out.id for out in self.model_description.outputs + out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id for out in self.model_description.outputs ] # TODO read from spec @@ -148,7 +149,7 @@ def unload(self) -> None: class TensorflowModelAdapter(TensorflowModelAdapterBase): weight_format = "tensorflow_saved_model_bundle" - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): if model_description.weights.tensorflow_saved_model_bundle is None: raise ValueError("missing tensorflow_saved_model_bundle weights") @@ -162,7 +163,7 @@ def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: class KerasModelAdapter(TensorflowModelAdapterBase): weight_format = "keras_hdf5" - def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None): + def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None): if model_description.weights.keras_hdf5 is None: raise ValueError("missing keras_hdf5 weights")