Skip to content

Commit

Permalink
More flexible output_shape computation in keras.layers.MultiHeadAtten…
Browse files Browse the repository at this point in the history
…tion (#20503)

* Made the compute_output_shape method more flexible; now _output_shape can be either an integer or a tuple (as previously required).
Fix discussed in #19769

* Added unit test

* Minor changes to comments in unit test

* Minor changes to comments in unit test
  • Loading branch information
lcs-crr authored Nov 18, 2024
1 parent 70b7044 commit 8a79442
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ def compute_output_shape(
)

if self._output_shape:
return query_shape[:-1] + self._output_shape
if isinstance(self._output_shape, tuple):
return query_shape[:-1] + self._output_shape
else:
return query_shape[:-1] + (self._output_shape,)

return query_shape

Expand Down
25 changes: 25 additions & 0 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras.src import testing
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.multi_head_attention import MultiHeadAttention


class MultiHeadAttentionTest(testing.TestCase):
Expand Down Expand Up @@ -593,3 +594,27 @@ def test_flash_attention_numerical_correctness(self):
)

self.assertAllClose(output_with_flash, output_without_flash)




def test_multi_head_attention_output_shape_as_int():
"""Test MultiHeadAttention with output_shape as an int."""
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8)
query = random.uniform((2, 4, 16))
value = random.uniform((2, 4, 16))
output = mha(query=query, value=value)

assert output.shape == (2, 4, 8), (f"Expected shape (2, 4, 8),"
f" got {output.shape}")


def test_multi_head_attention_output_shape_as_tuple():
"""Test MultiHeadAttention with output_shape as a tuple."""
mha = MultiHeadAttention(num_heads=2, key_dim=16, output_shape=(8, 8))
query = random.uniform((2, 4, 16))
value = random.uniform((2, 4, 16))
output = mha(query=query, value=value)

assert output.shape == (2, 4, 8, 8), (f"Expected shape (2, 4, 8, 8),"
f" got {output.shape}")

0 comments on commit 8a79442

Please sign in to comment.