Skip to content

Commit

Permalink
Fix bugs in matmul condition function and rename tests (#1512)
Browse files Browse the repository at this point in the history
Fix #1505 

The condition function in gemm and matmul pattern is trying to validate
whether reshapes are redundant. In the previous code, we mimic the
matmul behavior to get the final output shape without identifying that
when input rank is 2, is it expanded with 1 from one dimension or
originally is 2 dimension.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
titaiwangms and justinchuby authored May 7, 2024
1 parent ea1eda9 commit 73694e1
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/tutorial/rewriter/examples/broadcast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
# =====================


def check_if_need_reshape(
def check_if_not_need_reshape(
input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
Expand Down Expand Up @@ -184,7 +184,7 @@ def apply_rewrite(model):
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_pattern, # replacement pattern
check_if_need_reshape, # match_condition function
check_if_not_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Similarly for writing the condition checking function, we require only `input_a`
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows:

```{literalinclude} examples/broadcast_matmul.py
:pyobject: check_if_need_reshape
:pyobject: check_if_not_need_reshape
```

With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite.
Expand Down
15 changes: 9 additions & 6 deletions onnxscript/rewriter/broadcast_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# condition to check if we need to replace the pattern
def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool:
def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool:
"""If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
Expand Down Expand Up @@ -108,9 +108,12 @@ def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool:
broadcast_matmul_output_shape = (
longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape
)
if mimic_matmul_broadcast_behavior and dim_b == 2:
if mimic_matmul_broadcast_behavior and dim_b == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and dim_a == 2:
if mimic_matmul_broadcast_behavior and dim_a == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
Expand Down Expand Up @@ -149,14 +152,14 @@ def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c):
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern,
matmul,
check_if_need_reshape,
check_if_not_need_reshape,
)
one_reshape_matmul_reshape_rule = pattern.RewriteRule(
one_reshape_matmul_reshape_pattern,
matmul,
# We can use the same check_if_need_reshape function for both the rules,
# We can use the same check_if_not_need_reshape function for both the rules,
# as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern.
check_if_need_reshape,
check_if_not_need_reshape,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/gemm_to_matmul_add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from onnxscript.rewriter import pattern
from onnxscript.rewriter.broadcast_to_matmul import check_if_need_reshape
from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape

op = pattern.onnxop

Expand All @@ -18,4 +18,4 @@ def matmul_add(op, input_a, input_b, input_c, **_):
return op.Add(matmul, input_c)


rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_need_reshape)
rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape)
46 changes: 44 additions & 2 deletions onnxscript/rewriter/gemm_to_matmul_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa
self.assertEqual(model.graph[2].op_type, "MatMul")
self.assertEqual(model.graph[3].op_type, "Add")

def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_not_broadcastable(
def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable(
self,
):
model_proto = onnx.parser.parse_model(
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro
self.assertEqual(model.graph[2].op_type, "MatMul")
self.assertEqual(model.graph[3].op_type, "Add")

def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not_broadcastable(
def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_broadcastable(
self,
):
model_proto = onnx.parser.parse_model(
Expand All @@ -228,6 +228,48 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not
self.assertEqual(count, 0)
self.assertEqual(len(model.graph), 5)

def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broadcastable(
self,
):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[3, 5] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output)
{
shape_a = Constant<value: tensor = int64[2] {3, 5}>()
reshape_x = Reshape (input_x, shape_a)
gemm = Gemm<alpha=1.0, beta=1.0> (reshape_x, input_y, input_z)
shape_d = Constant<value: tensor = int64[2] {3, 10}>()
output = Reshape (gemm, shape_d)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
replacement_count = gemm_to_matmul_add.rule.apply_to_model(model)
self.assertEqual(replacement_count, 1)
self.assertEqual(len(model.graph), 4)

def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broadcastable(
self,
):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[5, 3] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output)
{
shape_a = Constant<value: tensor = int64[2] {3, 5}>()
reshape_x = Reshape (input_x, shape_a)
gemm = Gemm<alpha=1.0, beta=1.0> (reshape_x, input_y, input_z)
shape_d = Constant<value: tensor = int64[2] {3, 10}>()
output = Reshape (gemm, shape_d)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = gemm_to_matmul_add.rule.apply_to_model(model)
self.assertEqual(count, 0)
self.assertEqual(len(model.graph), 5)

def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted(
self,
):
Expand Down

0 comments on commit 73694e1

Please sign in to comment.