Skip to content

Commit

Permalink
[hotfix] Fix flex attention test (pytorch#1568)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Sep 12, 2024
1 parent f729ce1 commit 4d3e48a
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions tests/torchtune/modules/test_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,16 @@ class TestSDPAOrFlexAttention:
not _SUPPORTS_FLEX_ATTENTION,
reason="Please install a nightly build of torch (>=2.5.0) to run this test.",
)
@mock.patch("torchtune.modules.attention_utils.torch.compile")
@mock.patch("torchtune.modules.attention_utils.compile_friendly_flex_attention")
@mock.patch(
"torchtune.modules.attention_utils.nn.functional.scaled_dot_product_attention"
)
def test_flex_attention(self, mock_sdpa, mock_compile):
mock_flex = mock.MagicMock()
mock_compile.return_value = mock_flex
q = torch.ones(2, 3, 4)
k = torch.ones(2, 3, 4)
v = torch.ones(2, 3, 4)
attn_mask = torch.ones(2, 3, 4)
def test_flex_attention(self, mock_sdpa, mock_flex):
# [b, n_h, s, h_d]
q = torch.ones(2, 1, 3, 4)
k = torch.ones(2, 1, 3, 4)
v = torch.ones(2, 1, 3, 4)
attn_mask = torch.ones(2, 3, 3)
dropout_p = 0.0
is_causal = False

Expand All @@ -131,20 +130,17 @@ def test_flex_attention(self, mock_sdpa, mock_compile):
assert mock_flex.call_count == 1

@mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False)
@mock.patch("torchtune.modules.attention_utils.torch.compile")
@mock.patch(
"torchtune.modules.attention_utils.nn.functional.scaled_dot_product_attention"
)
def test_sdpa_attention(self, mock_sdpa, mock_compile):
mock_flex = mock.MagicMock()
mock_compile.return_value = mock_flex
q = torch.ones(2, 3, 4)
k = torch.ones(2, 3, 4)
v = torch.ones(2, 3, 4)
attn_mask = torch.ones(2, 3, 4)
def test_sdpa_attention(self, mock_sdpa):
# [b, n_h, s, h_d]
q = torch.ones(2, 1, 3, 4)
k = torch.ones(2, 1, 3, 4)
v = torch.ones(2, 1, 3, 4)
attn_mask = torch.ones(2, 3, 3)
dropout_p = 0.0
is_causal = False
_attention_call = _sdpa_or_flex_attention()
_ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal)
mock_sdpa.assert_called_once()
mock_flex.assert_not_called()

0 comments on commit 4d3e48a

Please sign in to comment.