From 0297665b976223421c1846b79554315697430f38 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 5 Jul 2023 18:27:31 +0800 Subject: [PATCH] [shardformer] fix flash attention test utils --- tests/test_utils/test_flash_attention.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 7a28b0157384..2334c84dc778 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -24,8 +24,9 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 8, 4, 16)]) +def test_attention_gpt(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -45,8 +46,9 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 8, 4, 16)]) +def test_attention_bert(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -69,8 +71,9 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) -def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 8, 4, 16)]) +def test_attention_no_mask(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -89,8 +92,9 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) -def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 24, 8, 4, 16)]) +def test_cross_attention(proj_shape, dtype=torch.float16): + (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") @@ -110,4 +114,4 @@ def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): assert list(y.shape) == [B, T, D] dy = torch.rand_like(y) - y.backward(dy) + y.backward(dy) \ No newline at end of file