diff --git a/testdata/dort_models/llama_forward.onnx b/testdata/dort_models/llama_forward.onnx new file mode 100644 index 000000000..9f3676d1e --- /dev/null +++ b/testdata/dort_models/llama_forward.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6de32573c9127923c867dc047ea5c886042573c3f9383e22299dea42f18a4306 +size 27225 diff --git a/tests/optimizer/test_models.py b/tests/optimizer/test_models.py index 6de8cd2da..ce78a8ac3 100644 --- a/tests/optimizer/test_models.py +++ b/tests/optimizer/test_models.py @@ -6,10 +6,12 @@ import numpy as np import onnx +import onnx.inliner import onnxruntime import parameterized from onnxscript import optimizer +from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils _SKIP_TABLE = {} @@ -64,6 +66,41 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name): for output, expected_output in zip(outputs, expected_outputs): np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3) + def test_optimizer_after_inlining(self): + model_dir = pathlib.Path(model_folder_path) / ".." / "dort_models" + filename = model_dir / "llama_forward.onnx" + + onnx_model = onnx.load(filename) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # first time + onnx_model = optimizer.optimize(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + onnx_model = ort_rewriter.rewrite(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # inline + onnx_model = onnx.inliner.inline_local_functions(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # second time + onnx_model = optimizer.optimize(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + onnx_model = ort_rewriter.rewrite(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2)