Skip to content

Commit

Permalink
Add match_condition example
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Apr 25, 2024
1 parent 22ca960 commit fa47e41
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 5 deletions.
209 changes: 209 additions & 0 deletions docs/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Onnx Pattern Rewriting with match condition parameter.
This script shows how to define a rewriting rule based on patterns while
utilizing the match condition parameter.
First we write a dummy model with a several Reshape nodes and a Matmul node
===================
"""


import math
import numpy as np
import torch
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
from onnxscript.rewriter import pattern


def original_model():
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 512]),
oh.make_tensor_value_info("y", onnx.TensorProto.FLOAT, shape=[1, 4, 512, 64]),
]
nodes = [
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const0"], value=oh.make_tensor("shape_a", onnx.TensorProto.INT64, [3], np.array([4, 512, 512]).astype(np.int64))),
oh.make_node("Reshape", ["x", "_onx_shape_const0"], ["_onx_reshape0"]),
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const1"], value=oh.make_tensor("shape_b", onnx.TensorProto.INT64, [3], np.array([4, 512, 64]).astype(np.int64))),
oh.make_node("Reshape", ["y", "_onx_shape_const1"], ["_onx_reshape1"]),
oh.make_node("MatMul", ["_onx_reshape0", "_onx_reshape1"], ["_onx_matmul"]),
oh.make_node("Constant", inputs=[], outputs=["_onx_shape_const2"], value=oh.make_tensor("shape_c", onnx.TensorProto.INT64, [4], np.array([1, 4, 512, 64]).astype(np.int64))),
oh.make_node("Reshape", ["_onx_matmul", "_onx_shape_const2"], ["_onx_reshape2"]),
]
outputs = [
oh.make_tensor_value_info("_onx_reshape2", onnx.TensorProto.FLOAT, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model


model = original_model()
onnx.save(model, 'test.onnx')
onnx.checker.check_model(model)


####################################
# The target pattern
# =====================

op = pattern.onnxop
msft_op = pattern.msft_op


def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
reshape_b = op.Reshape(input_b, shape_b)
matmul = op.MatMul(reshape_a, reshape_b)
return op.Reshape(matmul, shape_c)


####################################
# The replacement pattern
# =====================


def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c):
del shape_a # Unused
del shape_b # Unused
del shape_c # Unused
return op.MatMul(input_a, input_b)


####################################
# Write condition to check if we need to replace the pattern
# =====================


def check_if_need_reshape(match_bindings) -> bool:
"""
Args:
match_bindings: The match binding dictionary from a MatchResult.
Returns:
bool: True if we need to replace the pattern, False otherwise.
"""
input_a_shape = match_bindings["input_a"].shape
input_b_shape = match_bindings["input_b"].shape
shape_c = match_bindings["shape_c"].value_as_np_array
if shape_c is None:
return False
if not isinstance(shape_c, np.ndarray):
logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c))
return False
if len(shape_c.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c.shape,
)
return False
shape_c = shape_c.tolist()

# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None or shape_c is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False

dim_a = len(input_a_shape)
dim_b = len(input_b_shape)

# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if dim_a < 2:
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
dim_a = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if dim_b < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
dim_b = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
zip(
reversed(input_a_shape_except_second_last_dim),
reversed(input_b_shape_except_last_dim),
)
):
if dim_from_a not in {1, dim_from_b}:
logger.info("Original shape is not broadcastable.")
return False
elif idx > 0:
broadcast_matmul_output_shape = [
max(dim_from_a, dim_from_b),
*broadcast_matmul_output_shape,
]

# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if dim_a > dim_b:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c,
broadcast_matmul_output_shape,
)
return False

return True


####################################
# Create Rewrite Rule and Apply to Model
# =====================

two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_with_two_shape_inputs, # replacement pattern
check_if_need_reshape, # match_condition function
)
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[two_reshapes_matmul_reshape_rule],
)

onnx.checker.check_model(model_with_rewrite)
Binary file added docs/rewriter/examples/img/broadcast_01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 44 additions & 5 deletions docs/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ For this example, we do not require a `match_condition` so that option is skippe
:lines: 84-87
```

Now that the rewrite rule has been created,
1. The user's model is converted into an intermediate IR representation.
2. The rewrite rule is applied to the nodes in the graph.
3. The intermediate IR with the replaced pattern nodes is converted back to an ONNX Model Proto.
In order to do the steps mentioned above:
Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:

1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`.
2. `function_rewrite_rules` : `(Optional)` This paramter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]`
3. `pattern_rewrite_rules` : `(Optional)` This paramter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjuction with `model`. This parameter is of type `Sequence[PatternRewriteRule]`.

Note: `pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence.

The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above:

```{literalinclude} examples/erfgelu.py
:lines: 88-91
Expand All @@ -67,6 +71,41 @@ The graph (on the left) consists of the target pattern before the rewrite rule i

## Utilizing match_condition parameter for pattern-matching

This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration.

Let us consider a model which consists of the following pattern.

![target_pattern](examples/img/broadcast_01.png)

Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx Matmul behaves like numpy.matmul and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following:

1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)

If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows:

```{literalinclude} examples/broadcast_matmul.py
:pyobject: check_if_need_reshape
```

Once we have the match_condition function, we can write a target pattern and replacement pattern in a similar way to the first example.

```{literalinclude} examples/broadcast_matmul.py
:pyobject: two_reshapes_matmul_reshape_pattern
```

```{literalinclude} examples/broadcast_matmul.py
:pyobject: matmul_with_two_shape_inputs
```

With all the necessary components in place, the pattern rewrite rule is created and then the `rewriter.rewrite` is called to apply the rewrite.

```{literalinclude} examples/broadcast_matmul.py
:lines: 199-203
```
```{literalinclude} examples/broadcast_matmul.py
:lines: 204-207
```

## Utilizing commute parameter for pattern-matching

Expand Down

0 comments on commit fa47e41

Please sign in to comment.