Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Memory Efficient Kernel #17920

Merged
merged 20 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ size_t GetSequenceOffsetSize(int batch_size, bool has_padding) {
// There are batch_size + 1 offsets Without padding (or padding removed), and 2 * batch_size + 1 with padding.
size_t bytes = sizeof(int) * ((has_padding ? 2 * batch_size : batch_size) + 1);
return AlignSize(bytes);
;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
}

size_t GetAttentionWorkspaceSize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.num_keys = params.kv_sequence_length;

if (params.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
p.custom_mask_type = Attention::CausalFromBottomRight;
}

// Input format is BxSxNxH, output is BxSxNxH
Expand Down
74 changes: 60 additions & 14 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// #include "contrib_ops/cpu/utils/console_dumper.h"
Expand Down Expand Up @@ -55,6 +56,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
}

template <typename T>
Expand Down Expand Up @@ -92,18 +100,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
Expand Down Expand Up @@ -149,8 +145,50 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

// only kernel implemented for gqa right now
ORT_ENFORCE(use_flash_attention);
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
// allocate buffers
size_t kv_buffer_bytes = 0;
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads) || (parameters.past_kv_format != AttentionQkvFormat::Q_K_V_BSNH) || (past_key != nullptr && parameters.present_sequence_length != parameters.past_sequence_length + parameters.kv_sequence_length);
if (use_memory_efficient_attention && needs_buff) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
if (past_key == nullptr) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.kv_sequence_length * parameters.head_size);
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
} else {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
}
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
Expand All @@ -161,6 +199,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
Expand All @@ -173,6 +212,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (seqlens_k_buffer != nullptr) {
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
}
if (k_buffer != nullptr) {
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
}
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
};

} // namespace cuda
Expand Down
Loading