From 47fb031b98da291d005a71a782f29f72514db806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 11 Jun 2024 11:15:46 +0200 Subject: [PATCH] Proposal to group function defining a pattern into a class. (#1596) Signed-off-by: Xavier Dupre --- onnxscript/rewriter/llama_rule_sets.py | 116 ++++++++++++++----------- onnxscript/rewriter/pattern.py | 52 +++++++++++ 2 files changed, 115 insertions(+), 53 deletions(-) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index f6a347773..96aa25905 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -9,65 +9,75 @@ op = orp.onnxop -def transpose_identity_pattern(op, x, perm): - return op.Transpose(x, perm=perm) - - -def transpose_identity_check(context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool: - if isinstance(perm, ir.RefAttr): - return False - if perm.type == ir.AttributeType.INTS: - if perm.value == list(range(len(perm.value))): - return True - return False - - -def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None): - return op.Identity(x) - - -def transpose_transpose_pattern(op, x, perm1, perm2): - return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) - +class TransposeIdentity(orp.RewriteRuleAsClass): + """Replaces ``Transpose(. perm=perm)`` + when the permutation is identity. + """ -def transpose_transpose_check( - context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr -) -> bool: - if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): + @classmethod + def pattern(cls, op, x, perm): + return op.Transpose(x, perm=perm) + + @classmethod + def check(cls, context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool: + if isinstance(perm, ir.RefAttr): + return False + if perm.type == ir.AttributeType.INTS: + if perm.value == list(range(len(perm.value))): + return True return False - return True - - -def _apply_transpose(perm: tuple[int, ...], on: list[int]) -> list[int]: - assert len(perm) == len(on), "length mismatch" - res = [-1 for i in on] - for i, p in enumerate(perm): - res[i] = on[p] - return res - -def _apply_transposes(perms: list[tuple[int, ...]], on: list[int] | None = None) -> list[int]: - if on is None: - on = list(range(len(perms[0]))) - for p in perms: - on = _apply_transpose(p, on) - return on - - -def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr): - first = list(range(len(perm1.value))) - last = _apply_transposes([perm1.value, perm2.value]) - if first == last: + @classmethod + def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): return op.Identity(x) - return op.Transpose(x, perm=last) -transpose_identity_rule = orp.RewriteRule( - transpose_identity_pattern, transpose_identity_rewrite, transpose_identity_check -) -transpose_transpose_rule = orp.RewriteRule( - transpose_transpose_pattern, transpose_transpose_rewrite, transpose_transpose_check -) +class TransposeTranspose(orp.RewriteRuleAsClass): + """Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)`` + when both permutations are inverse. + """ + + @classmethod + def pattern(cls, op, x, perm1, perm2): + return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) + + @classmethod + def check( + cls, context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr + ) -> bool: + if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): + return False + return True + + @classmethod + def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]: + assert len(perm) == len(on), "length mismatch" + res = [-1 for i in on] + for i, p in enumerate(perm): + res[i] = on[p] + return res + + @classmethod + def _apply_transposes( + cls, perms: list[tuple[int, ...]], on: list[int] | None = None + ) -> list[int]: + if on is None: + on = list(range(len(perms[0]))) + for p in perms: + on = cls._apply_transpose(p, on) + return on + + @classmethod + def rewrite(cls, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr): + first = list(range(len(perm1.value))) + last = cls._apply_transposes([perm1.value, perm2.value]) + if first == last: + return op.Identity(x) + return op.Transpose(x, perm=last) + + +transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity) +transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose) def llama_p0_rule_set() -> orp.RewriteRuleSet: diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 7a48b0629..11df934d7 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1119,6 +1119,58 @@ def replace_pattern(new_pattern): return [replace_pattern(p) for p in self._target_pattern.commute()] +class RewriteRuleAsClass: + """Defines a class grouping method pattern, rewrite, check. + This class is then given to function :func:`make_rewrite_rule_from_class` + to define a new rule. + """ + + @classmethod + def pattern(cls, op, *_) -> Any: + raise NotImplementedError("Method 'pattern' must be overwritten.") + + @classmethod + def rewrite(cls, op, *_) -> Any: + raise NotImplementedError("Method 'rewrite' must be overwritten.") + + @classmethod + def check(cls, context, *_) -> bool: + return True + + +def make_rewrite_rule_from_class(rule_class: type | RewriteRuleAsClass) -> RewriteRule: + """Creates a RewriteRule from a class defining the function + pattern, rewrite, check with class method. It makes it is easier + to read when a module contains multiple patterns. + + Example:: + + class TransposeIdentity(RewriteRuleAsClass): + @classmethod + def pattern(cls, op, x, perm): + return op.Transpose(x, perm=perm) + + @classmethod + def check(cls, context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool: + if isinstance(perm, ir.RefAttr): + return False + if perm.type == ir.AttributeType.INTS: + if perm.value == list(range(len(perm.value))): + return True + return False + + @classmethod + def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): + return op.Identity(x) + + transpose_identity_rule = make_rewrite_rule_from_class(TransposeIdentity) + """ + assert hasattr(rule_class, "pattern"), f"Method 'pattern' is missing from {rule_class!r}." + assert hasattr(rule_class, "rewrite"), f"Method 'rewrite' is missing from {rule_class!r}." + assert hasattr(rule_class, "check"), f"Method 'check' is missing from {rule_class!r}." + return RewriteRule(rule_class.pattern, rule_class.rewrite, rule_class.check) + + def _apply_delta( graph_or_function: ir.Graph | ir.Function, node: ir.Node,