Skip to content

Commit

Permalink
[rewriter][docs] Add Tutorial for Pattern-based Rewrites (#1413)
Browse files Browse the repository at this point in the history
Add Tutorial for Pattern-based Rewrites
  • Loading branch information
shubhambhokare1 authored May 2, 2024
1 parent 03b55e3 commit b81a38a
Show file tree
Hide file tree
Showing 20 changed files with 597 additions and 49 deletions.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ tutorial/index
api/index
intermediate_representation/index
auto_examples/index
rewriter/index
articles/index
```

Expand Down
198 changes: 198 additions & 0 deletions docs/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Onnx Pattern Rewriting with match condition parameter.
This script shows how to define a rewriting rule based on patterns while
utilizing the match condition parameter.
First we write a dummy model with a several Reshape nodes and a Matmul node
===================
"""

import logging

import numpy as np
import onnx

import onnxscript
from onnxscript import FLOAT, ir, opset18, script
from onnxscript.rewriter import _ir_utils, pattern

logger = logging.getLogger(__name__)


@script()
def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1, 4, 512, 64]:
# NOTE: Modified from `value_ints` to `value`
shape_a = opset18.Constant(value=[4, 512, 512])
reshape_a = opset18.Reshape(A, shape_a)
shape_b = opset18.Constant(value=[4, 512, 64])
reshape_b = opset18.Reshape(B, shape_b)
matmul = opset18.MatMul(reshape_a, reshape_b)
shape_c = opset18.Constant(value=[1, 4, 512, 64])
result = opset18.Reshape(matmul, shape_c)
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================

_op = pattern.onnxop


def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = _op.Reshape(input_a, shape_a)
reshape_b = _op.Reshape(input_b, shape_b)
matmul = _op.MatMul(reshape_a, reshape_b)
return _op.Reshape(matmul, shape_c)


####################################
# The replacement pattern
# =====================


def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
return op.MatMul(input_a, input_b)


####################################
# Write condition to check if we need to replace the pattern
# =====================


def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
"""
input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
shape_c_value = _ir_utils.propagate_const_value(shape_c)
shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
return False
if len(shape_c.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c.shape,
)
return False
shape_c = shape_c.tolist()

# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None or shape_c is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = list(input_a_shape)
input_b_shape = list(input_b_shape)

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)

# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if dim_a < 2:
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
dim_a = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if dim_b < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
dim_b = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
zip(
reversed(input_a_shape_except_second_last_dim),
reversed(input_b_shape_except_last_dim),
)
):
if dim_from_a not in {1, dim_from_b}:
logger.info("Original shape is not broadcastable.")
return False
elif idx > 0:
broadcast_matmul_output_shape = [
max(dim_from_a, dim_from_b),
*broadcast_matmul_output_shape,
]

# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if dim_a > dim_b:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c,
broadcast_matmul_output_shape,
)
return False

return True


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_pattern, # replacement pattern
check_if_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
161 changes: 161 additions & 0 deletions docs/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
First a dummy model with a GELU activation
===================
"""

import math

import onnx

import onnxscript
from onnxscript import FLOAT, ir, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
input_add = opset18.Add(X, Y)
sqrt2 = opset18.Constant(value_float=math.sqrt(2))
erf = opset18.Erf(input_add / sqrt2)
add_const = opset18.Constant(value_float=1.0)
plus_one = erf + add_const
mul1 = input_add * plus_one
mul_const = opset18.Constant(value_float=0.5)
result = mul_const * mul1
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# Model demonstrating multiple patterns and variations of GELU activation
# =====================


@script()
def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
# Create first GELU variant
sqrt2_v1 = opset18.Constant(value_float=math.sqrt(2))
erf_v1 = opset18.Erf(X / sqrt2_v1)
add_const_v1 = opset18.Constant(value_float=1.0)
plus_one_v1 = erf_v1 + add_const_v1
mul1_v1 = X * plus_one_v1
mul_const_v1 = opset18.Constant(value_float=0.5)
gelu1 = mul_const_v1 * mul1_v1

# Create second GELU variant
sqrt2_v2 = opset18.Constant(value_float=math.sqrt(2))
erf_v2 = opset18.Erf(Y / sqrt2_v2)
add_const_v2 = opset18.Constant(value_float=1.0)
plus_one_v2 = erf_v2 + add_const_v2
mul1_v2 = Y * plus_one_v2
mul_const_v2 = opset18.Constant(value_float=0.5)
gelu2 = mul1_v2 * mul_const_v2

# Add both GELU functions
result = opset18.Add(gelu1, gelu2)
return result


commute_model = commute_model.to_model_proto()
onnx.checker.check_model(commute_model)


####################################
# The target pattern
# =====================

_op = pattern.onnxop


def erf_gelu_pattern(x):
return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(x):
return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5


####################################
# The replacement pattern
# =====================


def gelu(op, x: ir.Value):
return op.Gelu(x, domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied


def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
# pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
)
return model_with_rewrite_applied


def apply_rewrite_with_commute(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied


# Rewrite-Simple
model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(model_with_rewrite)

# Rewrite-Single-Patterns
# Incorrect number of rewrites
model_with_single_rewrite_ruleset = apply_rewrite(commute_model)
onnx.checker.check_model(model_with_single_rewrite_ruleset)

# Rewrite-Multiple-Patterns-RuleSet
model_with_rewrite_ruleset = apply_rewrite_with_ruleset(commute_model)
onnx.checker.check_model(model_with_rewrite_ruleset)

# Rewrite-Multiple-Patterns-Commute
model_with_rewrite_commute = apply_rewrite_with_commute(commute_model)
onnx.checker.check_model(model_with_rewrite_commute)
Binary file added docs/rewriter/examples/img/broadcast_01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/broadcast_02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_03_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_04_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_05_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_06_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_07_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions docs/rewriter/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Rewriter Tutorials

```{toctree}
rewrite_patterns
```
Loading

0 comments on commit b81a38a

Please sign in to comment.