Skip to content

Commit

Permalink
Packed QKV and Rotary Embedding Support for sm<80 GQA (#20012)
Browse files Browse the repository at this point in the history
### Description
Add support for packed qkv input and rotary embedding with sm<80 using
memory efficient attention kernel.



### Motivation and Context
Allows lower-end gpus to run gqa with packed qkv input and rotary
embedding.
  • Loading branch information
aciddelgado authored and rachguo committed Mar 25, 2024
1 parent 6d96c65 commit 6469bb5
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 64 deletions.
23 changes: 20 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
Expand All @@ -172,18 +170,31 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_memory_efficient_attention && needs_buff) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size);
}
size_t rotary_buffer_bytes = 0;
if (use_memory_efficient_attention && do_rotary_) {
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
size_t unpacked_qkv_bytes = 0;
if (use_memory_efficient_attention && parameters.is_packed_qkv) {
unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

// seqlens_k buffer
Expand Down Expand Up @@ -251,7 +262,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (unpacked_qkv_buffer != nullptr) {
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
if (rotary_buffer != nullptr) {
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
}
// Rotary Embedding
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
Expand Down
Loading

0 comments on commit 6469bb5

Please sign in to comment.