Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call onnx-rewritter when possible in onnxruntime.InferenceSession #19348

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wschin
Copy link
Contributor

@wschin wschin commented Jan 31, 2024

No description provided.

try:
from onnxrewriter.rewriter.transformers import rewrite
from onnxrewriter.optimizer import optimize
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
onnx_model = rewrite_and_optimize_model_bytes(self._model_bytes)
sess = C.InferenceSession(session_options, onnx_model, False, self._read_config_from_model)
else:
sess = C.InferenceSession(session_options, onnx_model.SerializeToString(), False, self._read_config_from_model)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'onnx_model' may be used before it is initialized.
try:
from onnxrewriter.rewriter.transformers import rewrite
from onnxrewriter.optimizer import optimize
except:

Check warning

Code scanning / lintrunner

RUFF/E722 Warning

Do not use bare except.
See https://docs.astral.sh/ruff/rules/bare-except
def rewrite_and_optimize_model_bytes(model):
assert HAS_ONNX_REWRITTER
onnx_model = onnx.ModelProto()
onnx_model.ParseFromString(self._model_bytes)

Check failure

Code scanning / lintrunner

RUFF/F821 Error


def rewrite_and_optimize_model_path(model_path):
assert HAS_ONNX_REWRITTER
onnx_model = onnx.load(self._model_path)

Check failure

Code scanning / lintrunner

RUFF/F821 Error

HAS_ONNX_REWRITTER = True
try:
from onnxrewriter.rewriter.transformers import rewrite
from onnxrewriter.optimizer import optimize
Copy link
Contributor

@tianleiwu tianleiwu Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see a few problems:
(1) onnxrewriter is not available in other language C/C++/Nuget API. It make the result inconsistent across different language.
(2) If onnxrewriter has a bug, you have to uninstall it since no explicit option to disable it.
(3) It might not work well with large model (>2GB) since .SerializeToString() is used.

If onnxrewriter is generic enough, why not implement it inside onnxruntime with C++?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with whatever way onnx-rewriter should be called. I just use the only available way I can unblock llama with onnxrt dynamo backend (without this change, ORT is many times slower than inductor, and we will get no market share from PyTorch 2 features). I can add a flag to turn it on/off. Does it make sense?

For how this thing should ultimately implemented, please talk with @thiagocrepaldi. I guess exporter (i.e., onnxscript) will eventually include this optimization pass after it's matured.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've discussed this with Thiago. The current plan is that onnx-rewriter will be invoked by onnxscript and will be part of the exporter workflow. Having a flag is again not consistent between different language bindings. We don't want users to think that the Python bindings can give better perf with a flag, but that flag is not available for others. Can the llama model be unblocked by calling the rewriter separately?

Copy link
Contributor Author

@wschin wschin Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop. The only alternating solution I have in mind is to support custom post-processing pass in InferenceSession or DORT.

@pranavsharma
Copy link
Contributor

We don't want onnx-rewriter to be called from within InferenceSession. This leads to inconsistencies between different language bindings. This script should be invoked as part of the exporter workflow and enabled with an optional parameter that indicates the ORT version since the rewriter will be tied to the ORT ver (due to the fusions and the availability of the relevant ops in that ver of ORT).

Copy link
Contributor

@pranavsharma pranavsharma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants