From 54f252647288a48545b0b6db4474cbff420a8f0c Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 12 Oct 2023 16:30:02 -0700 Subject: [PATCH] add unit test and fix build --- .../cuda/bert/group_query_attention_impl.h | 13 - .../python/transformers/test_flash_attn.py | 244 ++++++++++++------ 2 files changed, 163 insertions(+), 94 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index abc031b69493e..8412631078e6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,29 +14,17 @@ namespace cuda { template struct GroupQueryAttentionData { -<<<<<<< HEAD -======= // Input Tensors ->>>>>>> aciddelgado/gqa_memeff const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; -<<<<<<< HEAD -======= // Flash buffers ->>>>>>> aciddelgado/gqa_memeff T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; -<<<<<<< HEAD - T* output = nullptr; - T* present_key = nullptr; - T* present_value = nullptr; - bool use_flash_attention = false; -======= // Memory Efficient buffers T* fmha_buffer = nullptr; T* k = nullptr; @@ -48,7 +36,6 @@ struct GroupQueryAttentionData { // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; ->>>>>>> aciddelgado/gqa_memeff }; template diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 7d639ee8d5fb3..79389bd628076 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -9,8 +9,10 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # ------------------------------------------------------------------------- +import os import math import random +import unittest import numpy import torch @@ -1157,85 +1159,165 @@ def parity_check_gqa_past_no_buff( ) +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + for b in [5]: + for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: + for n in [6]: + for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + for b in [5]: + for s, s2 in [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ]: + for n in [6]: + for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + torch.manual_seed(69) + print("-------- TEST GQA ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in [1, 5]: + for s, s2 in [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + if major < 8: + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in [1, 5]: + for s, s2 in [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1024, 1024), + (1023, 1024), + (2048, 2048), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + + def test_gqa_past(self): + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST GQA PAST ---------") + random.seed(69) + for b in [2]: + for s, s2 in [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8: + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in [2]: + for s, s2 in [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (128, 128), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == "__main__": - print("-------- TEST PACKED MHA ---------") - for b in [5]: - for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, 0, n, n, h) - parity_check_mha(config, True) - print("-------- TEST MHA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, 0, n, n, h) - parity_check_mha(config, False) - print("-------- TEST GQA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - # (512, 256), - (1024, 1024), - (1023, 1024), - # (1024, 1023), - (2048, 2048), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True, False]: - torch.manual_seed(69) - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) - print("-------- TEST GQA PAST ---------") - random.seed(69) - for b in [2]: - for s, s2 in [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + unittest.main()