Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call onnx-rewritter when possible in onnxruntime.InferenceSession #19348

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import collections
import collections.abc
import onnx
import os
import typing
import warnings
Expand All @@ -16,6 +17,39 @@
if typing.TYPE_CHECKING:
import onnxruntime

HAS_ONNX_REWRITTER = True
try:
from onnxrewriter.rewriter.transformers import rewrite
from onnxrewriter.optimizer import optimize
Copy link
Contributor

@tianleiwu tianleiwu Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see a few problems:
(1) onnxrewriter is not available in other language C/C++/Nuget API. It make the result inconsistent across different language.
(2) If onnxrewriter has a bug, you have to uninstall it since no explicit option to disable it.
(3) It might not work well with large model (>2GB) since .SerializeToString() is used.

If onnxrewriter is generic enough, why not implement it inside onnxruntime with C++?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with whatever way onnx-rewriter should be called. I just use the only available way I can unblock llama with onnxrt dynamo backend (without this change, ORT is many times slower than inductor, and we will get no market share from PyTorch 2 features). I can add a flag to turn it on/off. Does it make sense?

For how this thing should ultimately implemented, please talk with @thiagocrepaldi. I guess exporter (i.e., onnxscript) will eventually include this optimization pass after it's matured.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've discussed this with Thiago. The current plan is that onnx-rewriter will be invoked by onnxscript and will be part of the exporter workflow. Having a flag is again not consistent between different language bindings. We don't want users to think that the Python bindings can give better perf with a flag, but that flag is not available for others. Can the llama model be unblocked by calling the rewriter separately?

Copy link
Contributor Author

@wschin wschin Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop. The only alternating solution I have in mind is to support custom post-processing pass in InferenceSession or DORT.

except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.

Check warning

Code scanning / lintrunner

RUFF/E722 Warning

Do not use bare except.
See https://docs.astral.sh/ruff/rules/bare-except
HAS_ONNX_REWRITTER = False

def rewrite_and_optimize_model_bytes(model):
assert HAS_ONNX_REWRITTER
onnx_model = onnx.ModelProto()
onnx_model.ParseFromString(self._model_bytes)

Check failure

Code scanning / lintrunner

RUFF/F821 Error

onnx_model = optimize(
onnx_model,
num_iterations=2,
onnx_shape_inference=False,
function_aware_folding=True,
)
onnx_model = rewrite(onnx_model)

return onnx_model.SerializeToString()

def rewrite_and_optimize_model_path(model_path):
assert HAS_ONNX_REWRITTER
onnx_model = onnx.load(self._model_path)

Check failure

Code scanning / lintrunner

RUFF/F821 Error

onnx_model = optimize(
onnx_model,
num_iterations=2,
onnx_shape_inference=False,
function_aware_folding=True,
)
onnx_model = rewrite(onnx_model)

return onnx_model.SerializeToString()

def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
if device_type == "cuda":
Expand Down Expand Up @@ -469,9 +503,18 @@
self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)

if self._model_path:
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
self._model_count += 1
if HAS_ONNX_REWRITTER:
onnx_model = rewrite_and_optimize_model_path(self._model_path)
sess = C.InferenceSession(session_options, onnx_model, True, self._read_config_from_model)
else:
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
else:
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
if HAS_ONNX_REWRITTER:
onnx_model = rewrite_and_optimize_model_bytes(self._model_bytes)
sess = C.InferenceSession(session_options, onnx_model, False, self._read_config_from_model)
else:
sess = C.InferenceSession(session_options, onnx_model.SerializeToString(), False, self._read_config_from_model)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'onnx_model' may be used before it is initialized.

if disabled_optimizers is None:
disabled_optimizers = set()
Expand Down
Loading