Skip to content

Commit

Permalink
[shardformer] fix flash attention test utils
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Jul 5, 2023
1 parent 190a6ea commit 0297665
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/test_utils/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):

@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
@parameterize('proj_shape', [(6, 8, 4, 16)])
def test_attention_gpt(proj_shape, dtype=torch.float16):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
Expand All @@ -45,8 +46,9 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):

@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
@parameterize('proj_shape', [(6, 8, 4, 16)])
def test_attention_bert(proj_shape, dtype=torch.float16):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
Expand All @@ -69,8 +71,9 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):

@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
@parameterize('proj_shape', [(6, 8, 4, 16)])
def test_attention_no_mask(proj_shape, dtype=torch.float16):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
Expand All @@ -89,8 +92,9 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):

@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
def test_cross_attention(proj_shape, dtype=torch.float16):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD

q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
Expand All @@ -110,4 +114,4 @@ def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
assert list(y.shape) == [B, T, D]

dy = torch.rand_like(y)
y.backward(dy)
y.backward(dy)

0 comments on commit 0297665

Please sign in to comment.