Skip to content

Commit

Permalink
Add first rewriting patterns for llama adding onnxruntime contrib ops (
Browse files Browse the repository at this point in the history
…#1622)

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 20, 2024
1 parent 2d13bbe commit 8758d28
Show file tree
Hide file tree
Showing 10 changed files with 622 additions and 20 deletions.
4 changes: 2 additions & 2 deletions onnxscript/optimizer/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def foldable_value(self, name: str, value):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.warning(
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
name,
type(value),
)
return None

if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT:
logger.warning(
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
name,
value.nbytes,
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/rewriter/generic_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,18 @@ def _match_backward(
graph_node,
)
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs):
if len(list(graph_input.uses())) != len(list(pattern_input.uses())):
self._hint(
"BACKWARD: one input is used outside the pattern",
"-- pattern",
pattern_node,
"-- model",
graph_node,
)
return self.none(starting_node, inspect.currentframe().f_lineno)

for graph_value, pattern_value in zip(graph_node.inputs, pattern_node.inputs):
# TODO(rama): Handle constant-pattern
pattern_pred = pattern_value.producer()
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter import rewrite as _rewrite
from onnxscript.rewriter.onnxruntime import (
fused_matmul_rule_sets,
group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
Expand All @@ -20,6 +21,7 @@
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
*group_normalization_merge_silu.rules.rules,
*fused_matmul_rule_sets.fused_matmul_rule_sets(),
]


Expand Down
179 changes: 179 additions & 0 deletions onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from typing import ClassVar

import onnxscript.rewriter.pattern as orp

op = orp.onnxop


class FusedMatMulDiv1(orp.RewriteRuleAsClass):
"""Replaces ``MatMul + Div`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y, cst):
return op.Div(op.MatMul(x, y), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
if cst.const_value is None:
return False
value = cst.const_value.numpy()
if value.size > 1:
return False
return True

@classmethod
def rewrite(cls, op, x, y, cst):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
return op.FusedMatMul(x, y, alpha=1 / c, domain="com.microsoft")


class FusedMatMulDiv2(orp.RewriteRuleAsClass):
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y, cst):
return op.Div(op.FusedMatMul(x, y, domain="com.microsoft"), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
if cst.const_value is None:
return False
if cst.const_value.numpy().size > 1:
return False
return True

@classmethod
def rewrite(cls, op, x, y, cst):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
node = list(x.uses())[0][0] # noqa: RUF015

kwargs = {}
alpha = node.attributes.get("alpha", None)
kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c
for name in ["transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")


class _TransposeMatMulBase(orp.RewriteRuleAsClass):
_pos: ClassVar = 1

@classmethod
def check(cls, context, x, y) -> bool:
perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
return perm == expected_perm

@classmethod
def rewrite(cls, op, x, y):
node = list((x if cls._pos == 2 else y).uses())[0][0] # noqa: RUF015
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
name = "transA" if cls._pos == 1 else "transB"
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(x, y, **kwargs, domain="com.microsoft")


class TransposeMatMul1(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.MatMul(op.Transpose(x), y)


class TransposeFusedMatMul1(TransposeMatMul1):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(op.Transpose(x), y, domain="com.microsoft")


class TransposeMatMul2(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""

_pos: ClassVar = 2

@classmethod
def pattern(cls, op, x, y):
return op.MatMul(x, op.Transpose(y))


class TransposeFusedMatMul2(TransposeMatMul2):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.FusedMatMul(x, op.Transpose(y), domain="com.microsoft")


class MatMulTranspose(orp.RewriteRuleAsClass):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.Transpose(op.MatMul(x, y))

@classmethod
def check(cls, context, x, y) -> bool:
matmul = list(x.uses())[0][0] # noqa: RUF015
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
perm = transpose.attributes["perm"].value
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
return perm == expected_perm

@classmethod
def rewrite(cls, op, x, y):
node = list(x.uses())[0][0] # noqa: RUF015
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
for name in ["transA", "transB"]:
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(y, x, **kwargs, domain="com.microsoft")


class FusedMatMulTranspose(MatMulTranspose):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""

@classmethod
def pattern(cls, op, x, y):
return op.Transpose(op.FusedMatMul(x, y, domain="com.microsoft"))


def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
"""Returns a set of rules introducting onnxruntime contrib obs.
This requires onnxruntime to run the model after
it is rewritten.
Returns:
RewriteRuleSet
"""
return orp.RewriteRuleSet(
[
orp.make_rewrite_rule_from_class(FusedMatMulDiv1, True),
orp.make_rewrite_rule_from_class(FusedMatMulDiv2, True),
orp.make_rewrite_rule_from_class(FusedMatMulTranspose, True),
orp.make_rewrite_rule_from_class(MatMulTranspose, True),
orp.make_rewrite_rule_from_class(TransposeMatMul1, True),
orp.make_rewrite_rule_from_class(TransposeFusedMatMul1, True),
orp.make_rewrite_rule_from_class(TransposeMatMul2, True),
orp.make_rewrite_rule_from_class(TransposeFusedMatMul2, True),
]
)
Loading

0 comments on commit 8758d28

Please sign in to comment.