Skip to content

Commit

Permalink
Mostly fixes typing in torchscript converter. Missing impl for v0_5
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomaz-Vieira committed Nov 22, 2023
1 parent 2defed2 commit e2da2c4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 54 deletions.
4 changes: 2 additions & 2 deletions bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def convert_torch_weights_to_torchscript(
output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
):
ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
sys.exit(ret_code)
torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
sys.exit(0)

convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__

Expand Down
113 changes: 61 additions & 52 deletions bioimageio/core/weight_converter/torch/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,71 @@
import warnings
from typing import List, Sequence
from typing_extensions import Any
from pathlib import Path
from typing import Union

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal

import bioimageio.spec as spec
from bioimageio.spec import load_description
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec import load_description
from bioimageio.spec.common import InvalidDescription
from bioimageio.spec.utils import download

from .utils import load_model


def _check_predictions(model, scripted_model, model_spec, input_data):
assert isinstance(input_data, list)

def _check(input_):
# FIXME: remove Any
def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]):
def _check(input_: Sequence[torch.Tensor]) -> None:
# get the expected output to validate the torchscript weights
expected_outputs = model(*input_)
if isinstance(expected_outputs, (torch.Tensor)):
expected_outputs = [expected_outputs]
expected_outputs = [out.numpy() for out in expected_outputs]
expected_tensors = model(*input_)
if isinstance(expected_tensors, torch.Tensor):
expected_tensors = [expected_tensors]
expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors]

outputs = scripted_model(*input_)
if isinstance(outputs, (torch.Tensor)):
outputs = [outputs]
outputs = [out.numpy() for out in outputs]
output_tensors = scripted_model(*input_)
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors]

try:
for exp, out in zip(expected_outputs, outputs):
assert_array_almost_equal(exp, out, decimal=4)
return 0
except AssertionError as e:
msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}"
warnings.warn(msg)
return 1
raise ValueError(f"Results before and after weights conversion do not agree:\n {str(e)}")

ret = _check(input_data)
n_inputs = len(model_spec.inputs)
# check has not passed or we have more tahn one input? then return immediately
if ret == 1 or n_inputs > 1:
return ret
_check(input_data)

if len(model_spec.inputs) > 1:
return # FIXME: why don't we check multiple inputs?

# do we have fixed input size or variable?
# if variable, we need to check multiple sizes!
shape_spec = model_spec.inputs[0].shape
try: # we have a variable shape
min_shape = shape_spec.min
step = shape_spec.step
except AttributeError: # we have fixed shape
return ret
input_descr = model_spec.inputs[0]
if isinstance(input_descr, v0_4.InputTensorDescr):
if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape):
return
min_shape = input_descr.shape.min
step = input_descr.shape.step
else:
raise NotImplementedError("FIXME: Can't handle v0.5 parameterized inputs yet")

half_step = [st // 2 for st in step]
max_steps = 4
step_factor = 1

# check that input and output agree for decreasing input sizes
while True:
for step_factor in range(1, max_steps + 1):
slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step)
this_input = [inp[slice_] for inp in input_data]
this_shape = this_input[0].shape
if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)):
return ret

ret = _check(this_input)
if ret == 1:
return ret
step_factor += 1
if step_factor > max_steps:
return ret

raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}")
_check(this_input)

def convert_weights_to_torchscript(
model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True
model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True
):
"""Convert model weights from format 'pytorch_state_dict' to 'torchscript'.
Expand All @@ -82,24 +75,40 @@ def convert_weights_to_torchscript(
use_tracing: whether to use tracing or scripting to export the torchscript format
"""
if isinstance(model_spec, (str, Path)):
model_spec = load_description(Path(model_spec))
loaded_spec = load_description(Path(model_spec))
if isinstance(loaded_spec, InvalidDescription):
raise ValueError(f"Bad resource description: {loaded_spec}")
if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr")
model_spec = loaded_spec

state_dict_weights_descr = model_spec.weights.pytorch_state_dict
if state_dict_weights_descr is None:
raise ValueError(f"The provided model does not have weights in the pytorch state dict format")

with torch.no_grad():
# load input and expected output data
input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs]
if isinstance(model_spec, v0_4.ModelDescr):
downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs]
else:
downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs]

input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs]
input_data = [torch.from_numpy(inp) for inp in input_data]

# instantiate model and get reference output
model = load_model(model_spec)
model = load_model(state_dict_weights_descr)

# make scripted model
# FIXME: remove Any
if use_tracing:
scripted_model = torch.jit.trace(model, input_data)
scripted_model: Any = torch.jit.trace(model, input_data)
else:
scripted_model = torch.jit.script(model)

# check the scripted model
ret = _check_predictions(model, scripted_model, model_spec, input_data)
scripted_model: Any = torch.jit.script(model)

ret = _check_predictions(
model=model,
scripted_model=scripted_model,
model_spec=model_spec,
input_data=input_data
)

# save the torchscript model
scripted_model.save(str(output_path)) # does not support Path, so need to cast to str
Expand Down

0 comments on commit e2da2c4

Please sign in to comment.