Skip to content

Commit

Permalink
Fix broadcast rule of expanding dims (#1567)
Browse files Browse the repository at this point in the history
Previous to this PR, no matter input a or b is expanded, they use the
same flag, and that is ambiguous to the following code.
  • Loading branch information
titaiwangms authored May 23, 2024
1 parent dac54d8 commit d316706
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
11 changes: 6 additions & 5 deletions onnxscript/rewriter/broadcast_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def check_if_not_need_reshape(
# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
mimic_matmul_broadcast_behavior_a = False
mimic_matmul_broadcast_behavior_b = False
if a_rank < 2:
if b_rank < 2:
logger.info("Optimization of dot product is not supported yet.")
Expand All @@ -74,7 +75,7 @@ def check_if_not_need_reshape(
else:
input_a_shape = [1, *input_a_shape] # type: ignore[assignment]
a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
mimic_matmul_broadcast_behavior_a = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if b_rank < 2:
Expand All @@ -84,7 +85,7 @@ def check_if_not_need_reshape(
else:
input_b_shape = [*input_b_shape, 1] # type: ignore[assignment]
b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
mimic_matmul_broadcast_behavior_b = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
Expand Down Expand Up @@ -119,10 +120,10 @@ def check_if_not_need_reshape(
*longer_shape[: -len(shorter_shape)],
*broadcast_matmul_output_shape,
]
if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
if mimic_matmul_broadcast_behavior_b and b_rank == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
if mimic_matmul_broadcast_behavior_a and a_rank == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
Expand Down
23 changes: 23 additions & 0 deletions onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,29 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 4)

def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_second_isexpanded_alike_and_broadcastable(
self,
):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[5] input_x, float[5, 1] input_y) => (float[1] output)
{
shape_a = Constant<value: tensor = int64[2] {1, 5}>()
reshape_x = Reshape (input_x, shape_a)
shape_b = Constant<value: tensor = int64[2] {5, 1}>()
reshape_y = Reshape (input_y, shape_b)
matmul = MatMul (reshape_x, reshape_y)
shape_c = Constant<value: tensor = int64[1] {1}>()
output = Reshape (matmul, shape_c)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = broadcast_to_matmul.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 4)

def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable(
self,
):
Expand Down

0 comments on commit d316706

Please sign in to comment.