From d31670646cc2c886048336f6da52932e008e8c1b Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 23 May 2024 10:46:27 -0700 Subject: [PATCH] Fix broadcast rule of expanding dims (#1567) Previous to this PR, no matter input a or b is expanded, they use the same flag, and that is ambiguous to the following code. --- onnxscript/rewriter/broadcast_to_matmul.py | 11 +++++---- .../rewriter/broadcast_to_matmul_test.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index ead1bbada..da12ae3ad 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -63,7 +63,8 @@ def check_if_not_need_reshape( # 1. Check if input shapes are broadcastable # 1.a. If the first input is 1-D, check whether # the dim matches the last second dim of the second input. - mimic_matmul_broadcast_behavior = False + mimic_matmul_broadcast_behavior_a = False + mimic_matmul_broadcast_behavior_b = False if a_rank < 2: if b_rank < 2: logger.info("Optimization of dot product is not supported yet.") @@ -74,7 +75,7 @@ def check_if_not_need_reshape( else: input_a_shape = [1, *input_a_shape] # type: ignore[assignment] a_rank = len(input_a_shape) - mimic_matmul_broadcast_behavior = True + mimic_matmul_broadcast_behavior_a = True # 1.b. If the second input is 1-D, check whether # the dim matches the last dim of the first input. if b_rank < 2: @@ -84,7 +85,7 @@ def check_if_not_need_reshape( else: input_b_shape = [*input_b_shape, 1] # type: ignore[assignment] b_rank = len(input_b_shape) - mimic_matmul_broadcast_behavior = True + mimic_matmul_broadcast_behavior_b = True # 1.c. If both inputs are at least 2-D, check whether # the last dimension of the first input matches the second # last dimension of the second input, and shape[:-2] are @@ -119,10 +120,10 @@ def check_if_not_need_reshape( *longer_shape[: -len(shorter_shape)], *broadcast_matmul_output_shape, ] - if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1: + if mimic_matmul_broadcast_behavior_b and b_rank == 2 and input_b_shape[-1] == 1: # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1: + if mimic_matmul_broadcast_behavior_a and a_rank == 2 and input_a_shape[0] == 1: # If input_a is expanded to 2-D, then we need to remove the first dimension # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index cc390d7a3..4f7aecae8 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -251,6 +251,29 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) + def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_second_isexpanded_alike_and_broadcastable( + self, + ): + model_proto = onnx.parser.parse_model( + """ + + agraph (float[5] input_x, float[5, 1] input_y) => (float[1] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = broadcast_to_matmul.rules.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 4) + def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable( self, ):