From b62e65e2a0287043e35db84a639063bd6c492e65 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Tue, 19 Sep 2023 20:28:09 +0900 Subject: [PATCH] Add length checks to fusion_transpose.py --- onnxruntime/python/tools/transformers/fusion_transpose.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 2762d95dd7b00..2ae6d1d0f0386 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -128,8 +128,8 @@ def fuse( return if not ( - self.model.get_constant_value(unsqueeze_3.input[1]) == 3 - and self.model.get_constant_value(unsqueeze_2.input[1]) == 2 + len(unsqueeze_3.input) == 2 and self.model.get_constant_value(unsqueeze_3.input[1]) == 3 + and len(unsqueeze_2.input) == 2 and self.model.get_constant_value(unsqueeze_2.input[1]) == 2 and len(self.model.get_children(gemm, input_name_to_nodes)) == 1 and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1 and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1