Skip to content

Commit

Permalink
Merge pull request #363 from Tomaz-Vieira/spec_v0_5_pytorch_adapter_t…
Browse files Browse the repository at this point in the history
…ypecheck

Fixes type-checking with spec v0.5 in torchscript model adapter
  • Loading branch information
FynnBe authored Nov 20, 2023
2 parents 6c9bf0f + 26c6f58 commit 14be1e7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bioimageio/core/model_adapters/_torchscript_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import xarray as xr
from numpy.typing import NDArray

from bioimageio.core.io import download
from bioimageio.spec.utils import download
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import RelativeFilePath

from ._model_adapter import ModelAdapter


class TorchscriptModelAdapter(ModelAdapter):
def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices: Optional[Sequence[str]] = None):
def __init__(self, *, model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None):
super().__init__()
if model_description.weights.torchscript is None:
raise ValueError(f"No torchscript weights found for model {model_description.name}")
Expand All @@ -32,7 +32,7 @@ def __init__(self, *, model_description: Union[v0_4.Model, v0_5.Model], devices:
if len(self.devices) > 1:
warnings.warn("Multiple devices for single torchscript model not yet implemented")

self._model = torch.jit.load(weight_path) # pyright: ignore[reportPrivateImportUsage]
self._model = torch.jit.load(weight_path)
self._model.to(self.devices[0])
self._internal_output_axes = [
tuple(out.axes) if isinstance(out.axes, str) else tuple(a.id for a in out.axes)
Expand Down

0 comments on commit 14be1e7

Please sign in to comment.