diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 9bc2bdd208a92..4140eeee0d111 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -94,6 +94,11 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ed1049b0bd73a..8e86862a62e7d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2422,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key (optional) : T
+
present_key : T
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value (optional) : T
+
present_value : T
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eb9e6d5c62467..16ce3a899fb5e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -374,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -395,6 +396,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca332..51c3d3d3a458b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf89..f16567bb6f2b7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 67d750aeac11a..8694dc998c7a8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -6,9 +6,8 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" -// #include "contrib_ops/cpu/utils/console_dumper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_flash_attention_ = true; #endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, @@ -143,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - // only kernel implemented for gqa right now - ORT_ENFORCE(use_flash_attention); +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); @@ -155,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -167,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (seqlens_k_buffer != nullptr) { data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 72c9814fad670..a90418ec2243a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; float scale_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index be8f5ca0ae3e9..8c21de9ced058 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) // key (K) : (B, S+, D_kv) // value (V) : (B, S+, D_kv) + ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BSNH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); - const auto& value_dims = value->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = sequence_length; - int kv_hidden_size = (key_dims.size() == 3) - ? static_cast(key_dims[2]) - : (kv_num_heads * static_cast(key_dims[3])); + int kv_sequence_length = static_cast(key_dims[1]); + int kv_hidden_size = static_cast(key_dims[2]); int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - if (key_dims[2] != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else { + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing key tensor."); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing value tensor."); + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. int32_t past_sequence_length = 0; - int present_sequence_length = 0; + int present_sequence_length = kv_sequence_length; if (past_seq_len != nullptr) { + if (past_key == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past KV must be present as share-buffer when using past_seq_len pointer."); + } if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "past_sequence_length tensor must be of one element when using past kv."); @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query, } else { past_sequence_length = static_cast(*((*past_seq_len).template Data())); } + if (past_sequence_length + kv_sequence_length > max_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); + } present_sequence_length = max_sequence_length; } else if (past_key != nullptr) { past_sequence_length = max_sequence_length; // this is the length of past_key tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ab3029ca34886..0455825c364a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -47,6 +48,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +////////// Auxiliary Kernels for KV prep + // Kernel for seqlens_k __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { int id = blockDim.x * blockIdx.x + threadIdx.x; @@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int present_head_stride = is_bsnh ? H : present_seqlen * H; // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH + // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L const int past_seqlen = present_seqlen - new_seqlen; @@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, } } +// Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int H, + const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, const bool is_bsnh) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; - const int present_seqlen = gridDim.x; - const int num_heads = blockDim.y; - const int thread_stride = blockDim.x; - - const int present_batch_stride = present_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; - - // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH - // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; - - while (h < H) { int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { const int past_batch_stride = past_seqlen * num_heads * H; @@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset]; } - h += thread_stride; } } +// Concat new to past in present. Supports past BSNH or past BNSH template -Status QkvToContext( +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int past_seqlen, + const int present_seqlen, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int past_seqlen, + const int present_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, + cudaStream_t stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { - assert(data.use_flash_attention); + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if USE_FLASH_ATTENTION - auto stream = static_cast(ort_stream->GetHandle()); + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key != nullptr && data.past_key == data.present_key) { + // Share buffer case + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else { + // Not share buffer or no past (prompt generation) + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.kv_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; const int present_sequence_length = parameters.present_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; - if (data.use_flash_attention) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(parameters.num_heads % parameters.kv_num_heads == 0); - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = parameters.is_unidirectional; - - if (data.past_key == nullptr && data.present_key == nullptr) { - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), - parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, - parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - - } else if (data.past_key == data.present_key) { - // Assume past and present kv share buffer. - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - assert(parameters.past_sequence_length >= 0); - assert(data.past_value != nullptr); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); - - } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (head_size % 4 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); - } - const int H = head_size / 4; - if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + if (data.past_key != nullptr) { + // Past key case + // concatenate new kv to past kv + if (data.past_key == data.present_key) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); } + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + } else if (num_heads == kv_num_heads) { + // no past or present and no need to ungroup... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // intermediate buffer so q and kv have same num heads... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, + kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, + max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = past_sequence_length + kv_sequence_length; + p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif - return Status::OK(); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 0bad9eeb61231..8412631078e6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,19 +14,28 @@ namespace cuda { template struct GroupQueryAttentionData { + // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors T* output = nullptr; T* present_key = nullptr; T* present_value = nullptr; + // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5f..d7aeef1501cd6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e5..3fe9dbf8ed34a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,6 +703,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 76c3f8716ff09..5bc18a4e69b47 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1051,15 +1051,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .Output(2, "present_value", "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 04351cd6e6782..319fed87dc9eb 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -10,7 +10,10 @@ # license information. # ------------------------------------------------------------------------- import math +import os +import platform import random +import unittest import numpy import torch @@ -22,6 +25,8 @@ torch.manual_seed(0) +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + class Formats: BSNH = 0 @@ -159,7 +164,7 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_no_past(config, causal=False): +def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH): nodes = [ helper.make_node( "GroupQueryAttention", @@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False): "key", "value", ], - ["output"], + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, unidirectional=1 if causal else 0, + is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0, domain="com.microsoft", ), ] @@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False): TensorProto.FLOAT16, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen): return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) -# TODO(aciddelgado): rename def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): onnx_model_str = create_packed_multihead_attention_graph(config) qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) @@ -548,8 +573,8 @@ def mha_func(q, k, v, config): return output -def gqa_no_past_func(q, k, v, config, causal=True): - onnx_model_str = create_group_query_attention_graph_no_past(config, causal) +def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH): + onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) @@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True): } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) - ort_output = ort_session.run(None, ort_inputs) + ort_output, _, _ = ort_session.run(None, ort_inputs) ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) return output @@ -689,17 +714,12 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) if causal: # Some rows are completely masked out so we fill them with zero instead of NaN attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: @@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(present_k[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref.shape) - - # print(present_k - k_cache_ref.detach().cpu().numpy()) - # Make sure past-present buffer updating correctly if past_format == Formats.BSNH: assert numpy.allclose( @@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff( ) +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + torch.manual_seed(69) + print("-------- TEST GQA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1024, 1024), + (1023, 1024), + (2048, 2048), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + if major < 5 or (major == 5 and minor < 3): + return + print("------- MEMORY EFFICIENT ATTENTION ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + + def test_gqa_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- TEST GQA PAST ---------") + print("-------- MEMORY EFFICEINT --------") + batches = [2] if pipeline_mode else [1, 2] + seqs = ( + [(1, 128), (3, 1024), (64, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == "__main__": - print("-------- TEST PACKED MHA ---------") - for b in [5]: - for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, 0, n, n, h) - parity_check_mha(config, True) - print("-------- TEST MHA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, 0, n, n, h) - parity_check_mha(config, False) - print("-------- TEST GQA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) - print("-------- TEST GQA PAST ---------") - random.seed(69) - for b in [2]: - for s, s2 in [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + unittest.main()