Skip to content

Commit

Permalink
Add one test failing for the optimizer after a model optimized and in…
Browse files Browse the repository at this point in the history
…lined (#1465)

Co-authored-by: Justin Chu <[email protected]>
Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
3 people authored Apr 26, 2024
1 parent 7dcddea commit 997beb2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
3 changes: 3 additions & 0 deletions testdata/dort_models/llama_forward.onnx
Git LFS file not shown
39 changes: 38 additions & 1 deletion tests/optimizer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import numpy as np
import onnx
import onnx.inliner
import onnxruntime
import parameterized

from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter
from onnxscript.utils import evaluation_utils

_SKIP_TABLE = {}
Expand Down Expand Up @@ -64,6 +66,41 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name):
for output, expected_output in zip(outputs, expected_outputs):
np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3)

def test_optimizer_after_inlining(self):
model_dir = pathlib.Path(model_folder_path) / ".." / "dort_models"
filename = model_dir / "llama_forward.onnx"

onnx_model = onnx.load(filename)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# first time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# inline
onnx_model = onnx.inliner.inline_local_functions(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# second time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit 997beb2

Please sign in to comment.