Skip to content

Commit

Permalink
GQA Memory Efficient Kernel (microsoft#17920)
Browse files Browse the repository at this point in the history
Implement Cutlass Memory Efficient Attention Kernel into Group Query
Attention Operator.

### Motivation and Context
Before this change, Group Query Attention Operator was supported only by
Flash-Attention. While this is the most efficient kernel for the
operation, it only supports sm >= 80. Cutlass Memory Efficient Attention
Kernel supports sm >= 53, allowing us to support a broader range of GPU
hardware.
  • Loading branch information
aciddelgado authored and kleiti committed Mar 22, 2024
1 parent aca0171 commit 9a2ba9d
Show file tree
Hide file tree
Showing 14 changed files with 843 additions and 312 deletions.
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
"bert/group_query_attention_helper.h"
"bert/group_query_attention.h"
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
"bert/group_query_attention_impl.cu"
)

if (NOT onnxruntime_ENABLE_ATEN)
Expand Down
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2422,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
</dl>

#### Outputs (1 - 3)
#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dt><tt>present_key</tt> : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dt><tt>present_value</tt> : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ Status EfficientAttention(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
p.max_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
Expand All @@ -395,6 +396,7 @@ Status EfficientAttention(
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.num_keys = params.kv_sequence_length;

if (params.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
p.custom_mask_type = Attention::CausalFromBottomRight;
}

// Input format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
// We use max_sequence_length to calculate KV stride
if (params.is_kv_bsnh) {
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
} else {
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.max_sequence_length * params.qk_head_size;
p.v_strideH = params.max_sequence_length * params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.qk_head_size;
p.v_strideM = params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}
}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
Expand Down
73 changes: 57 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
// #include "contrib_ops/cpu/utils/console_dumper.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -55,6 +54,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
}

template <typename T>
Expand Down Expand Up @@ -92,18 +98,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
Expand Down Expand Up @@ -143,8 +137,47 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

// only kernel implemented for gqa right now
ORT_ENFORCE(use_flash_attention);
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
// allocate buffers
size_t kv_buffer_bytes = 0;
// need a buffer if we must ungroup kv
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
if (use_memory_efficient_attention && needs_buff) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
Expand All @@ -155,6 +188,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
Expand All @@ -167,6 +201,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (seqlens_k_buffer != nullptr) {
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
}
if (k_buffer != nullptr) {
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
}
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
};

} // namespace cuda
Expand Down
88 changes: 38 additions & 50 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query,
// query (Q) : (B, S, D)
// key (K) : (B, S+, D_kv)
// value (V) : (B, S+, D_kv)
ORT_UNUSED_PARAMETER(value);

AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;

const auto& query_dims = query->Shape().GetDims();
const auto& key_dims = key->Shape().GetDims();
const auto& value_dims = value->Shape().GetDims();

if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
Expand All @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query,
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = static_cast<int>(q_hidden_size) / num_heads;

int kv_sequence_length = sequence_length;
int kv_hidden_size = (key_dims.size() == 3)
? static_cast<int>(key_dims[2])
: (kv_num_heads * static_cast<int>(key_dims[3]));
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_hidden_size = static_cast<int>(key_dims[2]);

int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
Expand Down Expand Up @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent");
}

if (key != nullptr) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
if (key_dims[2] != value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
}

qkv_format = Q_K_V_BSNH;
kv_sequence_length = static_cast<int>(key_dims[1]);
} else {
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing key tensor.");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

if (value != nullptr) {
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}

if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}

if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
} else {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing value tensor.");
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}

if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}

// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
int32_t past_sequence_length = 0;
int present_sequence_length = 0;
int present_sequence_length = kv_sequence_length;
if (past_seq_len != nullptr) {
if (past_key == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Past KV must be present as share-buffer when using past_seq_len pointer.");
}
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"past_sequence_length tensor must be of one element when using past kv.");
Expand All @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query,
} else {
past_sequence_length = static_cast<int32_t>(*((*past_seq_len).template Data<int64_t>()));
}
if (past_sequence_length + kv_sequence_length > max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");
}
present_sequence_length = max_sequence_length;
} else if (past_key != nullptr) {
past_sequence_length = max_sequence_length; // this is the length of past_key tensor
Expand Down
Loading

0 comments on commit 9a2ba9d

Please sign in to comment.