Skip to content

Commit

Permalink
Fix: Onnx model export not always working properly in half precision (#…
Browse files Browse the repository at this point in the history
…119)

* build: Add onnxconverter-common as dependency

* feat: Allow exporting mixed precision onnx models if a broken model is detected

* build: Upgrade version

* docs: Update changelog
  • Loading branch information
lorenzomammana authored May 30, 2024
1 parent 508406d commit 7c4aa31
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 42 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
# Changelog
All notable changes to this project will be documented in this file.

### [2.1.8]

#### Added

- Add onnxconverter-common to the dependencies in order to allow exporting onnx models in mixed precision if issues
are encountered exporting the model entirely in half precision.

### [2.1.7]

#### Fixed
Expand Down
80 changes: 48 additions & 32 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quadra"
version = "2.1.7"
version = "2.1.8"
description = "Deep Learning experiment orchestration library"
authors = [
"Federico Belotti <[email protected]>",
Expand Down Expand Up @@ -86,6 +86,7 @@ typing_extensions = { version = "4.11.0", python = "<3.10" }
onnx = { version = "1.15.0", optional = true }
onnxsim = { version = "0.4.28", optional = true }
onnxruntime_gpu = { version = "1.17.0", optional = true, source = "onnx_cu12" }
onnxconverter-common = { version = "^1.14.0", optional = true }

[[tool.poetry.source]]
name = "torch_cu121"
Expand Down Expand Up @@ -141,7 +142,7 @@ mike = "1.1.2"
cairosvg = "2.7.0"

[tool.poetry.extras]
onnx = ["onnx", "onnxsim", "onnxruntime_gpu"]
onnx = ["onnx", "onnxsim", "onnxruntime_gpu", "onnxconverter-common"]

[tool.poetry_bumpversion.file."quadra/__init__.py"]
search = '__version__ = "{current_version}"'
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.7"
__version__ = "2.1.8"


def get_version():
Expand Down
104 changes: 97 additions & 7 deletions quadra/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from anomalib.models.cflow import CflowLightning
from omegaconf import DictConfig, ListConfig, OmegaConf
from onnxconverter_common import auto_convert_mixed_precision
from torch import nn

from quadra.models.base import ModelSignatureWrapper
Expand Down Expand Up @@ -250,14 +251,14 @@ def export_onnx_model(
for i, _ in enumerate(output_names):
dynamic_axes[output_names[i]] = {0: "batch_size"}

onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))

onnx_config["input_names"] = input_names
onnx_config["output_names"] = output_names
onnx_config["dynamic_axes"] = dynamic_axes
modified_onnx_config["input_names"] = input_names
modified_onnx_config["output_names"] = output_names
modified_onnx_config["dynamic_axes"] = dynamic_axes

simplify = onnx_config.pop("simplify", False)
_ = onnx_config.pop("fixed_batch_size", None)
simplify = modified_onnx_config.pop("simplify", False)
_ = modified_onnx_config.pop("fixed_batch_size", None)

if len(inp) == 1:
inp = inp[0]
Expand All @@ -269,7 +270,7 @@ def export_onnx_model(
raise ValueError("ONNX export does not support model with dict inputs")

try:
torch.onnx.export(model=model, args=inp, f=model_path, **onnx_config)
torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config)

onnx_model = onnx.load(model_path)
# Check if ONNX model is valid
Expand All @@ -280,6 +281,19 @@ def export_onnx_model(

log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))

if half_precision:
is_export_ok = _safe_export_half_precision_onnx(
model=model,
export_model_path=model_path,
inp=inp,
onnx_config=onnx_config,
input_shapes=input_shapes,
input_names=input_names,
)

if not is_export_ok:
return None

if simplify:
log.info("Attempting to simplify ONNX model")
onnx_model = onnx.load(model_path)
Expand All @@ -302,6 +316,82 @@ def export_onnx_model(
return os.path.join(os.getcwd(), model_path), input_shapes


def _safe_export_half_precision_onnx(
model: nn.Module,
export_model_path: str,
inp: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
onnx_config: DictConfig,
input_shapes: list[Any],
input_names: list[str],
):
"""Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
the model with a more stable export and overwrite the original model.
Args:
model: PyTorch model to be exported
export_model_path: Path to save the model
inp: Input tensors for the model
onnx_config: ONNX export configuration
input_shapes: Input shapes for the model
input_names: Input names for the model
Returns:
True if the model is stable or it was possible to export a more stable model, False otherwise.
"""
test_fp_16_model: BaseEvaluationModel = import_deployment_model(
export_model_path, OmegaConf.create({"onnx": {}}), "cuda:0"
)
if not isinstance(inp, Sequence):
inp = [inp]

test_output = test_fp_16_model(*inp)

if not isinstance(test_output, Sequence):
test_output = [test_output]

# Check if there are nan values in any of the outputs
is_broken_model = any(torch.isnan(out).any() for out in test_output)

if is_broken_model:
try:
log.warning(
"The exported half precision ONNX model contains NaN values, attempting with a more stable export..."
)
# Cast back the fp16 model to fp32 to simulate the export with fp32
model = model.float()
log.info("Starting to export model in full precision")
export_output = export_onnx_model(
model=model,
output_path=os.path.dirname(export_model_path),
onnx_config=onnx_config,
input_shapes=input_shapes,
half_precision=False,
model_name=os.path.basename(export_model_path),
)
if export_output is not None:
export_model_path, _ = export_output
else:
log.warning("Failed to export model")
return False

model_fp32 = onnx.load(export_model_path)
test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
log.warning("Attempting to convert model in mixed precision, this may take a while...")
model_fp16 = auto_convert_mixed_precision(model_fp32, test_data, rtol=0.01, atol=0.001, keep_io_types=False)
onnx.save(model_fp16, export_model_path)

onnx_model = onnx.load(export_model_path)
# Check if ONNX model is valid
onnx.checker.check_model(onnx_model)
return True
except Exception as e:
log.debug("Failed to export model with mixed precision with error: %s", e)
return False
else:
log.info("Exported half precision ONNX model does not contain NaN values, model is stable")
return True


def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
"""Export pytorch model's parameter dictionary using a deserialized state_dict.
Expand Down

0 comments on commit 7c4aa31

Please sign in to comment.