From e2da2c44e29fdd6b2caa8c285564bea0944c433e Mon Sep 17 00:00:00 2001 From: Tomaz Vieira Date: Wed, 22 Nov 2023 16:16:50 +0100 Subject: [PATCH] Mostly fixes typing in torchscript converter. Missing impl for v0_5 --- bioimageio/core/__main__.py | 4 +- .../weight_converter/torch/torchscript.py | 113 ++++++++++-------- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index 75da0316..aabd2b05 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -280,8 +280,8 @@ def convert_torch_weights_to_torchscript( output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), ): - ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) - sys.exit(ret_code) + torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) + sys.exit(0) convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py index 3feca51c..0ebe6201 100644 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ b/bioimageio/core/weight_converter/torch/torchscript.py @@ -1,4 +1,5 @@ -import warnings +from typing import List, Sequence +from typing_extensions import Any from pathlib import Path from typing import Union @@ -6,73 +7,65 @@ import torch from numpy.testing import assert_array_almost_equal -import bioimageio.spec as spec from bioimageio.spec import load_description +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec import load_description +from bioimageio.spec.common import InvalidDescription +from bioimageio.spec.utils import download from .utils import load_model - -def _check_predictions(model, scripted_model, model_spec, input_data): - assert isinstance(input_data, list) - - def _check(input_): +# FIXME: remove Any +def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]): + def _check(input_: Sequence[torch.Tensor]) -> None: # get the expected output to validate the torchscript weights - expected_outputs = model(*input_) - if isinstance(expected_outputs, (torch.Tensor)): - expected_outputs = [expected_outputs] - expected_outputs = [out.numpy() for out in expected_outputs] + expected_tensors = model(*input_) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors] - outputs = scripted_model(*input_) - if isinstance(outputs, (torch.Tensor)): - outputs = [outputs] - outputs = [out.numpy() for out in outputs] + output_tensors = scripted_model(*input_) + if isinstance(output_tensors, torch.Tensor): + output_tensors = [output_tensors] + outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] try: for exp, out in zip(expected_outputs, outputs): assert_array_almost_equal(exp, out, decimal=4) - return 0 except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 + raise ValueError(f"Results before and after weights conversion do not agree:\n {str(e)}") - ret = _check(input_data) - n_inputs = len(model_spec.inputs) - # check has not passed or we have more tahn one input? then return immediately - if ret == 1 or n_inputs > 1: - return ret + _check(input_data) + + if len(model_spec.inputs) > 1: + return # FIXME: why don't we check multiple inputs? # do we have fixed input size or variable? # if variable, we need to check multiple sizes! - shape_spec = model_spec.inputs[0].shape - try: # we have a variable shape - min_shape = shape_spec.min - step = shape_spec.step - except AttributeError: # we have fixed shape - return ret + input_descr = model_spec.inputs[0] + if isinstance(input_descr, v0_4.InputTensorDescr): + if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape): + return + min_shape = input_descr.shape.min + step = input_descr.shape.step + else: + raise NotImplementedError("FIXME: Can't handle v0.5 parameterized inputs yet") half_step = [st // 2 for st in step] max_steps = 4 step_factor = 1 # check that input and output agree for decreasing input sizes - while True: + for step_factor in range(1, max_steps + 1): slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step) this_input = [inp[slice_] for inp in input_data] this_shape = this_input[0].shape if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - return ret - - ret = _check(this_input) - if ret == 1: - return ret - step_factor += 1 - if step_factor > max_steps: - return ret - + raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}") + _check(this_input) def convert_weights_to_torchscript( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True + model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True ): """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. @@ -82,24 +75,40 @@ def convert_weights_to_torchscript( use_tracing: whether to use tracing or scripting to export the torchscript format """ if isinstance(model_spec, (str, Path)): - model_spec = load_description(Path(model_spec)) + loaded_spec = load_description(Path(model_spec)) + if isinstance(loaded_spec, InvalidDescription): + raise ValueError(f"Bad resource description: {loaded_spec}") + if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr") + model_spec = loaded_spec + + state_dict_weights_descr = model_spec.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError(f"The provided model does not have weights in the pytorch state dict format") with torch.no_grad(): - # load input and expected output data - input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs] + if isinstance(model_spec, v0_4.ModelDescr): + downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs] + else: + downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs] + + input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs] input_data = [torch.from_numpy(inp) for inp in input_data] - # instantiate model and get reference output - model = load_model(model_spec) + model = load_model(state_dict_weights_descr) - # make scripted model + # FIXME: remove Any if use_tracing: - scripted_model = torch.jit.trace(model, input_data) + scripted_model: Any = torch.jit.trace(model, input_data) else: - scripted_model = torch.jit.script(model) - - # check the scripted model - ret = _check_predictions(model, scripted_model, model_spec, input_data) + scripted_model: Any = torch.jit.script(model) + + ret = _check_predictions( + model=model, + scripted_model=scripted_model, + model_spec=model_spec, + input_data=input_data + ) # save the torchscript model scripted_model.save(str(output_path)) # does not support Path, so need to cast to str