-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Call onnx-rewritter when possible in onnxruntime.InferenceSession #19348
base: main
Are you sure you want to change the base?
Conversation
try: | ||
from onnxrewriter.rewriter.transformers import rewrite | ||
from onnxrewriter.optimizer import optimize | ||
except: |
Check notice
Code scanning / CodeQL
Except block handles 'BaseException' Note
onnx_model = rewrite_and_optimize_model_bytes(self._model_bytes) | ||
sess = C.InferenceSession(session_options, onnx_model, False, self._read_config_from_model) | ||
else: | ||
sess = C.InferenceSession(session_options, onnx_model.SerializeToString(), False, self._read_config_from_model) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
try: | ||
from onnxrewriter.rewriter.transformers import rewrite | ||
from onnxrewriter.optimizer import optimize | ||
except: |
Check warning
Code scanning / lintrunner
RUFF/E722 Warning
See https://docs.astral.sh/ruff/rules/bare-except
def rewrite_and_optimize_model_bytes(model): | ||
assert HAS_ONNX_REWRITTER | ||
onnx_model = onnx.ModelProto() | ||
onnx_model.ParseFromString(self._model_bytes) |
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
|
||
def rewrite_and_optimize_model_path(model_path): | ||
assert HAS_ONNX_REWRITTER | ||
onnx_model = onnx.load(self._model_path) |
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
HAS_ONNX_REWRITTER = True | ||
try: | ||
from onnxrewriter.rewriter.transformers import rewrite | ||
from onnxrewriter.optimizer import optimize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see a few problems:
(1) onnxrewriter is not available in other language C/C++/Nuget API. It make the result inconsistent across different language.
(2) If onnxrewriter has a bug, you have to uninstall it since no explicit option to disable it.
(3) It might not work well with large model (>2GB) since .SerializeToString() is used.
If onnxrewriter is generic enough, why not implement it inside onnxruntime with C++?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am ok with whatever way onnx-rewriter should be called. I just use the only available way I can unblock llama with onnxrt
dynamo backend (without this change, ORT is many times slower than inductor, and we will get no market share from PyTorch 2 features). I can add a flag to turn it on/off. Does it make sense?
For how this thing should ultimately implemented, please talk with @thiagocrepaldi. I guess exporter (i.e., onnxscript) will eventually include this optimization pass after it's matured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've discussed this with Thiago. The current plan is that onnx-rewriter will be invoked by onnxscript and will be part of the exporter workflow. Having a flag is again not consistent between different language bindings. We don't want users to think that the Python bindings can give better perf with a flag, but that flag is not available for others. Can the llama model be unblocked by calling the rewriter separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nop. The only alternating solution I have in mind is to support custom post-processing
pass in InferenceSession or DORT.
We don't want onnx-rewriter to be called from within InferenceSession. This leads to inconsistencies between different language bindings. This script should be invoked as part of the exporter workflow and enabled with an optional parameter that indicates the ORT version since the rewriter will be tied to the ORT ver (due to the fusions and the availability of the relevant ops in that ver of ORT). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see my comment.
No description provided.