Skip to content

Commit

Permalink
Complete subsections
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Apr 30, 2024
1 parent fa47e41 commit 3d8ff4c
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 131 deletions.
117 changes: 61 additions & 56 deletions docs/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,32 @@
===================
"""

import logging

Check notice

Code scanning / CodeQL

Unused import Note documentation

Import of 'logging' is not used.

import numpy

Check notice

Code scanning / CodeQL

Unused import Note documentation

Import of 'numpy' is not used.

import math
import numpy as np
import torch
import onnx

Check notice

Code scanning / CodeQL

Unused import Note documentation

Import of 'onnx' is not used.
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
from onnxscript.rewriter import pattern


def original_model():
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 512]),
oh.make_tensor_value_info("y", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 64]),
]
nodes = [
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const0"], value=oh.make_tensor("shape_a", onnx.TensorProto.INT64, [3], np.array([4, 512, 512]).astype(np.int64))),
oh.make_node("Reshape", ["x", "_onx_shape_const0"], ["_onx_reshape0"]),
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const1"], value=oh.make_tensor("shape_b", onnx.TensorProto.INT64, [3], np.array([4, 512, 64]).astype(np.int64))),
oh.make_node("Reshape", ["y", "_onx_shape_const1"], ["_onx_reshape1"]),
oh.make_node("MatMul", ["_onx_reshape0", "_onx_reshape1"], ["_onx_matmul"]),
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const2"], value=oh.make_tensor("shape_c", onnx.TensorProto.INT64, [4], np.array([1, 4, 512, 64]).astype(np.int64))),
oh.make_node("Reshape", ["_onx_matmul", "_onx_shape_const2"], ["_onx_reshape2"]),
]
outputs = [
oh.make_tensor_value_info("_onx_reshape2", onnx.TensorProto.FLOAT, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model
from onnxscript import FLOAT, opset18, script
from onnxscript import ir

Check notice

Code scanning / CodeQL

Unused import Note documentation

Import of 'ir' is not used.
from onnxscript.rewriter import _ir_utils, pattern


@script()
def original_model(A: FLOAT[1, 4, 512, 512], B: FLOAT[1, 4, 512, 64]) -> FLOAT[1, 4, 512, 64]:
shape_a = opset18.Constant(value_ints=[4, 512, 512])
reshape_a = opset18.Reshape(A, shape_a)

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning documentation

This assignment to 'reshape_a' is unnecessary as it is
redefined
before this value is used.
shape_b = opset18.Constant(value_ints=[4, 512, 64])
reshape_a = opset18.Reshape(B, shape_b)
matmul = opset18.MatMul(reshape_a, reshape_a)
shape_c = opset18.Constant(value_ints=[1, 4, 512, 64])
result = opset18.Reshape(matmul, shape_c)
return result


model = original_model()
onnx.save(model, 'test.onnx')
onnx.checker.check_model(model)
model = original_model.to_model_proto()
# onnx.checker.check_model(model)


####################################
Expand All @@ -76,10 +55,7 @@ def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shap
# =====================


def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c):
del shape_a # Unused
del shape_b # Unused
del shape_c # Unused
def matmul(op, input_a, input_b, **_):
return op.MatMul(input_a, input_b)


Expand All @@ -89,7 +65,14 @@ def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c):


def check_if_need_reshape(match_bindings) -> bool:
"""
"""If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
Args:
match_bindings: The match binding dictionary from a MatchResult.
Expand All @@ -99,7 +82,9 @@ def check_if_need_reshape(match_bindings) -> bool:
"""
input_a_shape = match_bindings["input_a"].shape
input_b_shape = match_bindings["input_b"].shape
shape_c = match_bindings["shape_c"].value_as_np_array
# TODO: Get a helper func to get const_value
shape_c_value = _ir_utils.propagate_const_value(match_bindings["shape_c"])
shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr]
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
Expand All @@ -118,6 +103,8 @@ def check_if_need_reshape(match_bindings) -> bool:
if input_a_shape is None or input_b_shape is None or shape_c is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = list(input_a_shape)
input_b_shape = list(input_b_shape)

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)
Expand Down Expand Up @@ -196,14 +183,32 @@ def check_if_need_reshape(match_bindings) -> bool:
# Create Rewrite Rule and Apply to Model
# =====================

