From 8758d2879df54efc8b3d2cd2b2a0935b18405555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 20 Jun 2024 12:27:16 +0200 Subject: [PATCH] Add first rewriting patterns for llama adding onnxruntime contrib ops (#1622) Signed-off-by: Xavier Dupre --- onnxscript/optimizer/constant_folding.py | 4 +- onnxscript/rewriter/generic_pattern.py | 12 + onnxscript/rewriter/onnxruntime/__init__.py | 2 + .../onnxruntime/fused_matmul_rule_sets.py | 179 +++++++++ .../fused_matmul_rule_sets_test.py | 363 ++++++++++++++++++ onnxscript/rewriter/pattern.py | 34 +- .../tools/benchmark/benchmark_helpers.py | 34 +- onnxscript/tools/benchmark/export_model.py | 2 +- .../tools/benchmark/export_model_batch.py | 8 +- .../tools/benchmark/export_model_test.py | 4 +- 10 files changed, 622 insertions(+), 20 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py create mode 100644 onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py index 82c0f2536..d119c41e9 100644 --- a/onnxscript/optimizer/constant_folding.py +++ b/onnxscript/optimizer/constant_folding.py @@ -82,7 +82,7 @@ def foldable_value(self, name: str, value): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. # So, a constant-value of type sequence is not folded, but it can be used # to optimize subsequent operations when possible. - logger.warning( + logger.info( "Skip storing constant folded value %s due to unsupported type %s.", name, type(value), @@ -90,7 +90,7 @@ def foldable_value(self, name: str, value): return None if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: - logger.warning( + logger.info( "Skip storing constant folded nvalue %s due to large size %s.", name, value.nbytes, diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 51957ff47..d0daf2e06 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -296,6 +296,18 @@ def _match_backward( graph_node, ) return self.none(starting_node, inspect.currentframe().f_lineno) + + for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): + if len(list(graph_input.uses())) != len(list(pattern_input.uses())): + self._hint( + "BACKWARD: one input is used outside the pattern", + "-- pattern", + pattern_node, + "-- model", + graph_node, + ) + return self.none(starting_node, inspect.currentframe().f_lineno) + for graph_value, pattern_value in zip(graph_node.inputs, pattern_node.inputs): # TODO(rama): Handle constant-pattern pattern_pred = pattern_value.producer() diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index f76dd680c..aa7b9a0ae 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -7,6 +7,7 @@ from onnxscript.rewriter import function_rule, pattern from onnxscript.rewriter import rewrite as _rewrite from onnxscript.rewriter.onnxruntime import ( + fused_matmul_rule_sets, group_normalization_merge_silu, instance_to_group_normalization, softmax, @@ -20,6 +21,7 @@ *instance_to_group_normalization.rules.rules, # NOTE: group normalization merge silu should be applied after instance to group normalization *group_normalization_merge_silu.rules.rules, + *fused_matmul_rule_sets.fused_matmul_rule_sets(), ] diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py new file mode 100644 index 000000000..83f263304 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import ClassVar + +import onnxscript.rewriter.pattern as orp + +op = orp.onnxop + + +class FusedMatMulDiv1(orp.RewriteRuleAsClass): + """Replaces ``MatMul + Div`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y, cst): + return op.Div(op.MatMul(x, y), cst) + + @classmethod + def check(cls, context, x, y, cst) -> bool: + if cst.const_value is None: + return False + value = cst.const_value.numpy() + if value.size > 1: + return False + return True + + @classmethod + def rewrite(cls, op, x, y, cst): + value = cst.const_value.numpy() + c = float(value[0] if value.shape == (1,) else value) + return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft") + + +class FusedMatMulDiv2(orp.RewriteRuleAsClass): + """Replaces ``FusedMatMul + Div`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y, cst): + return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst) + + @classmethod + def check(cls, context, x, y, cst) -> bool: + if cst.const_value is None: + return False + if cst.const_value.numpy().size > 1: + return False + return True + + @classmethod + def rewrite(cls, op, x, y, cst): + value = cst.const_value.numpy() + c = float(value[0] if value.shape == (1,) else value) + node = list(x.uses())[0][0] # noqa: RUF015 + + kwargs = {} + alpha = node.attributes.get("alpha", None) + kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c + for name in ["transA", "transB", "transBatchA", "transBatchB"]: + att = node.attributes.get(name) + if att: + kwargs[name] = att.value + return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + + +class _TransposeMatMulBase(orp.RewriteRuleAsClass): + _pos: ClassVar = 1 + + @classmethod + def check(cls, context, x, y) -> bool: + perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 + expected_perm = list(range(len(perm))) + expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] + return perm == expected_perm + + @classmethod + def rewrite(cls, op, x, y): + node = list((x if cls._pos == 2 else y).uses())[0][0] # noqa: RUF015 + kwargs = {} + for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: + att = node.attributes.get(name) + if att: + kwargs[name] = att.value + name = "transA" if cls._pos == 1 else "transB" + kwargs[name] = 1 - kwargs.get(name, 0) + return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft") + + +class TransposeMatMul1(_TransposeMatMulBase): + """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y): + return op.MatMul(op.Transpose(x), y) + + +class TransposeFusedMatMul1(TransposeMatMul1): + """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y): + return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft") + + +class TransposeMatMul2(_TransposeMatMulBase): + """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + + _pos: ClassVar = 2 + + @classmethod + def pattern(cls, op, x, y): + return op.MatMul(x, op.Transpose(y)) + + +class TransposeFusedMatMul2(TransposeMatMul2): + """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y): + return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft") + + +class MatMulTranspose(orp.RewriteRuleAsClass): + """Replaces ``MatMul + Transpose`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y): + return op.Transpose(op.MatMul(x, y)) + + @classmethod + def check(cls, context, x, y) -> bool: + matmul = list(x.uses())[0][0] # noqa: RUF015 + transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 + perm = transpose.attributes["perm"].value + expected_perm = list(range(len(perm))) + expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] + return perm == expected_perm + + @classmethod + def rewrite(cls, op, x, y): + node = list(x.uses())[0][0] # noqa: RUF015 + kwargs = {} + for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: + att = node.attributes.get(name) + if att: + kwargs[name] = att.value + for name in ["transA", "transB"]: + kwargs[name] = 1 - kwargs.get(name, 0) + return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft") + + +class FusedMatMulTranspose(MatMulTranspose): + """Replaces ``MatMul + Transpose`` by FusedMatMul.""" + + @classmethod + def pattern(cls, op, x, y): + return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft")) + + +def fused_matmul_rule_sets() -> orp.RewriteRuleSet: + """Returns a set of rules introducting onnxruntime contrib obs. + This requires onnxruntime to run the model after + it is rewritten. + + Returns: + RewriteRuleSet + """ + return orp.RewriteRuleSet( + [ + orp.make_rewrite_rule_from_class(FusedMatMulDiv1, True), + orp.make_rewrite_rule_from_class(FusedMatMulDiv2, True), + orp.make_rewrite_rule_from_class(FusedMatMulTranspose, True), + orp.make_rewrite_rule_from_class(MatMulTranspose, True), + orp.make_rewrite_rule_from_class(TransposeMatMul1, True), + orp.make_rewrite_rule_from_class(TransposeFusedMatMul1, True), + orp.make_rewrite_rule_from_class(TransposeMatMul2, True), + orp.make_rewrite_rule_from_class(TransposeFusedMatMul2, True), + ] + ) diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py new file mode 100644 index 000000000..a7d170e69 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py @@ -0,0 +1,363 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Any + +import numpy as np +import onnx +import onnx.reference +import onnx.reference.op_run + +import onnxscript.rewriter.onnxruntime.fused_matmul_rule_sets as fused_matmul_rule_sets +from onnxscript import ir + +FLOAT = onnx.TensorProto.FLOAT + + +class FusedMatMul(onnx.reference.op_run.OpRun): + op_domain = "com.microsoft" + + def _run( + self, + A, + B, + alpha: float = 1, + transA: int = 0, + transB: int = 0, + transBatchA: int = 0, + transBatchB: int = 0, + ): + assert transBatchA == 0, f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}" + assert transBatchB == 0, f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}" + if transA: + perm = list(range(len(A.shape))) + dim = len(perm) + perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] + A = np.transpose(A, perm) + if transB: + perm = list(range(len(B.shape))) + dim = len(perm) + perm[dim - 2], perm[dim - 1] = perm[dim - 1], perm[dim - 2] + B = np.transpose(B, perm) + a = np.array(alpha, dtype=A.dtype) + return (np.matmul(A, B) * a,) + + +class OrtRuleSetsTest(unittest.TestCase): + def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: + feeds: dict[str, Any] = {} + for i in model.graph.input: + ish = tuple(i.type.tensor_type.shape.dim) + # Creates an input tensor with a dimension defined by the onnx model + # or equals to i + 2 with i being the dimension index. + # The tensor is kept small to make the test fast. + shape = tuple( + (d.dim_value if d.dim_value > 0 else i + 2) for i, d in enumerate(ish) + ) + if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: + feeds[i.name] = np.random.randn(*shape).astype(np.float32) + else: + raise AssertionError(f"Not implemented for input {i}") + return feeds + + def _check_model( + self, + model: onnx.ModelProto, + optimized_model: onnx.ModelProto, + feeds: dict[str, Any] | None = None, + atol: float = 0.0, + rtol: float = 1e-7, + ): + if not feeds: + feeds = self._get_random_inputs(model) + ref = onnx.reference.ReferenceEvaluator(model, new_ops=[FusedMatMul]) + opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) + expected = ref.run(None, feeds) + got = opt.run(None, feeds) + self.assertEqual(len(expected), len(got)) + for a, b in zip(expected, got): + np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) + + @classmethod + def _fused_matmul_div_models(cls): + models = [ + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node( + "FusedMatMul", + ["X", "Y"], + ["xyc"], + transA=1, + transB=0, + alpha=0.4, + transBatchA=0, + transBatchB=0, + domain="com.microsoft", + ), + onnx.helper.make_node("Div", ["xyc", "D"], ["Z"]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [6, "a"]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + [ + onnx.numpy_helper.from_array( + np.array([0.8], dtype=np.float32), name="D" + ), + ], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), + onnx.helper.make_node("Div", ["xy", "C"], ["Z"]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, ["a", 6]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + [ + onnx.numpy_helper.from_array( + np.array([0.6], dtype=np.float32), name="C" + ) + ], + ), + opset_imports=[onnx.helper.make_opsetid("", 18)], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), + onnx.helper.make_node("Div", ["xy", "C"], ["xyc"]), + onnx.helper.make_node("Div", ["xyc", "D"], ["Z"]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, ["a", 6]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + [ + onnx.numpy_helper.from_array( + np.array([0.6], dtype=np.float32), name="C" + ), + onnx.numpy_helper.from_array( + np.array([0.8], dtype=np.float32), name="D" + ), + ], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + ], + ), + ] + return models + + def test_ort_rule_set_fused_matmul_div(self): + for model_proto in self._fused_matmul_div_models(): + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + + self.assertEqual(["FusedMatMul"], [n.op_type for n in rewritten_model.graph.node]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @classmethod + def _transposed_fused_matmul_div_models(cls): + models = [ + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node( + "FusedMatMul", + ["X", "Y"], + ["xy"], + domain="com.microsoft", + alpha=0.5, + ), + onnx.helper.make_node("Transpose", ["xy"], ["Z"], perm=[1, 0]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), + onnx.helper.make_node("Transpose", ["xy"], ["Z"], perm=[1, 0]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), + onnx.helper.make_node("MatMul", ["Xt", "Y"], ["Z"]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), + onnx.helper.make_node( + "FusedMatMul", + ["Xt", "Y"], + ["Z"], + domain="com.microsoft", + alpha=0.5, + ), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["Y"], ["Yt"], perm=[1, 0]), + onnx.helper.make_node("MatMul", ["X", "Yt"], ["Z"]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["Y"], ["Yt"], perm=[1, 0]), + onnx.helper.make_node( + "FusedMatMul", + ["X", "Yt"], + ["Z"], + domain="com.microsoft", + alpha=0.5, + ), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + ] + return models + + def test_ort_rule_set_transpose_fused_matmul_div(self): + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + for model_proto in self._transposed_fused_matmul_div_models(): + ir_model = ir.serde.deserialize_model(model_proto) + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + + self.assertEqual(["FusedMatMul"], [n.op_type for n in rewritten_model.graph.node]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @classmethod + def _should_not_match(cls): + models = [ + onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), + onnx.helper.make_node("MatMul", ["Xt", "Y"], ["Z"]), + onnx.helper.make_node("Transpose", ["Xt"], ["W"], perm=[1, 0]), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), + ], + [ + onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None]), + onnx.helper.make_tensor_value_info("W", FLOAT, [None, None]), + ], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ), + ] + return models + + def test_should_not_match(self): + for model_proto in self._should_not_match(): + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + + self.assertEqual( + ["Transpose", "MatMul", "Transpose"], + [n.op_type for n in rewritten_model.graph.node], + ) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 534ce7997..d8bdb6e65 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1034,6 +1034,7 @@ def __init__( condition_function: Callable | None = None, matcher: PatternMatcher | Callable[[GraphPattern], PatternMatcher] | None = None, verbose: int = 0, + name: str | None = None, ) -> None: """Create a rewrite rule. @@ -1048,6 +1049,7 @@ def __init__( matcher: The pattern matcher that will be used to match the pattern. If not provided, a default matcher will be used. verbose: The verbosity level of the rule. + name: for debugging purpose """ if not isinstance(target_pattern, GraphPattern): @@ -1070,6 +1072,14 @@ def __init__( else: self._matcher = matcher(self._target_pattern) self._verbose = verbose + self.name = name + + def __str__(self) -> str: + if self.name: + return f"{self.__class__.__name__}(..., name={self.name!r})" + return ( + f"{self.__class__.__name__}({self._target_pattern}, {self._replacement_pattern})" + ) def try_rewrite( self, @@ -1141,7 +1151,9 @@ def check(cls, context, *_) -> bool: return True -def make_rewrite_rule_from_class(rule_class: type | RewriteRuleAsClass) -> RewriteRule: +def make_rewrite_rule_from_class( + rule_class: type | RewriteRuleAsClass, generic: bool = False +) -> RewriteRule: """Creates a RewriteRule from a class defining the function pattern, rewrite, check with class method. It makes it is easier to read when a module contains multiple patterns. @@ -1171,7 +1183,22 @@ def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): assert hasattr(rule_class, "pattern"), f"Method 'pattern' is missing from {rule_class!r}." assert hasattr(rule_class, "rewrite"), f"Method 'rewrite' is missing from {rule_class!r}." assert hasattr(rule_class, "check"), f"Method 'check' is missing from {rule_class!r}." - return RewriteRule(rule_class.pattern, rule_class.rewrite, rule_class.check) + if generic: + import onnxscript.rewriter.generic_pattern as orpp + + return RewriteRule( + rule_class.pattern, + rule_class.rewrite, + rule_class.check, + orpp.GenericPatternMatcher, + name=rule_class.__name__, # type: ignore[union-attr] + ) + return RewriteRule( + rule_class.pattern, + rule_class.rewrite, + rule_class.check, + name=rule_class.__name__, # type: ignore[union-attr] + ) def _apply_delta( @@ -1258,3 +1285,6 @@ def apply_to_model(self, model: ir.Model, verbose: int | None = None) -> int: for function in model.functions.values(): count += self._apply_to_graph_or_function(model, function, verbose=verbose) return count + + def __iter__(self): + yield from self.rules diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index 12e074c34..36d9084fa 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -21,6 +21,8 @@ import onnxscript.optimizer import onnxscript.rewriter import onnxscript.rewriter.llama_rule_sets as rules +import onnxscript.rewriter.onnxruntime as ort_rules +import onnxscript.rewriter.pattern as orp from onnxscript import ir from onnxscript.optimizer.remove_unused import remove_unused_nodes @@ -216,7 +218,7 @@ def common_export( inputs: inputs dynamic_shapes: dynamic shapes target_opset: target opset - optimization: optimization scenario + optimization: optimization scenario, '/' separated values verbose: verbosity stats: if not None, populates this dictionary with statistics about time @@ -257,6 +259,7 @@ def common_export( if stats is not None: stats["export_time"] = time.perf_counter() - begin + stats["filesize"] = os.stat(filename).st_size if verbose: print(f"[common_export] exporter done in {time.perf_counter() - begin}s") @@ -303,8 +306,9 @@ def apply_rule_sets( Returns: optimized model """ + assert rule_sets, "No need to call apply_rule_sets for an empty set." if verbose: - print("[apply_rule_sets] deserialize model") + print(f"[apply_rule_sets] deserialize model before {rule_sets}") begin = time.perf_counter() ir_model = ir.serde.deserialize_model(model_proto) end = time.perf_counter() - begin @@ -319,11 +323,14 @@ def apply_rule_sets( if rule_set_name == "llama0": rule_set = rules.llama_p0_rule_set() + elif rule_set_name == "onnxruntime": + rule_set = orp.RewriteRuleSet(ort_rules.ORT_PATTERN_REWRITE_RULES) else: raise ValueError(f"Unexpected rule_set name {rule_set_name!r}") begin = time.perf_counter() rule_set.apply_to_model(ir_model) + remove_unused_nodes(ir_model) end = time.perf_counter() - begin if stats is not None: stats[f"opt_rule_{rule_set_name}_time"] = end @@ -366,7 +373,7 @@ def optimize_model_proto( Args: model_proto: ModelProto - optimization: comma separated value + optimization: '/' separated value verbose: verbosity stats: if not None, populates this dictionary with statistics @@ -376,13 +383,25 @@ def optimize_model_proto( if not optimization: return model_proto - for value in optimization.split(","): + known_rule_sets = {"llama0", "onnxruntime"} + + rule_sets: list[str] = [] + for value in optimization.split("/"): + if value in known_rule_sets: + rule_sets.append(value) + continue + if value not in known_rule_sets and rule_sets: + model_proto = apply_rule_sets(model_proto, rule_sets, stats=stats, verbose=verbose) + del rule_sets[:] + continue + if verbose: print(f"[optimize_model_proto] start {value}") n_nodes = len(model_proto.graph.node) n_functions = len(model_proto.functions) begin = time.perf_counter() + if value == "optimize": model_proto = onnxscript.optimizer.optimize( model_proto, @@ -396,11 +415,6 @@ def optimize_model_proto( elif value == "inline": model_proto = onnx.inliner.inline_local_functions(model_proto) - elif value == "llama0": - model_proto = apply_rule_sets( - model_proto, ["llama0"], stats=stats, verbose=verbose - ) - else: raise AssertionError( f"Optimization step {value!r} is not implemented in {optimization!r}" @@ -418,6 +432,8 @@ def optimize_model_proto( f"[optimize_model_proto] {value} done in {end} " f"with +/- {delta} nodes, +/- {deltaf} functions" ) + if rule_sets: + model_proto = apply_rule_sets(model_proto, rule_sets, stats=stats, verbose=verbose) return model_proto diff --git a/onnxscript/tools/benchmark/export_model.py b/onnxscript/tools/benchmark/export_model.py index 16f599057..88d40dc27 100644 --- a/onnxscript/tools/benchmark/export_model.py +++ b/onnxscript/tools/benchmark/export_model.py @@ -25,7 +25,7 @@ def main(args=None): Example with a medium llama model:: - python -m onnxscript.tools.benchmark.export_model --model llama --device cuda --config large --num_hidden_layers=1 --dtype=float32 --dynamic=0 --verbose=1 --exporter=dynamo + python -m onnxscript.tools.benchmark.export_model --model llama --device cuda --config medium --num_hidden_layers=1 --dtype=float32 --dynamic=0 --verbose=1 --exporter=dynamo --optimization=rewrite/optimize/inline/llama0/onnxruntime """ ), repeat=(10, "number of inferences to measure"), diff --git a/onnxscript/tools/benchmark/export_model_batch.py b/onnxscript/tools/benchmark/export_model_batch.py index 58787b8fb..ffef9cbd4 100644 --- a/onnxscript/tools/benchmark/export_model_batch.py +++ b/onnxscript/tools/benchmark/export_model_batch.py @@ -60,11 +60,11 @@ def main(args: list[str] | None = None): configs: list[dict[str, Any]] = [ dict(exporter="eager"), dict(ort_optimize=1, exporter="script"), - dict(ort_optimize=1, optimization="optimize,rewrite,inline", exporter="script"), - dict(ort_optimize=0, optimization="optimize,rewrite,inline", exporter="script"), + dict(ort_optimize=1, optimization="optimize/rewrite/inline", exporter="script"), + dict(ort_optimize=0, optimization="optimize/rewrite/inline", exporter="script"), dict(ort_optimize=1, optimization="", exporter="dynamo"), - dict(ort_optimize=1, optimization="optimize,rewrite,inline", exporter="dynamo"), - dict(ort_optimize=0, optimization="optimize,rewrite,inline", exporter="dynamo"), + dict(ort_optimize=1, optimization="optimize/rewrite/inline", exporter="dynamo"), + dict(ort_optimize=0, optimization="optimize/rewrite/inline", exporter="dynamo"), ] common_kwargs: dict[str, Any] = kwargs.copy() common_kwargs["verbose"] = max(common_kwargs["verbose"] - 1, 0) diff --git a/onnxscript/tools/benchmark/export_model_test.py b/onnxscript/tools/benchmark/export_model_test.py index 6806e3135..aadb842ad 100644 --- a/onnxscript/tools/benchmark/export_model_test.py +++ b/onnxscript/tools/benchmark/export_model_test.py @@ -132,7 +132,7 @@ def test_export_model_phi_cpu_dynamo_llama0(self): "--exporter", "dynamo", "--optimization", - "rewrite,optimize,inline,llama0", + "rewrite/optimize/inline/llama0/onnxruntime", "--model", "phi", ] @@ -162,7 +162,7 @@ def test_export_model_phi3_cpu_dynamo_llama0(self): "--exporter", "dynamo", "--optimization", - "rewrite,optimize,inline,llama0", + "rewrite/optimize/inline/llama0", "--model", "phi3", ]