diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef4c4ae906243..f2631cccfd787 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2238,9 +2238,9 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. output_shape = query_shape if key_shape is not None and len(key_shape) == 3: - value_shape = self._try_get_shape(node, 2) - if value_shape is not None and len(value_shape) == 3: - output_shape[2] = value_shape[2] + # value_shape = self._try_get_shape(node, 2) + # if value_shape is not None and len(value_shape) == 3: + # output_shape[2] = value_shape[2] total_sequence_length = key_shape[1] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type