Skip to content

Commit

Permalink
Temporary fix for MHA shape infer to allow GQA
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Jan 21, 2024
1 parent bdd56f2 commit e9c073e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,9 +2238,9 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802
# By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
output_shape = query_shape
if key_shape is not None and len(key_shape) == 3:
value_shape = self._try_get_shape(node, 2)
if value_shape is not None and len(value_shape) == 3:
output_shape[2] = value_shape[2]
# value_shape = self._try_get_shape(node, 2)
# if value_shape is not None and len(value_shape) == 3:
# output_shape[2] = value_shape[2]
total_sequence_length = key_shape[1]

output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
Expand Down

0 comments on commit e9c073e

Please sign in to comment.