-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
22ca960
commit fa47e41
Showing
3 changed files
with
253 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
"""Onnx Pattern Rewriting with match condition parameter. | ||
This script shows how to define a rewriting rule based on patterns while | ||
utilizing the match condition parameter. | ||
First we write a dummy model with a several Reshape nodes and a Matmul node | ||
=================== | ||
""" | ||
|
||
|
||
import math | ||
import numpy as np | ||
import torch | ||
import onnx | ||
import onnx.helper as oh | ||
import onnx.numpy_helper as onh | ||
|
||
import onnxscript | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
def original_model(): | ||
inputs = [ | ||
oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 512]), | ||
oh.make_tensor_value_info("y", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 64]), | ||
] | ||
nodes = [ | ||
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const0"], value=oh.make_tensor("shape_a", onnx.TensorProto.INT64, [3], np.array([4, 512, 512]).astype(np.int64))), | ||
oh.make_node("Reshape", ["x", "_onx_shape_const0"], ["_onx_reshape0"]), | ||
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const1"], value=oh.make_tensor("shape_b", onnx.TensorProto.INT64, [3], np.array([4, 512, 64]).astype(np.int64))), | ||
oh.make_node("Reshape", ["y", "_onx_shape_const1"], ["_onx_reshape1"]), | ||
oh.make_node("MatMul", ["_onx_reshape0", "_onx_reshape1"], ["_onx_matmul"]), | ||
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const2"], value=oh.make_tensor("shape_c", onnx.TensorProto.INT64, [4], np.array([1, 4, 512, 64]).astype(np.int64))), | ||
oh.make_node("Reshape", ["_onx_matmul", "_onx_shape_const2"], ["_onx_reshape2"]), | ||
] | ||
outputs = [ | ||
oh.make_tensor_value_info("_onx_reshape2", onnx.TensorProto.FLOAT, []), | ||
] | ||
model = oh.make_model( | ||
oh.make_graph( | ||
nodes, | ||
"experiment", | ||
inputs, | ||
outputs, | ||
), | ||
opset_imports=[ | ||
oh.make_opsetid("", 18), | ||
oh.make_opsetid("com.microsoft", 18), | ||
], | ||
) | ||
return model | ||
|
||
|
||
model = original_model() | ||
onnx.save(model, 'test.onnx') | ||
onnx.checker.check_model(model) | ||
|
||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
op = pattern.onnxop | ||
msft_op = pattern.msft_op | ||
|
||
|
||
def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): | ||
reshape_a = op.Reshape(input_a, shape_a) | ||
reshape_b = op.Reshape(input_b, shape_b) | ||
matmul = op.MatMul(reshape_a, reshape_b) | ||
return op.Reshape(matmul, shape_c) | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c): | ||
del shape_a # Unused | ||
del shape_b # Unused | ||
del shape_c # Unused | ||
return op.MatMul(input_a, input_b) | ||
|
||
|
||
#################################### | ||
# Write condition to check if we need to replace the pattern | ||
# ===================== | ||
|
||
|
||
def check_if_need_reshape(match_bindings) -> bool: | ||
""" | ||
Args: | ||
match_bindings: The match binding dictionary from a MatchResult. | ||
Returns: | ||
bool: True if we need to replace the pattern, False otherwise. | ||
""" | ||
input_a_shape = match_bindings["input_a"].shape | ||
input_b_shape = match_bindings["input_b"].shape | ||
shape_c = match_bindings["shape_c"].value_as_np_array | ||
if shape_c is None: | ||
return False | ||
if not isinstance(shape_c, np.ndarray): | ||
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) | ||
return False | ||
if len(shape_c.shape) != 1: | ||
logger.info( | ||
"Unexpected final shape. The shape of 'shape' value is %s", | ||
shape_c.shape, | ||
) | ||
return False | ||
shape_c = shape_c.tolist() | ||
|
||
# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape | ||
# information. So, we need to check if the shape is None and return False. | ||
if input_a_shape is None or input_b_shape is None or shape_c is None: | ||
logger.info("Shape information is not available for the inputs and outputs.") | ||
return False | ||
|
||
dim_a = len(input_a_shape) | ||
dim_b = len(input_b_shape) | ||
|
||
# 1. Check if input shapes are broadcastable | ||
# 1.a. If the first input is 1-D, check whether | ||
# the dim matches the last second dim of the second input. | ||
mimic_matmul_broadcast_behavior = False | ||
if dim_a < 2: | ||
if input_a_shape[-1] != input_b_shape[-2]: | ||
logger.info("Original shape is not MatMul compatible.") | ||
return False | ||
else: | ||
input_a_shape = [1, *input_a_shape] | ||
dim_a = len(input_a_shape) | ||
mimic_matmul_broadcast_behavior = True | ||
# 1.b. If the second input is 1-D, check whether | ||
# the dim matches the last dim of the first input. | ||
if dim_b < 2: | ||
if input_b_shape[-1] != input_a_shape[-1]: | ||
logger.info("Original shape is not MatMul compatible.") | ||
return False | ||
else: | ||
input_b_shape = [*input_b_shape, 1] | ||
dim_b = len(input_b_shape) | ||
mimic_matmul_broadcast_behavior = True | ||
# 1.c. If both inputs are at least 2-D, check whether | ||
# the last dimension of the first input matches the second | ||
# last dimension of the second input, and shape[:-2] are | ||
# broadcastable. | ||
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] | ||
input_b_shape_except_last_dim = input_b_shape[:-1] | ||
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] | ||
for idx, (dim_from_a, dim_from_b) in enumerate( | ||
zip( | ||
reversed(input_a_shape_except_second_last_dim), | ||
reversed(input_b_shape_except_last_dim), | ||
) | ||
): | ||
if dim_from_a not in {1, dim_from_b}: | ||
logger.info("Original shape is not broadcastable.") | ||
return False | ||
elif idx > 0: | ||
broadcast_matmul_output_shape = [ | ||
max(dim_from_a, dim_from_b), | ||
*broadcast_matmul_output_shape, | ||
] | ||
|
||
# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) | ||
# Prepend the broadcast_matmul_output_shape with the longer shape of input | ||
if dim_a > dim_b: | ||
longer_shape = input_a_shape | ||
shorter_shape = input_b_shape | ||
else: | ||
longer_shape = input_b_shape | ||
shorter_shape = input_a_shape | ||
broadcast_matmul_output_shape = ( | ||
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape | ||
) | ||
if mimic_matmul_broadcast_behavior and dim_b == 2: | ||
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] | ||
if mimic_matmul_broadcast_behavior and dim_a == 2: | ||
broadcast_matmul_output_shape.pop(-2) | ||
if shape_c != broadcast_matmul_output_shape: | ||
logger.info( | ||
"Final output shape is not the same. Expected %s vs actual %s", | ||
shape_c, | ||
broadcast_matmul_output_shape, | ||
) | ||
return False | ||
|
||
return True | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
two_reshapes_matmul_reshape_rule = pattern.RewriteRule( | ||
two_reshapes_matmul_reshape_pattern, # target pattern | ||
matmul_with_two_shape_inputs, # replacement pattern | ||
check_if_need_reshape, # match_condition function | ||
) | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=[two_reshapes_matmul_reshape_rule], | ||
) | ||
|
||
onnx.checker.check_model(model_with_rewrite) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters