Skip to content

Commit

Permalink
Merge pull request #364 from Tomaz-Vieira/spec_v0_5_tensorflow_adapte…
Browse files Browse the repository at this point in the history
…r_typecheck

Fixes some typing issues in tensor flow adapter
  • Loading branch information
FynnBe authored Nov 20, 2023
2 parents 14be1e7 + 32e53f5 commit d6d0fea
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions bioimageio/core/model_adapters/_tensorflow_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import tensorflow as tf
import xarray as xr

from bioimageio.core.io import FileSource, download
from bioimageio.spec.utils import download
from bioimageio.spec.generic.v0_3 import FileSource #FIXME: getre-export from somewhere?
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import RelativeFilePath

Expand Down Expand Up @@ -87,10 +88,10 @@ def _get_network(self, weight_file: FileSource):
# alive in between of forward passes (but then the sessions need to be properly opened / closed)
def _forward_tf(self, *input_tensors):
input_keys = [
ipt.name if isinstance(ipt, v0_4.InputTensor) else ipt.id for ipt in self.model_description.inputs
ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id for ipt in self.model_description.inputs
]
output_keys = [
out.name if isinstance(out, v0_4.OutputTensor) else out.id for out in self.model_description.outputs
out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id for out in self.model_description.outputs
]

# TODO read from spec
Expand Down Expand Up @@ -148,7 +149,7 @@ def unload(self) -> None:
class TensorflowModelAdapter(TensorflowModelAdapterBase):
weight_format = "tensorflow_saved_model_bundle"

def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None):
def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None):
if model_description.weights.tensorflow_saved_model_bundle is None:
raise ValueError("missing tensorflow_saved_model_bundle weights")

Expand All @@ -162,7 +163,7 @@ def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices:
class KerasModelAdapter(TensorflowModelAdapterBase):
weight_format = "keras_hdf5"

def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None):
def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None):
if model_description.weights.keras_hdf5 is None:
raise ValueError("missing keras_hdf5 weights")

Expand Down

0 comments on commit d6d0fea

Please sign in to comment.