diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 2762d95dd7b00..2ae6d1d0f0386 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -128,8 +128,8 @@ def fuse( return if not ( - self.model.get_constant_value(unsqueeze_3.input[1]) == 3 - and self.model.get_constant_value(unsqueeze_2.input[1]) == 2 + len(unsqueeze_3.input) == 2 and self.model.get_constant_value(unsqueeze_3.input[1]) == 3 + and len(unsqueeze_2.input) == 2 and self.model.get_constant_value(unsqueeze_2.input[1]) == 2 and len(self.model.get_children(gemm, input_name_to_nodes)) == 1 and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1 and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1