Skip to content

Commit

Permalink
Unify single-output and multi-output pattern matchers (#1515)
Browse files Browse the repository at this point in the history
* Migrate all rewrite-rule functions to uniformly use an extra
first-parameter "op/context".
* Align and factor out common logic of the two rewrite-algorithms. Now,
they differ only in the core-matching algorithm, and share all the rest.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
gramalingam and justinchuby authored May 10, 2024
1 parent fefea96 commit 66d34e4
Show file tree
Hide file tree
Showing 15 changed files with 532 additions and 529 deletions.
15 changes: 7 additions & 8 deletions docs/tutorial/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@ def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1
# The target pattern
# =====================

_op = pattern.onnxop


def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = _op.Reshape(input_a, shape_a)
reshape_b = _op.Reshape(input_b, shape_b)
matmul = _op.MatMul(reshape_a, reshape_b)
return _op.Reshape(matmul, shape_c)
def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
reshape_b = op.Reshape(input_b, shape_b)
matmul = op.MatMul(reshape_a, reshape_b)
return op.Reshape(matmul, shape_c)


####################################
Expand All @@ -65,7 +63,7 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):


def check_if_not_need_reshape(
input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
Expand All @@ -75,6 +73,7 @@ def check_if_not_need_reshape(
If the above are true, then we don't need the reshapes.
"""
del context # Reserved for future extensions
input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
Expand Down
18 changes: 8 additions & 10 deletions docs/tutorial/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
# The target pattern
# =====================

_op = pattern.onnxop

def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

def erf_gelu_pattern(x):
return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(x):
return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
def erf_gelu_pattern_2(op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5


####################################
Expand All @@ -98,7 +96,7 @@ def gelu(op, x: ir.Value):
def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
gelu, # Replacement
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
Expand All @@ -111,11 +109,11 @@ def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
gelu, # Replacement
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
Expand All @@ -131,7 +129,7 @@ def apply_rewrite_with_ruleset(model):
def apply_rewrite_with_commute(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
Expand Down
1 change: 0 additions & 1 deletion docs/tutorial/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ Firstly, include all the rewriter relevant imports.
from onnxscript.rewriter import pattern
from onnxscript import ir

_op = pattern.onnxop
```

Then create a target pattern that needs to be replaced using onnxscript operators.
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/rewriter/broadcast_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# condition to check if we need to replace the pattern
def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool:
def check_if_not_need_reshape(context, input_a, input_b, shape_c, **_) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
Expand Down Expand Up @@ -126,7 +126,7 @@ def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool:
return True


def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
# TODO: Modified from `value_ints` to `value` to match pattern in benchmark models.
# This implementation misses pattern of Constants with `value_ints` attribute.
# See more at https://github.com/microsoft/onnx-rewriter/issues/191.
Expand All @@ -142,7 +142,7 @@ def matmul(op, input_a, input_b, **_):
return op.MatMul(input_a, input_b)


def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c):
def one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
matmul = op.MatMul(reshape_a, input_b)
return op.Reshape(matmul, shape_c)
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/cast_constant_of_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)


def cast_constant_of_shape(shape, scalar, dtype):
def cast_constant_of_shape(op, shape, scalar, dtype):
constant = op.ConstantOfShape(shape, value=scalar)
return op.Cast(constant, to=dtype)

Expand All @@ -23,7 +23,7 @@ def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir
return op.ConstantOfShape(shape, value=cast_value)


def cast_constant_of_shape_without_value(shape, dtype):
def cast_constant_of_shape_without_value(op, shape, dtype):
constant = op.ConstantOfShape(shape)
return op.Cast(constant, to=dtype)

Expand Down
7 changes: 1 addition & 6 deletions onnxscript/rewriter/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

from onnxscript.rewriter import pattern

op = pattern.onnxop


# Pattern to match against
def erf_gelu_pattern(x):
def erf_gelu_pattern(op, x):
# erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# half = pattern.Constant(0.5)
# sqrt2 = pattern.Constant(1.4142)
Expand All @@ -19,9 +17,6 @@ def erf_gelu_pattern(x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))


msft_op = pattern.msft_op


# Replacement
def gelu(op, x):
return op.Gelu(x, domain="com.microsoft")
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/gemm_to_matmul_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


# Pattern to match against
def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c):
def reshape_gemm_reshape_pattern(op, input_a, input_b, input_c, shape_a, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
# TODO: Temporary workaround to support benchmodels.
# Tracked by https://github.com/microsoft/onnx-rewriter/issues/197.
Expand Down
Loading

0 comments on commit 66d34e4

Please sign in to comment.