Skip to content

Commit

Permalink
add unit test and fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Oct 12, 2023
1 parent ef82d4d commit 54f2526
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 94 deletions.
13 changes: 0 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,17 @@ namespace cuda {

template <typename T>
struct GroupQueryAttentionData {
<<<<<<< HEAD
=======
// Input Tensors
>>>>>>> aciddelgado/gqa_memeff
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
<<<<<<< HEAD
=======
// Flash buffers
>>>>>>> aciddelgado/gqa_memeff
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;
int* seqlens_k = nullptr;
<<<<<<< HEAD
T* output = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
bool use_flash_attention = false;
=======
// Memory Efficient buffers
T* fmha_buffer = nullptr;
T* k = nullptr;
Expand All @@ -48,7 +36,6 @@ struct GroupQueryAttentionData {
// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
>>>>>>> aciddelgado/gqa_memeff
};

template <typename T>
Expand Down
244 changes: 163 additions & 81 deletions onnxruntime/test/python/transformers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import os
import math
import random
import unittest

import numpy
import torch
Expand Down Expand Up @@ -1157,85 +1159,165 @@ def parity_check_gqa_past_no_buff(
)


class TestMHA(unittest.TestCase):
def test_packed_mha(self):
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST PACKED MHA ---------")
for b in [5]:
for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s, 0, n, n, h)
parity_check_mha(config, True)

def test_mha(self):
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST MHA ---------")
for b in [5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s2, 0, n, n, h)
parity_check_mha(config, False)


class TestGQA(unittest.TestCase):
def test_gqa_no_past(self):
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return
torch.manual_seed(69)
print("-------- TEST GQA ---------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
for b in [1, 5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True, False]:
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)
if major < 8:
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
for b in [1, 5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1024, 1024),
(1023, 1024),
(2048, 2048),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True, False]:
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)

def test_gqa_past(self):
major, _ = torch.cuda.get_device_capability()
if major < 8:
return
print("-------- TEST GQA PAST ---------")
random.seed(69)
for b in [2]:
for s, s2 in [
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 512),
(16, 128 * 512),
(128, 128),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
if major < 8:

Check warning

Code scanning / CodeQL

Redundant comparison Warning test

Test is always false, because of
this condition
.
return

Check warning

Code scanning / CodeQL

Unreachable code Warning test

This statement is unreachable.
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
for b in [2]:
for s, s2 in [
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(128, 128),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)


if __name__ == "__main__":
print("-------- TEST PACKED MHA ---------")
for b in [5]:
for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s, 0, n, n, h)
parity_check_mha(config, True)
print("-------- TEST MHA ---------")
for b in [5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
]:
for n in [6]:
for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]:
config = Config(b, s, s2, 0, n, n, h)
parity_check_mha(config, False)
print("-------- TEST GQA ---------")
for b in [5]:
for s, s2 in [
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
# (512, 256),
(1024, 1024),
(1023, 1024),
# (1024, 1023),
(2048, 2048),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True, False]:
torch.manual_seed(69)
config = Config(b, s, s2, 0, n, n2, h)
parity_check_gqa_no_past(config, causal=causal)
print("-------- TEST GQA PAST ---------")
random.seed(69)
for b in [2]:
for s, s2 in [
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 512),
(16, 128 * 512),
(128, 128),
]:
for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]:
for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]:
for causal in [True]:
for past_kv_format in [Formats.BNSH, Formats.BSNH]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
parity_check_gqa_past(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
parity_check_gqa_past_no_buff(
config,
causal=causal,
past_format=past_kv_format,
rtol=1e-3,
atol=1e-3,
)
unittest.main()

0 comments on commit 54f2526

Please sign in to comment.