From 4f99b4bdbb6b10c46b3f58cc87df41b815e6447e Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 25 Apr 2024 16:11:59 -0700 Subject: [PATCH 1/4] Fix opset import test in pattern_test.py (#1461) From https://github.com/microsoft/onnxscript/pull/1451#discussion_r1580208670, the model string should be modified with added function domain. And now we can add checker. --- onnxscript/rewriter/pattern_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index f268f9d25..a30a2341c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -2,6 +2,7 @@ import unittest import numpy as np +import onnx.checker import onnx.parser from onnxscript import ir @@ -337,7 +338,7 @@ def double(op, x): agraph (float[N] x) => (float[M] z) { - z = afunction (x) + z = pkg.custom.afunction (x) } afunction (x) => (z) @@ -355,6 +356,7 @@ def double(op, x): self.assertEqual( model.functions[("pkg.custom", "afunction", "")].opset_imports["custom.domain"], 10 ) + onnx.checker.check_model(ir.serde.serialize_model(model)) if __name__ == "__main__": From 9280fe7beb925898512596bbef068c02c3d0e30c Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 25 Apr 2024 16:37:51 -0700 Subject: [PATCH 2/4] Propagate relevant info from old values to new values (#1459) When we create new values to replace old ones, we lose some of the information associated with the old values. Propagate this information. (See https://github.com/microsoft/onnxscript/issues/1455) --- .../rewriter/cast_constant_of_shape_test.py | 13 ++++--------- onnxscript/rewriter/pattern.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/cast_constant_of_shape_test.py index 1a124f7cf..c16ac082d 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/cast_constant_of_shape_test.py @@ -2,7 +2,6 @@ import onnx.checker import onnx.parser -import onnx.printer from onnxscript import ir from onnxscript.rewriter import cast_constant_of_shape @@ -16,8 +15,7 @@ def test_cast_after_constant_of_shape_is_fused(self): agraph (int64[2] input_x) => (float16[1, 4] output) { constant = ConstantOfShape (input_x) - temp = Cast (constant) - output = Identity (temp) + output = Cast (constant) } """ ) @@ -25,11 +23,9 @@ def test_cast_after_constant_of_shape_is_fused(self): model = ir.serde.deserialize_model(input_model_proto) count = cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) - self.assertEqual(len(model.graph), 2) + self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) output_model_proto = ir.serde.serialize_model(model) - # TODO: Eliminating `temp` in above example causes a failure. - # Rewriter changes graph output name, doesn't introduce type for it onnx.checker.check_model(output_model_proto, True) def test_cast_after_constant_of_shape_without_value_is_fused(self): @@ -39,15 +35,14 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self): agraph (int64[2] input_x) => (float16[1, 4] output) { constant = ConstantOfShape (input_x) - temp = Cast (constant) - output = Identity (temp) + output = Cast (constant) } """ ) model = ir.serde.deserialize_model(model_proto) count = cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) - self.assertEqual(len(model.graph), 2) + self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) output_model_proto = ir.serde.serialize_model(model) onnx.checker.check_model(output_model_proto, True) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 12ead6cba..180ac1717 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -910,13 +910,21 @@ def _apply_deltas( else: deleted_nodes, inserted_nodes = delta # Replace deleted nodes with inserted nodes. - # However, we merge the last deleted node and last inserted node - # to avoid replacing the values produced by the last deleted node - # in all places where they are used. So, we reuse the output - # values from the last deleted node and replace the node itself # TODO: simplify this last_deleted = deleted_nodes[-1] last_inserted = inserted_nodes[-1] + + for old_value, new_value in zip(last_deleted.outputs, last_inserted.outputs): + # Propagate relevant info from old value to new value + # TODO(Rama): Perhaps we should merge old and new types. As of now, new + # values don't have type information. Note that this could be a problem + # for semantics-altering rewrite-rules: we should allow users to override + # this for such rules. + new_value.type = old_value.type + new_value.shape = old_value.shape + new_value.const_value = old_value.const_value + new_value.name = old_value.name + # Reconnect the users of the deleted node to use the new outputs _convenience.replace_all_uses_with(last_deleted.outputs, last_inserted.outputs) # Update graph/function outputs if the node generates output From 7dcddeae3089a9c76cdcd9735c9f8ad93258e2d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 26 Apr 2024 08:45:03 -0700 Subject: [PATCH 3/4] Remove `-v` in CI tests to produce shorter logs (#1469) --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 3716f7e74..f28d6ce34 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -78,7 +78,7 @@ jobs: - name: Pull Test Data run: git lfs pull - name: Run tests - run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml + run: nox -t ${{ matrix.nox-tag }} --forcecolor -- --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml env: CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" From 997beb2bdfe75adfbc44f3864844637d4bccc534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 26 Apr 2024 19:13:33 +0200 Subject: [PATCH 4/4] Add one test failing for the optimizer after a model optimized and inlined (#1465) Co-authored-by: Justin Chu Co-authored-by: Ti-Tai Wang --- testdata/dort_models/llama_forward.onnx | 3 ++ tests/optimizer/test_models.py | 39 ++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 testdata/dort_models/llama_forward.onnx diff --git a/testdata/dort_models/llama_forward.onnx b/testdata/dort_models/llama_forward.onnx new file mode 100644 index 000000000..9f3676d1e --- /dev/null +++ b/testdata/dort_models/llama_forward.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6de32573c9127923c867dc047ea5c886042573c3f9383e22299dea42f18a4306 +size 27225 diff --git a/tests/optimizer/test_models.py b/tests/optimizer/test_models.py index 6de8cd2da..ce78a8ac3 100644 --- a/tests/optimizer/test_models.py +++ b/tests/optimizer/test_models.py @@ -6,10 +6,12 @@ import numpy as np import onnx +import onnx.inliner import onnxruntime import parameterized from onnxscript import optimizer +from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils _SKIP_TABLE = {} @@ -64,6 +66,41 @@ def test_model_runs_and_matches_accuracy_after_optimization(self, model_name): for output, expected_output in zip(outputs, expected_outputs): np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3) + def test_optimizer_after_inlining(self): + model_dir = pathlib.Path(model_folder_path) / ".." / "dort_models" + filename = model_dir / "llama_forward.onnx" + + onnx_model = onnx.load(filename) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # first time + onnx_model = optimizer.optimize(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + onnx_model = ort_rewriter.rewrite(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # inline + onnx_model = onnx.inliner.inline_local_functions(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + # second time + onnx_model = optimizer.optimize(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + onnx_model = ort_rewriter.rewrite(onnx_model) + onnxruntime.InferenceSession( + onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2)