two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_with_two_shape_inputs, # replacement pattern
check_if_need_reshape, # match_condition function
)
model_with_rewrite = onnxscript.rewriter.rewrite(

def apply_rewrite(
model,
two_reshapes_matmul_reshape_pattern, # target pattern
matmul, # replacement pattern
check_if_need_reshape, # match_condition function
):
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul, # replacement pattern
check_if_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet(rules=[two_reshapes_matmul_reshape_rule])

Check notice

Code scanning / CodeQL

Unused local variable Note documentation

Variable rewrite_rule_set is not used.
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(

Check notice

Code scanning / CodeQL

Unused local variable Note documentation

Variable model_with_rewrite is not used.
model,
pattern_rewrite_rules=[two_reshapes_matmul_reshape_rule],
)


model_with_rewrite = apply_rewrite(
model,
pattern_rewrite_rules=[two_reshapes_matmul_reshape_rule],
two_reshapes_matmul_reshape_pattern,
matmul,
check_if_need_reshape,
)

Check warning

Code scanning / CodeQL

Use of the return value of a procedure Warning documentation

The result of
apply_rewrite
is used even though it is always None.

onnx.checker.check_model(model_with_rewrite)
# onnx.checker.check_model(model_with_rewrite)
172 changes: 122 additions & 50 deletions docs/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,88 +6,160 @@
===================
"""


import math
import numpy as np

import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


def original_model():
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("y", onnx.TensorProto.FLOAT, shape=[]),
]
nodes = [
oh.make_node("Add", ["x", "y"], ["_onx_add0"]),
oh.make_node("Constant", inputs=[], outputs=['_onx_const0'], value_float=math.sqrt(2)),
oh.make_node("Div", ["_onx_add0", "_onx_const0"], ["_onx_div0"]),
oh.make_node("Erf", ["_onx_div0"], ["_onx_erf0"]),
oh.make_node("Constant", inputs=[], outputs=['_onx_const1'], value_float=1.0),
oh.make_node("Add", ["_onx_erf0", "_onx_const1"], ["_onx_add1"]),
oh.make_node("Mul", ["_onx_add0", "_onx_add1"], ["_onx_mul0"]),
oh.make_node("Constant", inputs=[], outputs=['_onx_const2'], value_float=0.5),
oh.make_node("Mul", ["_onx_const2", "_onx_mul0"], ["_onx_mul1"]),

]
outputs = [
oh.make_tensor_value_info("_onx_mul1", onnx.TensorProto.FLOAT, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model
@script()
def original_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
input_add = opset18.Add(X, Y)
sqrt2 = opset18.Constant(value_float=math.sqrt(2))
erf = opset18.Erf(input_add / sqrt2)
add_const = opset18.Constant(value_float=1.0)
plus_one = erf + add_const
mul1 = input_add * plus_one
mul_const = opset18.Constant(value_float=0.5)
result = mul_const * mul1
return result


model = original_model()
model = original_model.to_model_proto()
onnx.checker.check_model(model)


####################################
# Model demonstrating multiple patterns and variations of GELU activation
# =====================


@script()
def commute_model(X: FLOAT[64, 128], Y: FLOAT[64, 128]) -> FLOAT[64, 128]:
# Create first GELU variant
sqrt2_v1 = opset18.Constant(value_float=math.sqrt(2))
erf_v1 = opset18.Erf(X / sqrt2_v1)
add_const_v1 = opset18.Constant(value_float=1.0)
plus_one_v1 = erf_v1 + add_const_v1
mul1_v1 = X * plus_one_v1
mul_const_v1 = opset18.Constant(value_float=0.5)
gelu1 = mul_const_v1 * mul1_v1

# Create second GELU variant
sqrt2_v2 = opset18.Constant(value_float=math.sqrt(2))
erf_v2 = opset18.Erf(Y / sqrt2_v2)
add_const_v2 = opset18.Constant(value_float=1.0)
plus_one_v2 = erf_v2 + add_const_v2
mul1_v2 = Y * plus_one_v2
mul_const_v2 = opset18.Constant(value_float=0.5)
gelu2 = mul1_v2 * mul_const_v2

# Add both GELU functions
result = opset18.Add(gelu1, gelu2)
return result


commute_model = commute_model.to_model_proto()
# onnx.checker.check_model(commute_model)


####################################
# The target pattern
# =====================

op = pattern.onnxop
msft_op = pattern.msft_op


def erf_gelu_pattern(x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5


####################################
# The replacement pattern
# =====================


def gelu(x):
return msft_op.Gelu(x)
def gelu(op, x):
return op.Gelu(x, domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================

rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],

def apply_rewrite(model, target_pattern, replacement_pattern):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied


def apply_rewrite_with_ruleset(
model, erf_gelu_pattern, erf_gelu_pattern_2, replacement_pattern
):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
# pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing mutliple rules

Check warning on line 128 in docs/rewriter/examples/erfgelu.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "mutliple" is a misspelling of "multiple" Raw Output: ./docs/rewriter/examples/erfgelu.py:128:80: "mutliple" is a misspelling of "multiple"
)
return model_with_rewrite_applied


def apply_rewrite_with_commute(model, target_pattern, replacement_pattern):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule], commute=True)
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied


# Rewrite-Simple
model_with_rewrite = apply_rewrite(model, erf_gelu_pattern, gelu)
# onnx.checker.check_model(model_with_rewrite)

# Rewrite-Single-Patterns
# Incorrect number of rewrites
model_with_single_rewrite_ruleset = apply_rewrite(commute_model, erf_gelu_pattern, gelu)
# onnx.checker.check_model(model_with_single_rewrite_ruleset)

# Rewrite-Multiple-Patterns-RuleSet
model_with_rewrite_ruleset = apply_rewrite_with_ruleset(
commute_model, erf_gelu_pattern, erf_gelu_pattern_2, gelu
)

onnx.checker.check_model(model_with_rewrite)
# onnx.checker.check_model(model_with_rewrite_ruleset)

# Rewrite-Multiple-Patterns-Commute
model_with_rewrite_commute = apply_rewrite_with_commute(commute_model, erf_gelu_pattern, gelu)
# onnx.checker.check_model(model_with_rewrite_commute)
Binary file added docs/rewriter/examples/img/erfgelu_03_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_04_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_05_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_06_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/rewriter/examples/img/erfgelu_07_commute.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 3d8ff4c

Please sign in to comment.