Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rewriter][docs] Add Tutorial for Pattern-based Rewrites #1413

Merged
merged 9 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ tutorial/index
api/index
intermediate_representation/index
auto_examples/index
rewriter/index
articles/index
```

Expand Down
198 changes: 198 additions & 0 deletions docs/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Onnx Pattern Rewriting with match condition parameter.

This script shows how to define a rewriting rule based on patterns while
utilizing the match condition parameter.

First we write a dummy model with a several Reshape nodes and a Matmul node
===================
"""

import logging
Fixed Show fixed Hide fixed

import numpy as np
import onnx
Fixed Show fixed Hide fixed

import onnxscript
from onnxscript import FLOAT, ir, opset18, script
from onnxscript.rewriter import _ir_utils, pattern

logger = logging.getLogger(__name__)


@script()
def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1, 4, 512, 64]:
# NOTE: Modified from `value_ints` to `value`
shape_a = opset18.Constant(value=[4, 512, 512])
reshape_a = opset18.Reshape(A, shape_a)
Fixed Show fixed Hide fixed
shape_b = opset18.Constant(value=[4, 512, 64])
reshape_b = opset18.Reshape(B, shape_b)
matmul = opset18.MatMul(reshape_a, reshape_b)
Fixed Show fixed Hide fixed
shape_c = opset18.Constant(value=[1, 4, 512, 64])
result = opset18.Reshape(matmul, shape_c)
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================

_op = pattern.onnxop


def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = _op.Reshape(input_a, shape_a)
reshape_b = _op.Reshape(input_b, shape_b)
matmul = _op.MatMul(reshape_a, reshape_b)
return _op.Reshape(matmul, shape_c)


####################################
# The replacement pattern
# =====================


def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
Fixed Show fixed Hide fixed
return op.MatMul(input_a, input_b)


####################################
# Write condition to check if we need to replace the pattern
# =====================


def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.

To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)

If the above are true, then we don't need the reshapes.
"""
input_a_shape = input_a.shape
input_b_shape = input_b.shape
# TODO: Get a helper func to get const_value
shape_c_value = _ir_utils.propagate_const_value(shape_c)
shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
return False
if len(shape_c.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c.shape,
)
return False
shape_c = shape_c.tolist()

# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None or shape_c is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = list(input_a_shape)
input_b_shape = list(input_b_shape)

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)

# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if dim_a < 2:
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
dim_a = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if dim_b < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
dim_b = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
zip(
reversed(input_a_shape_except_second_last_dim),
reversed(input_b_shape_except_last_dim),
)
):
if dim_from_a not in {1, dim_from_b}:
logger.info("Original shape is not broadcastable.")
return False
elif idx > 0:
broadcast_matmul_output_shape = [
max(dim_from_a, dim_from_b),
*broadcast_matmul_output_shape,
]

# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if dim_a > dim_b:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c,
broadcast_matmul_output_shape,
)
return False

return True


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
Fixed Show fixed Hide fixed
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_pattern, # replacement pattern
check_if_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)
161 changes: 161 additions & 0 deletions docs/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Onnx Pattern Rewriting.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

This script shows how to define a rewriting rule based on patterns.

First a dummy model with a GELU activation
===================
"""

import math
Fixed Show fixed Hide fixed

import onnx

import onnxscript
from onnxscript import FLOAT, ir, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
input_add = opset18.Add(X, Y)
sqrt2 = opset18.Constant(value_float=math.sqrt(2))
erf = opset18.Erf(input_add / sqrt2)
add_const = opset18.Constant(value_float=1.0)
plus_one = erf + add_const
mul1 = input_add * plus_one
mul_const = opset18.Constant(value_float=0.5)
result = mul_const * mul1
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# Model demonstrating multiple patterns and variations of GELU activation
# =====================


@script()
def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
# Create first GELU variant
sqrt2_v1 = opset18.Constant(value_float=math.sqrt(2))
erf_v1 = opset18.Erf(X / sqrt2_v1)
add_const_v1 = opset18.Constant(value_float=1.0)
plus_one_v1 = erf_v1 + add_const_v1
mul1_v1 = X * plus_one_v1
mul_const_v1 = opset18.Constant(value_float=0.5)
gelu1 = mul_const_v1 * mul1_v1

# Create second GELU variant
sqrt2_v2 = opset18.Constant(value_float=math.sqrt(2))
erf_v2 = opset18.Erf(Y / sqrt2_v2)
add_const_v2 = opset18.Constant(value_float=1.0)
plus_one_v2 = erf_v2 + add_const_v2
mul1_v2 = Y * plus_one_v2
mul_const_v2 = opset18.Constant(value_float=0.5)
gelu2 = mul1_v2 * mul_const_v2

# Add both GELU functions
result = opset18.Add(gelu1, gelu2)
return result


commute_model = commute_model.to_model_proto()
onnx.checker.check_model(commute_model)


####################################
# The target pattern
# =====================

_op = pattern.onnxop


def erf_gelu_pattern(x):
return 0.5 * (x * (_op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(x):
return (x * (_op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5


####################################
# The replacement pattern
# =====================


def gelu(op, x: ir.Value):
Fixed Show fixed Hide fixed
return op.Gelu(x, domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
Fixed Show fixed Hide fixed
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied


def apply_rewrite_with_ruleset(model):
Fixed Show fixed Hide fixed
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
# pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
)
return model_with_rewrite_applied


def apply_rewrite_with_commute(model):
Fixed Show fixed Hide fixed
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied


# Rewrite-Simple
model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(model_with_rewrite)

# Rewrite-Single-Patterns
# Incorrect number of rewrites
model_with_single_rewrite_ruleset = apply_rewrite(commute_model)
onnx.checker.check_model(model_with_single_rewrite_ruleset)

# Rewrite-Multiple-Patterns-RuleSet
model_with_rewrite_ruleset = apply_rewrite_with_ruleset(commute_model)
onnx.checker.check_model(model_with_rewrite_ruleset)

# Rewrite-Multiple-Patterns-Commute
model_with_rewrite_commute = apply_rewrite_with_commute(commute_model)
onnx.checker.check_model(model_with_rewrite_commute)
Binary file added docs/rewriter/examples/img/broadcast_01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/broadcast_02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_02.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_03_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_04_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_05_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_06_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_07_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions docs/rewriter/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Rewriter Tutorials

```{toctree}
rewrite_patterns
```
Loading
Loading