Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spec v0 5 fixing pytorch conversion #365

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def convert_torch_weights_to_torchscript(
output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
):
ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
sys.exit(ret_code)
torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
sys.exit(0)

convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__

Expand Down
13 changes: 7 additions & 6 deletions bioimageio/core/model_adapters/_pytorch_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from bioimageio.core.utils import import_callable
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.utils import download

from ._model_adapter import ModelAdapter

Expand All @@ -15,8 +16,8 @@ class PytorchModelAdapter(ModelAdapter):
def __init__(
self,
*,
outputs: Union[Sequence[v0_4.OutputTensor], Sequence[v0_5.OutputTensor]],
weights: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights],
outputs: Union[Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr]],
weights: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr],
devices: Optional[Sequence[str]] = None,
):
super().__init__()
Expand All @@ -25,7 +26,7 @@ def __init__(
self._devices = self.get_devices(devices)
self._network = self._network.to(self._devices[0])

state: Any = torch.load(weights.source, map_location=self._devices[0])
state: Any = torch.load(download(weights.source).path, map_location=self._devices[0])
_ = self._network.load_state_dict(state)

self._network = self._network.eval()
Expand All @@ -50,16 +51,16 @@ def unload(self) -> None:
torch.cuda.empty_cache() # release reserved memory

@staticmethod
def get_network(weight_spec: Union[v0_4.PytorchStateDictWeights, v0_5.PytorchStateDictWeights]):
def get_network(weight_spec: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr]) -> torch.nn.Module:
arch = import_callable(
weight_spec.architecture,
sha256=weight_spec.architecture_sha256
if isinstance(weight_spec, v0_4.PytorchStateDictWeights)
if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
else weight_spec.sha256,
)
model_kwargs = (
weight_spec.kwargs
if isinstance(weight_spec, v0_4.PytorchStateDictWeights)
if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
else weight_spec.architecture.kwargs
)
network = arch(**model_kwargs)
Expand Down
77 changes: 43 additions & 34 deletions bioimageio/core/weight_converter/torch/onnx.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
import warnings
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, List, Sequence, cast

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal

from bioimageio.spec import load_description
from bioimageio.spec._internal.types import BioimageioYamlSource
from bioimageio.spec.model import v0_4, v0_5

try:
import onnxruntime as rt
except ImportError:
rt = None

# def add_converted_onnx_weights(model_spec: AnyModel, *, opset_version: Optional[int] = 12, use_tracing: bool = True,
# verbose: bool = True,
# test_decimal: int = 4):


# def add_onnx_weights_from_pytorch_state_dict(model_spec: Union[BioimageioYamlSource, AnyModel], test_decimals: int = 4):

from bioimageio.core.weight_converter.torch.utils import load_model
from bioimageio.spec.common import InvalidDescription
from bioimageio.spec.utils import download

def add_onnx_weights(
source_model: Union[BioimageioYamlSource, AnyModel],
model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
*,
output_path: Path,
use_tracing: bool = True,
test_decimal: int = 4,
verbose: bool = False,
opset_version: "int | None" = None,
):
"""Convert model weights from format 'pytorch_state_dict' to 'onnx'.

Expand All @@ -37,42 +29,59 @@ def add_onnx_weights(
use_tracing: whether to use tracing or scripting to export the onnx format
test_decimal: precision for testing whether the results agree
"""
if isinstance(source_model, (str, Path)):
model = load_description(Path(source_model))
assert isinstance(model, (v0_4.Model, v0_5.Model))
if isinstance(model_spec, (str, Path)):
loaded_spec = load_description(Path(model_spec))
if isinstance(loaded_spec, InvalidDescription):
raise ValueError(f"Bad resource description: {loaded_spec}")
if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr")
model_spec = loaded_spec

state_dict_weights_descr = model_spec.weights.pytorch_state_dict
if state_dict_weights_descr is None:
raise ValueError(f"The provided model does not have weights in the pytorch state dict format")

with torch.no_grad():
# load input and expected output data
input_data = [np.load(ipt).astype("float32") for ipt in model.test_inputs]
if isinstance(model_spec, v0_4.ModelDescr):
downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs]
else:
downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs]

input_data: List[np.ndarray[Any, Any]] = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs]
input_tensors = [torch.from_numpy(inp) for inp in input_data]

