Skip to content

Commit

Permalink
Proposal to group function defining a pattern into a class. (#1596)
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 11, 2024
1 parent 4c3a6be commit 47fb031
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 53 deletions.
116 changes: 63 additions & 53 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,65 +9,75 @@
op = orp.onnxop


def transpose_identity_pattern(op, x, perm):
return op.Transpose(x, perm=perm)


def transpose_identity_check(context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
if isinstance(perm, ir.RefAttr):
return False
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False


def transpose_identity_rewrite(op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)


def transpose_transpose_pattern(op, x, perm1, perm2):
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)

class TransposeIdentity(orp.RewriteRuleAsClass):
"""Replaces ``Transpose(. perm=perm)``
when the permutation is identity.
"""

def transpose_transpose_check(
context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
@classmethod
def pattern(cls, op, x, perm):
return op.Transpose(x, perm=perm)

@classmethod
def check(cls, context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
if isinstance(perm, ir.RefAttr):
return False
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False
return True


def _apply_transpose(perm: tuple[int, ...], on: list[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
return res


def _apply_transposes(perms: list[tuple[int, ...]], on: list[int] | None = None) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = _apply_transpose(p, on)
return on


def transpose_transpose_rewrite(op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.value)))
last = _apply_transposes([perm1.value, perm2.value])
if first == last:
@classmethod
def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)
return op.Transpose(x, perm=last)


transpose_identity_rule = orp.RewriteRule(
transpose_identity_pattern, transpose_identity_rewrite, transpose_identity_check
)
transpose_transpose_rule = orp.RewriteRule(
transpose_transpose_pattern, transpose_transpose_rewrite, transpose_transpose_check
)
class TransposeTranspose(orp.RewriteRuleAsClass):
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
when both permutations are inverse.
"""

@classmethod
def pattern(cls, op, x, perm1, perm2):
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)

@classmethod
def check(
cls, context, x: ir.Value, perm1: ir.Attr | ir.RefAttr, perm2: ir.Attr | ir.RefAttr
) -> bool:
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False
return True

@classmethod
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
return res

@classmethod
def _apply_transposes(
cls, perms: list[tuple[int, ...]], on: list[int] | None = None
) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = cls._apply_transpose(p, on)
return on

@classmethod
def rewrite(cls, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.value)))
last = cls._apply_transposes([perm1.value, perm2.value])
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)


transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)


def llama_p0_rule_set() -> orp.RewriteRuleSet:
Expand Down
52 changes: 52 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,58 @@ def replace_pattern(new_pattern):
return [replace_pattern(p) for p in self._target_pattern.commute()]


class RewriteRuleAsClass:
"""Defines a class grouping method pattern, rewrite, check.
This class is then given to function :func:`make_rewrite_rule_from_class`
to define a new rule.
"""

@classmethod
def pattern(cls, op, *_) -> Any:
raise NotImplementedError("Method 'pattern' must be overwritten.")

@classmethod
def rewrite(cls, op, *_) -> Any:
raise NotImplementedError("Method 'rewrite' must be overwritten.")

@classmethod
def check(cls, context, *_) -> bool:
return True


def make_rewrite_rule_from_class(rule_class: type | RewriteRuleAsClass) -> RewriteRule:
"""Creates a RewriteRule from a class defining the function
pattern, rewrite, check with class method. It makes it is easier
to read when a module contains multiple patterns.
Example::
class TransposeIdentity(RewriteRuleAsClass):
@classmethod
def pattern(cls, op, x, perm):
return op.Transpose(x, perm=perm)
@classmethod
def check(cls, context, x: ir.Value, perm: ir.Attr | ir.RefAttr) -> bool:
if isinstance(perm, ir.RefAttr):
return False
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False
@classmethod
def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None):
return op.Identity(x)
transpose_identity_rule = make_rewrite_rule_from_class(TransposeIdentity)
"""
assert hasattr(rule_class, "pattern"), f"Method 'pattern' is missing from {rule_class!r}."
assert hasattr(rule_class, "rewrite"), f"Method 'rewrite' is missing from {rule_class!r}."
assert hasattr(rule_class, "check"), f"Method 'check' is missing from {rule_class!r}."
return RewriteRule(rule_class.pattern, rule_class.rewrite, rule_class.check)


def _apply_delta(
graph_or_function: ir.Graph | ir.Function,
node: ir.Node,
Expand Down

0 comments on commit 47fb031

Please sign in to comment.