From e9c073ef24d119855eedfb423cc2e081faf5c5c9 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 21 Jan 2024 01:42:14 -0800 Subject: [PATCH] Temporary fix for MHA shape infer to allow GQA --- onnxruntime/python/tools/symbolic_shape_infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef4c4ae906243..f2631cccfd787 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2238,9 +2238,9 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. output_shape = query_shape if key_shape is not None and len(key_shape) == 3: - value_shape = self._try_get_shape(node, 2) - if value_shape is not None and len(value_shape) == 3: - output_shape[2] = value_shape[2] + # value_shape = self._try_get_shape(node, 2) + # if value_shape is not None and len(value_shape) == 3: + # output_shape[2] = value_shape[2] total_sequence_length = key_shape[1] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type