# instantiate and generate the expected output
model = load_model(model_spec)
expected_outputs = model(*input_tensors)
if isinstance(expected_outputs, torch.Tensor):
expected_outputs = [expected_outputs]
expected_outputs = [out.numpy() for out in expected_outputs]
model = load_model(state_dict_weights_descr)

expected_tensors = model(*input_tensors)
if isinstance(expected_tensors, torch.Tensor):
expected_tensors = [expected_tensors]
expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors]

if use_tracing:
torch.onnx.export(
model,
input_tensors if len(input_tensors) > 1 else input_tensors[0],
output_path,
tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0],
str(output_path),
verbose=verbose,
opset_version=opset_version,
)
else:
raise NotImplementedError

if rt is None:
try:
import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
except ImportError:
msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
warnings.warn(msg)
return 1
return

# check the onnx model
sess = rt.InferenceSession(str(output_path)) # does not support Path, so need to cast to str
onnx_inputs = {input_name.name: inp for input_name, inp in zip(sess.get_inputs(), input_data)}
outputs = sess.run(None, onnx_inputs)
sess = rt.InferenceSession(str(output_path))
onnx_input_node_args = cast(List[Any], sess.get_inputs()) # fixme: remove cast, try using rt.NodeArg instead of Any
onnx_inputs: Dict[str, np.ndarray[Any, Any]] = {
input_name.name: inp for input_name, inp in zip(onnx_input_node_args, input_data)
}
outputs = cast(Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs)) #FIXME: remove cast

try:
for exp, out in zip(expected_outputs, outputs):
Expand Down
135 changes: 77 additions & 58 deletions bioimageio/core/weight_converter/torch/torchscript.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,81 @@
import warnings
from typing import List, Sequence
from typing_extensions import Any, assert_never
from pathlib import Path
from typing import Union

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal

import bioimageio.spec as spec
from bioimageio.spec import load_description
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec import load_description
from bioimageio.spec.common import InvalidDescription
from bioimageio.spec.utils import download

from .utils import load_model

# FIXME: remove Any
def _check_predictions(model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", input_data: Sequence[torch.Tensor]):
def _check(input_: Sequence[torch.Tensor]) -> None:
expected_tensors = model(*input_)
if isinstance(expected_tensors, torch.Tensor):
expected_tensors = [expected_tensors]
expected_outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in expected_tensors]

def _check_predictions(model, scripted_model, model_spec, input_data):
assert isinstance(input_data, list)

def _check(input_):
# get the expected output to validate the torchscript weights
expected_outputs = model(*input_)
if isinstance(expected_outputs, (torch.Tensor)):
expected_outputs = [expected_outputs]
expected_outputs = [out.numpy() for out in expected_outputs]

outputs = scripted_model(*input_)
if isinstance(outputs, (torch.Tensor)):
outputs = [outputs]
outputs = [out.numpy() for out in outputs]
output_tensors = scripted_model(*input_)
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors]

try:
for exp, out in zip(expected_outputs, outputs):
assert_array_almost_equal(exp, out, decimal=4)
return 0
except AssertionError as e:
msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}"
warnings.warn(msg)
return 1

ret = _check(input_data)
n_inputs = len(model_spec.inputs)
# check has not passed or we have more tahn one input? then return immediately
if ret == 1 or n_inputs > 1:
return ret

# do we have fixed input size or variable?
# if variable, we need to check multiple sizes!
shape_spec = model_spec.inputs[0].shape
try: # we have a variable shape
min_shape = shape_spec.min
step = shape_spec.step
except AttributeError: # we have fixed shape
return ret
raise ValueError(f"Results before and after weights conversion do not agree:\n {str(e)}")

_check(input_data)

if len(model_spec.inputs) > 1:
return # FIXME: why don't we check multiple inputs?

input_descr = model_spec.inputs[0]
if isinstance(input_descr, v0_4.InputTensorDescr):
if not isinstance(input_descr.shape, v0_4.ParametrizedInputShape):
return
min_shape = input_descr.shape.min
step = input_descr.shape.step
else:
min_shape: List[int] = []
step: List[int] = []
for axis in input_descr.axes:
if isinstance(axis.size, v0_5.ParameterizedSize):
min_shape.append(axis.size.min)
step.append(axis.size.step)
elif isinstance(axis.size, int):
min_shape.append(axis.size)
step.append(0)
elif isinstance(axis.size, (v0_5.AxisId, v0_5.TensorAxisId, type(None))):
raise NotImplementedError(f"Can't verify inputs that don't specify their shape fully: {axis}")
elif isinstance(axis.size, v0_5.SizeReference): # pyright: ignore [reportUnnecessaryIsInstance]
raise NotImplementedError(f"Can't handle axes like '{axis}' yet")
else:
assert_never(axis.size)

