From 73694e13f021653448aaea4408eff3a95bb3e44b Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 7 May 2024 15:43:23 -0700 Subject: [PATCH] Fix bugs in matmul condition function and rename tests (#1512) Fix #1505 The condition function in gemm and matmul pattern is trying to validate whether reshapes are redundant. In the previous code, we mimic the matmul behavior to get the final output shape without identifying that when input rank is 2, is it expanded with 1 from one dimension or originally is 2 dimension. --------- Co-authored-by: Justin Chu --- .../rewriter/examples/broadcast_matmul.py | 4 +- docs/tutorial/rewriter/rewrite_patterns.md | 2 +- onnxscript/rewriter/broadcast_to_matmul.py | 15 +++--- onnxscript/rewriter/gemm_to_matmul_add.py | 4 +- .../rewriter/gemm_to_matmul_add_test.py | 46 ++++++++++++++++++- 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py index 80e0c0846..22b374e5b 100644 --- a/docs/tutorial/rewriter/examples/broadcast_matmul.py +++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py @@ -64,7 +64,7 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_): # ===================== -def check_if_need_reshape( +def check_if_not_need_reshape( input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ ) -> bool: """If matmul broadcasting is enough, then we don't need the reshapes. @@ -184,7 +184,7 @@ def apply_rewrite(model): two_reshapes_matmul_reshape_rule = pattern.RewriteRule( two_reshapes_matmul_reshape_pattern, # target pattern matmul_pattern, # replacement pattern - check_if_need_reshape, # match_condition function + check_if_not_need_reshape, # match_condition function ) # Create a Rewrite Rule Set rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule]) diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md index 60b6a9d51..731238044 100644 --- a/docs/tutorial/rewriter/rewrite_patterns.md +++ b/docs/tutorial/rewriter/rewrite_patterns.md @@ -185,7 +185,7 @@ Similarly for writing the condition checking function, we require only `input_a` In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: ```{literalinclude} examples/broadcast_matmul.py -:pyobject: check_if_need_reshape +:pyobject: check_if_not_need_reshape ``` With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 20071f3a8..bc45e06b5 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -11,7 +11,7 @@ # condition to check if we need to replace the pattern -def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: +def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: """If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: @@ -108,9 +108,12 @@ def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: broadcast_matmul_output_shape = ( longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape ) - if mimic_matmul_broadcast_behavior and dim_b == 2: + if mimic_matmul_broadcast_behavior and dim_b == 2 and input_b_shape[-1] == 1: + # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2: + if mimic_matmul_broadcast_behavior and dim_a == 2 and input_a_shape[0] == 1: + # If input_a is expanded to 2-D, then we need to remove the first dimension + # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) if shape_c != broadcast_matmul_output_shape: logger.info( @@ -149,14 +152,14 @@ def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): two_reshapes_matmul_reshape_rule = pattern.RewriteRule( two_reshapes_matmul_reshape_pattern, matmul, - check_if_need_reshape, + check_if_not_need_reshape, ) one_reshape_matmul_reshape_rule = pattern.RewriteRule( one_reshape_matmul_reshape_pattern, matmul, - # We can use the same check_if_need_reshape function for both the rules, + # We can use the same check_if_not_need_reshape function for both the rules, # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. - check_if_need_reshape, + check_if_not_need_reshape, ) # NOTE: The order of the rules is important. Larger pattern should be checked first. diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index cce9865c9..95cb82e30 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -1,5 +1,5 @@ from onnxscript.rewriter import pattern -from onnxscript.rewriter.broadcast_to_matmul import check_if_need_reshape +from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape op = pattern.onnxop @@ -18,4 +18,4 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_need_reshape) +rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/gemm_to_matmul_add_test.py index 337cfa4ca..cb285036b 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/gemm_to_matmul_add_test.py @@ -163,7 +163,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa self.assertEqual(model.graph[2].op_type, "MatMul") self.assertEqual(model.graph[3].op_type, "Add") - def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_not_broadcastable( + def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable( self, ): model_proto = onnx.parser.parse_model( @@ -207,7 +207,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro self.assertEqual(model.graph[2].op_type, "MatMul") self.assertEqual(model.graph[3].op_type, "Add") - def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not_broadcastable( + def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_broadcastable( self, ): model_proto = onnx.parser.parse_model( @@ -228,6 +228,48 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) + def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[3, 5] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + self.assertEqual(replacement_count, 1) + self.assertEqual(len(model.graph), 4) + + def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[5, 3] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = gemm_to_matmul_add.rule.apply_to_model(model) + self.assertEqual(count, 0) + self.assertEqual(len(model.graph), 5) + def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( self, ):