Skip to content

Commit

Permalink
add one test failing for the optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Apr 26, 2024
1 parent 0d98619 commit 6500709
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
3 changes: 3 additions & 0 deletions testdata/dort_models/llama_forward.onnx
Git LFS file not shown
40 changes: 39 additions & 1 deletion tests/optimizer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import parameterized

from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter
from onnxscript.utils import evaluation_utils

_SKIP_TABLE = {}
Expand Down Expand Up @@ -64,6 +65,43 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name):
for output, expected_output in zip(outputs, expected_outputs):
np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3)

def test_optimizer_after_inlining(self):
model_dir = pathlib.Path(model_folder_path) / ".." / "dort_models"
filename = model_dir / "llama_forward.onnx"
if not filename.exists():
self.skipTest(f"Model {filename!r} does not exist")

onnx_model = onnx.load(filename)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# first time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# inline
onnx_model = onnx.inliner.inline_local_functions(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)

# second time
onnx_model = optimizer.optimize(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
onnx_model = ort_rewriter.rewrite(onnx_model)
onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit 6500709

Please sign in to comment.