diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py index 80e0c0846..22b374e5b 100644 --- a/docs/tutorial/rewriter/examples/broadcast_matmul.py +++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py @@ -64,7 +64,7 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_): # ===================== -def check_if_need_reshape( +def check_if_not_need_reshape( input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ ) -> bool: """If matmul broadcasting is enough, then we don't need the reshapes. @@ -184,7 +184,7 @@ def apply_rewrite(model): two_reshapes_matmul_reshape_rule = pattern.RewriteRule( two_reshapes_matmul_reshape_pattern, # target pattern matmul_pattern, # replacement pattern - check_if_need_reshape, # match_condition function + check_if_not_need_reshape, # match_condition function ) # Create a Rewrite Rule Set rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule]) diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md index 60b6a9d51..731238044 100644 --- a/docs/tutorial/rewriter/rewrite_patterns.md +++ b/docs/tutorial/rewriter/rewrite_patterns.md @@ -185,7 +185,7 @@ Similarly for writing the condition checking function, we require only `input_a` In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: ```{literalinclude} examples/broadcast_matmul.py -:pyobject: check_if_need_reshape +:pyobject: check_if_not_need_reshape ``` With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 20071f3a8..bc45e06b5 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -11,7 +11,7 @@ # condition to check if we need to replace the pattern -def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: +def check_if_not_need_reshape(input_a, input_b, shape_c, **_) -> bool: """If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: @@ -108,9 +108,12 @@ def check_if_need_reshape(input_a, input_b, shape_c, **_) -> bool: broadcast_matmul_output_shape = ( longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape ) - if mimic_matmul_broadcast_behavior and dim_b == 2: + if mimic_matmul_broadcast_behavior and dim_b == 2 and input_b_shape[-1] == 1: + # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2: + if mimic_matmul_broadcast_behavior and dim_a == 2 and input_a_shape[0] == 1: + # If input_a is expanded to 2-D, then we need to remove the first dimension + # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) if shape_c != broadcast_matmul_output_shape: logger.info( @@ -149,14 +152,14 @@ def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): two_reshapes_matmul_reshape_rule = pattern.RewriteRule( two_reshapes_matmul_reshape_pattern, matmul, - check_if_need_reshape, + check_if_not_need_reshape, ) one_reshape_matmul_reshape_rule = pattern.RewriteRule( one_reshape_matmul_reshape_pattern, matmul, - # We can use the same check_if_need_reshape function for both the rules, + # We can use the same check_if_not_need_reshape function for both the rules, # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. - check_if_need_reshape, + check_if_not_need_reshape, ) # NOTE: The order of the rules is important. Larger pattern should be checked first. diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py index cce9865c9..95cb82e30 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -1,5 +1,5 @@ from onnxscript.rewriter import pattern -from onnxscript.rewriter.broadcast_to_matmul import check_if_need_reshape +from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape op = pattern.onnxop @@ -18,4 +18,4 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_need_reshape) +rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/gemm_to_matmul_add_test.py index 337cfa4ca..cb285036b 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/gemm_to_matmul_add_test.py @@ -163,7 +163,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa self.assertEqual(model.graph[2].op_type, "MatMul") self.assertEqual(model.graph[3].op_type, "Add") - def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_not_broadcastable( + def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable( self, ): model_proto = onnx.parser.parse_model( @@ -207,7 +207,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro self.assertEqual(model.graph[2].op_type, "MatMul") self.assertEqual(model.graph[3].op_type, "Add") - def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not_broadcastable( + def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_broadcastable( self, ): model_proto = onnx.parser.parse_model( @@ -228,6 +228,48 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) + def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[3, 5] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + self.assertEqual(replacement_count, 1) + self.assertEqual(len(model.graph), 4) + + def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[5, 3] input_x, float[5, 10] input_y, float[3, 10] input_z) => (float[3, 10] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = gemm_to_matmul_add.rule.apply_to_model(model) + self.assertEqual(count, 0) + self.assertEqual(len(model.graph), 5) + def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( self, ):