half_step = [st // 2 for st in step]
max_steps = 4
step_factor = 1

# check that input and output agree for decreasing input sizes
while True:
for step_factor in range(1, max_steps + 1):
slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step)
this_input = [inp[slice_] for inp in input_data]
this_shape = this_input[0].shape
if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)):
return ret

ret = _check(this_input)
if ret == 1:
return ret
step_factor += 1
if step_factor > max_steps:
return ret

raise ValueError(f"Mismatched shapes: {this_shape}. Expected at least {min_shape}")
_check(this_input)

def convert_weights_to_torchscript(
model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True
model_spec: Union[str, Path, v0_4.ModelDescr, v0_5.ModelDescr], output_path: Path, use_tracing: bool = True
):
"""Convert model weights from format 'pytorch_state_dict' to 'torchscript'.

Expand All @@ -82,24 +85,40 @@ def convert_weights_to_torchscript(
use_tracing: whether to use tracing or scripting to export the torchscript format
"""
if isinstance(model_spec, (str, Path)):
model_spec = load_description(Path(model_spec))
loaded_spec = load_description(Path(model_spec))
if isinstance(loaded_spec, InvalidDescription):
raise ValueError(f"Bad resource description: {loaded_spec}")
if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)):
raise TypeError(f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr")
model_spec = loaded_spec

state_dict_weights_descr = model_spec.weights.pytorch_state_dict
if state_dict_weights_descr is None:
raise ValueError(f"The provided model does not have weights in the pytorch state dict format")

with torch.no_grad():
# load input and expected output data
input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs]
if isinstance(model_spec, v0_4.ModelDescr):
downloaded_test_inputs = [download(inp) for inp in model_spec.test_inputs]
else:
downloaded_test_inputs = [inp.test_tensor.download() for inp in model_spec.inputs]

input_data = [np.load(dl.path).astype("float32") for dl in downloaded_test_inputs]
input_data = [torch.from_numpy(inp) for inp in input_data]

# instantiate model and get reference output
model = load_model(model_spec)
model = load_model(state_dict_weights_descr)

# make scripted model
# FIXME: remove Any
if use_tracing:
scripted_model = torch.jit.trace(model, input_data)
scripted_model: Any = torch.jit.trace(model, input_data)
else:
scripted_model = torch.jit.script(model)

# check the scripted model
ret = _check_predictions(model, scripted_model, model_spec, input_data)
scripted_model: Any = torch.jit.script(model)

ret = _check_predictions(
model=model,
scripted_model=scripted_model,
model_spec=model_spec,
input_data=input_data
)

# save the torchscript model
scripted_model.save(str(output_path)) # does not support Path, so need to cast to str
Expand Down
17 changes: 10 additions & 7 deletions bioimageio/core/weight_converter/torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
from bioimageio.core.prediction_pipeline._model_adapters._pytorch_model_adapter import PytorchModelAdapter

from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.utils import download



# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
# and for each weight format
def load_model(node):
model = PytorchModelAdapter.get_nn_instance(node)
state = torch.load(node.weights["pytorch_state_dict"].source, map_location="cpu")
model.load_state_dict(state)
model.eval()
return model
def load_model(node: "v0_4.PytorchStateDictWeightsDescr | v0_5.PytorchStateDictWeightsDescr"):
model = PytorchModelAdapter.get_network(node)
state = torch.load(download(node.source).path, map_location="cpu")
_ = model.load_state_dict(state) #FIXME: check incompatible keys?
return model.eval()
6 changes: 5 additions & 1 deletion tests/weight_converter/torch/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
def test_torchscript_converter(any_torch_model, tmp_path):
from pathlib import Path
from bioimageio.spec.model import v0_4, v0_5


def test_torchscript_converter(any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path):
from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript

out_path = tmp_path / "weights.pt"
Expand Down
Loading