-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Call onnx-rewritter when possible in onnxruntime.InferenceSession #19348
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import collections | ||
import collections.abc | ||
import onnx | ||
import os | ||
import typing | ||
import warnings | ||
|
@@ -16,6 +17,39 @@ | |
if typing.TYPE_CHECKING: | ||
import onnxruntime | ||
|
||
HAS_ONNX_REWRITTER = True | ||
try: | ||
from onnxrewriter.rewriter.transformers import rewrite | ||
from onnxrewriter.optimizer import optimize | ||
except: | ||
Check notice Code scanning / CodeQL Except block handles 'BaseException' Note
Except block directly handles BaseException.
Check warning Code scanning / lintrunner RUFF/E722 Warning
Do not use bare except.
See https://docs.astral.sh/ruff/rules/bare-except |
||
HAS_ONNX_REWRITTER = False | ||
|
||
def rewrite_and_optimize_model_bytes(model): | ||
assert HAS_ONNX_REWRITTER | ||
onnx_model = onnx.ModelProto() | ||
onnx_model.ParseFromString(self._model_bytes) | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name self.
See https://docs.astral.sh/ruff/rules/undefined-name |
||
onnx_model = optimize( | ||
onnx_model, | ||
num_iterations=2, | ||
onnx_shape_inference=False, | ||
function_aware_folding=True, | ||
) | ||
onnx_model = rewrite(onnx_model) | ||
|
||
return onnx_model.SerializeToString() | ||
|
||
def rewrite_and_optimize_model_path(model_path): | ||
assert HAS_ONNX_REWRITTER | ||
onnx_model = onnx.load(self._model_path) | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name self.
See https://docs.astral.sh/ruff/rules/undefined-name |
||
onnx_model = optimize( | ||
onnx_model, | ||
num_iterations=2, | ||
onnx_shape_inference=False, | ||
function_aware_folding=True, | ||
) | ||
onnx_model = rewrite(onnx_model) | ||
|
||
return onnx_model.SerializeToString() | ||
|
||
def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: | ||
if device_type == "cuda": | ||
|
@@ -469,9 +503,18 @@ | |
self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) | ||
|
||
if self._model_path: | ||
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model) | ||
self._model_count += 1 | ||
if HAS_ONNX_REWRITTER: | ||
onnx_model = rewrite_and_optimize_model_path(self._model_path) | ||
sess = C.InferenceSession(session_options, onnx_model, True, self._read_config_from_model) | ||
else: | ||
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model) | ||
else: | ||
sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model) | ||
if HAS_ONNX_REWRITTER: | ||
onnx_model = rewrite_and_optimize_model_bytes(self._model_bytes) | ||
sess = C.InferenceSession(session_options, onnx_model, False, self._read_config_from_model) | ||
else: | ||
sess = C.InferenceSession(session_options, onnx_model.SerializeToString(), False, self._read_config_from_model) | ||
Check failure Code scanning / CodeQL Potentially uninitialized local variable Error
Local variable 'onnx_model' may be used before it is initialized.
|
||
|
||
if disabled_optimizers is None: | ||
disabled_optimizers = set() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see a few problems:
(1) onnxrewriter is not available in other language C/C++/Nuget API. It make the result inconsistent across different language.
(2) If onnxrewriter has a bug, you have to uninstall it since no explicit option to disable it.
(3) It might not work well with large model (>2GB) since .SerializeToString() is used.
If onnxrewriter is generic enough, why not implement it inside onnxruntime with C++?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am ok with whatever way onnx-rewriter should be called. I just use the only available way I can unblock llama with
onnxrt
dynamo backend (without this change, ORT is many times slower than inductor, and we will get no market share from PyTorch 2 features). I can add a flag to turn it on/off. Does it make sense?For how this thing should ultimately implemented, please talk with @thiagocrepaldi. I guess exporter (i.e., onnxscript) will eventually include this optimization pass after it's matured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've discussed this with Thiago. The current plan is that onnx-rewriter will be invoked by onnxscript and will be part of the exporter workflow. Having a flag is again not consistent between different language bindings. We don't want users to think that the Python bindings can give better perf with a flag, but that flag is not available for others. Can the llama model be unblocked by calling the rewriter separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nop. The only alternating solution I have in mind is to support
custom post-processing
pass in InferenceSession or DORT.