Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Memory Efficient Kernel #17920

Merged
merged 20 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
"bert/group_query_attention_helper.h"
"bert/group_query_attention.h"
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
"bert/group_query_attention_impl.cu"
)

if (NOT onnxruntime_ENABLE_ATEN)
Expand Down
6 changes: 3 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2421,14 +2421,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
</dl>

#### Outputs (1 - 3)
#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dt><tt>present_key</tt> : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dt><tt>present_value</tt> : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ Status EfficientAttention(
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
p.max_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
Expand All @@ -393,6 +394,7 @@ Status EfficientAttention(
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.is_kv_bsnh = true;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
? data.scratch
: nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.num_keys = params.kv_sequence_length;

if (params.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
p.custom_mask_type = Attention::CausalFromBottomRight;
}

// Input format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
// We use max_sequence_length to calculate KV stride
if (params.is_kv_bsnh) {
// Input Q, K, V format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.max_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
} else {
// Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.max_sequence_length * params.qk_head_size;
p.v_strideH = params.max_sequence_length * params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;

p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.qk_head_size;
p.v_strideM = params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;

p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length;
p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length;
p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}
}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
Expand Down
73 changes: 57 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention.h"
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
// #include "contrib_ops/cpu/utils/console_dumper.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -55,6 +54,13 @@
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);

Check warning on line 60 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc#L60

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:60:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
#else
disable_memory_efficient_attention_ = true;
#endif
}

template <typename T>
Expand Down Expand Up @@ -92,18 +98,6 @@
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> present_dims;
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported(device_prop,
Expand Down Expand Up @@ -149,8 +143,47 @@
auto seqlens_k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

// only kernel implemented for gqa right now
ORT_ENFORCE(use_flash_attention);
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
// allocate buffers
size_t kv_buffer_bytes = 0;
// need a buffer if we must ungroup kv
const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads);
if (use_memory_efficient_attention && needs_buff) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size);

Check warning on line 160 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc#L160

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:160:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {

Check warning on line 163 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc#L163

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:163:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));

Check warning on line 164 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc#L164

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:164:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

std::vector<int64_t> present_dims;

Check warning on line 176 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc#L176

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:176:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
present_dims = {
parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
} else { // BNSH
present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
}
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);

data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
Expand All @@ -161,6 +194,7 @@
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
Expand All @@ -173,6 +207,13 @@
if (seqlens_k_buffer != nullptr) {
data.seqlens_k = reinterpret_cast<int*>(seqlens_k_buffer.get());
}
if (k_buffer != nullptr) {
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
}
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel {
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
};

} // namespace cuda
Expand Down
88 changes: 38 additions & 50 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
// query (Q) : (B, S, D)
// key (K) : (B, S+, D_kv)
// value (V) : (B, S+, D_kv)
ORT_UNUSED_PARAMETER(value);
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved

AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BSNH;

const auto& query_dims = query->Shape().GetDims();
const auto& key_dims = key->Shape().GetDims();
const auto& value_dims = value->Shape().GetDims();

if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
Expand All @@ -47,10 +47,8 @@
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = static_cast<int>(q_hidden_size) / num_heads;

int kv_sequence_length = sequence_length;
int kv_hidden_size = (key_dims.size() == 3)
? static_cast<int>(key_dims[2])
: (kv_num_heads * static_cast<int>(key_dims[3]));
int kv_sequence_length = static_cast<int>(key_dims[1]);
int kv_hidden_size = static_cast<int>(key_dims[2]);

int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
Expand Down Expand Up @@ -134,63 +132,49 @@
"Input 'past_key' and 'past_value' shall be both present or both absent");
}

if (key != nullptr) {
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
if (key_dims[2] != value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)");
}

qkv_format = Q_K_V_BSNH;
kv_sequence_length = static_cast<int>(key_dims[1]);
} else {
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing key tensor.");
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}

if (value != nullptr) {
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}

if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
}

if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
} else {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Missing value tensor.");
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}

if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}

tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly.
int32_t past_sequence_length = 0;
int present_sequence_length = 0;
int present_sequence_length = kv_sequence_length;
if (past_seq_len != nullptr) {
if (past_key == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Past KV must be present as share-buffer when using past_seq_len pointer.");
}
if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"past_sequence_length tensor must be of one element when using past kv.");
Expand All @@ -200,6 +184,10 @@
} else {
past_sequence_length = static_cast<int32_t>(*((*past_seq_len).template Data<int64_t>()));
}
if (past_sequence_length + kv_sequence_length > max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length");

Check warning on line 189 in onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h#L189

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h:189:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
present_sequence_length = max_sequence_length;
} else if (past_key != nullptr) {
past_sequence_length = max_sequence_length; // this is the length of past_key tensor
Expand Down
Loading
Loading