From 406cd324e0f7de5ecb58e55c267924b790a49e85 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:43:12 -0700 Subject: [PATCH] [CUDA] GroupQueryAttention operator using FlashAttention (#17674) ### Description Added Group Query Attention op, supporting integer multiple number of heads for Q / KV. As of now, this op can only use FlashAttention kernel, meaning it only supports sm>=80 on Linux. Results from onnxruntime/test/python/transformers/benchmark_gqa.py show an on-average ~37% speed-up over Decoder Masked Multi-Head Attention, with even greater improvements for long past sequence lengths. ``` op batch s_kv heads h_dim ms TFLOPS gqa 16 2048 8 32 0.34 0.10 dmmha 16 2048 8 32 0.39 0.09 --------- gqa 16 2048 8 64 0.45 0.15 dmmha 16 2048 8 64 0.61 0.11 --------- gqa 16 2048 8 128 0.54 0.25 dmmha 16 2048 8 128 0.83 0.16 --------- gqa 16 2048 16 32 0.45 0.15 dmmha 16 2048 16 32 0.69 0.10 --------- gqa 16 2048 16 64 0.69 0.19 dmmha 16 2048 16 64 0.83 0.16 --------- gqa 16 2048 16 128 0.71 0.38 dmmha 16 2048 16 128 1.28 0.21 --------- gqa 16 2048 32 32 0.58 0.23 dmmha 16 2048 32 32 0.77 0.17 --------- gqa 16 2048 32 64 0.58 0.46 dmmha 16 2048 32 64 1.25 0.21 --------- gqa 16 2048 32 128 0.76 0.71 dmmha 16 2048 32 128 2.15 0.25 --------- gqa 16 2048 64 32 0.68 0.39 dmmha 16 2048 64 32 1.23 0.22 --------- gqa 16 2048 64 64 0.77 0.70 dmmha 16 2048 64 64 2.11 0.25 --------- gqa 16 2048 64 128 1.10 0.97 dmmha 16 2048 64 128 4.06 0.26 --------- gqa 16 2048 128 32 1.00 0.54 dmmha 16 2048 128 32 2.09 0.26 --------- gqa 16 2048 128 64 1.10 0.97 dmmha 16 2048 128 64 4.08 0.26 ``` ### Motivation and Context As of now, this op is targeted for use on LLama models, as it supports kv-caching and different number of heads for Q and KV (Grouped Query Attention). We plan to add support for more platforms, input formats, etc. in the future. --------- Co-authored-by: Tianlei Wu Co-authored-by: tlwu@microsoft.com --- docs/ContribOperators.md | 68 +- docs/OperatorKernels.md | 1 + .../contrib_ops/cpu/bert/attention_common.h | 20 + .../cuda/bert/add_bias_transpose.cu | 9 +- .../contrib_ops/cuda/bert/bert_padding.cu | 44 +- .../cuda/bert/embed_layer_norm_impl.cu | 13 +- .../contrib_ops/cuda/bert/fast_gelu_impl.cu | 6 +- .../cuda/bert/flash_attention/block_info.h | 12 +- .../cuda/bert/flash_attention/flash.h | 25 +- .../cuda/bert/flash_attention/flash_api.cc | 231 +++++- .../cuda/bert/flash_attention/flash_api.h | 36 +- .../bert/flash_attention/flash_fwd_kernel.h | 714 ++++++++++++++++- .../flash_fwd_launch_template.h | 83 +- .../flash_fwd_split_hdim128_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim160_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim192_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim224_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim256_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim32_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim64_fp16_sm80.cu | 15 + .../flash_fwd_split_hdim96_fp16_sm80.cu | 15 + .../cuda/bert/flash_attention/kernel_traits.h | 27 +- .../cuda/bert/flash_attention/utils.h | 130 ++-- .../cuda/bert/group_query_attention.cc | 185 +++++ .../cuda/bert/group_query_attention.h | 34 + .../cuda/bert/group_query_attention_helper.h | 253 ++++++ .../cuda/bert/group_query_attention_impl.cu | 279 +++++++ .../cuda/bert/group_query_attention_impl.h | 42 + .../cuda/bert/longformer_attention_impl.cu | 92 +-- .../cuda/bert/skip_layer_norm_impl.cu | 9 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 135 +++- onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../test/python/transformers/benchmark_gqa.py | 339 ++++++++ .../python/transformers/test_flash_attn.py | 722 +++++++++++++++++- tools/ci_build/build.py | 6 +- 36 files changed, 3448 insertions(+), 191 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h create mode 100644 onnxruntime/test/python/transformers/benchmark_gqa.py diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 888bcdbb9e21b..2a16bdbf7b55d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -42,6 +42,7 @@ Do not modify directly.* * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm + * com.microsoft.GroupQueryAttention * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention @@ -1170,9 +1171,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
present_key (optional) : T
-
past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
present_value (optional) : T
-
past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
#### Type Constraints @@ -2268,6 +2269,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GroupQueryAttention** + + Group Query Self/Cross Attention. + + Supports different number of heads for q and kv. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
is_past_bsnh : int
+
Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).
+
kv_num_heads : int (required)
+
Number of attention heads for k and v
+
num_heads : int (required)
+
Number of attention heads for q
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 1.
+
+ +#### Inputs (3 - 6) + +
+
query : T
+
Query with shape (batch_size, sequence_length, hidden_size)
+
key : T
+
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
value : T
+
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
past_key (optional) : T
+
past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
past_value (optional) : T
+
past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
past_sequence_length (optional) : M
+
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
+
+ +#### Outputs (1 - 3) + +
+
output : T
+
3D output tensor with shape (batch_size, sequence_length, hidden_size)
+
present_key (optional) : T
+
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
present_value (optional) : T
+
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output to float tensors.
+
M : tensor(int32), tensor(int64)
+
Constrain past sequence length to int tensor.
+
+ + ### **com.microsoft.Inverse** #### Version diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 14b6b339c11f3..ce9d8aabfede3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -840,6 +840,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)
**T** = tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 4c9c15d07a9b8..5184dd99309b1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -82,6 +82,26 @@ struct PackedAttentionParameters { bool broadcast_res_pos_bias; }; +// Parameters deduced from node attributes and inputs/outputs. +struct GroupQueryAttentionParameters { + int batch_size; + int sequence_length; + int past_sequence_length; // actual sequence length of past_key and past_value + int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present) + int present_sequence_length; // past_sequence_length + kv_sequence_length + int max_sequence_length; // allocated length of past_key and past_value + int hidden_size; + int num_heads; + int head_size; + int kv_hidden_size; + int kv_num_heads; + bool is_unidirectional; // causal + float scale; + int num_splits; // number of splits for splitkv + AttentionQkvFormat qkv_format; + AttentionQkvFormat past_kv_format; +}; + namespace attention { // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index d846f55f1e28d..626e4c0b87a3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o T* k_smem = q_smem + rotary_embedding_dim; const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (head_idx) / half_rotary_dim; - const int intra_half_idx = (head_idx) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; + const int half_idx = (head_idx) / half_rotary_dim; + const int intra_half_idx = (head_idx) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; if (do_rotary) { *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; @@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co } } - template __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) { // Format 3 for cutlass memory efficient attention @@ -651,7 +650,7 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (qk_head_size != 64 && qk_head_size !=128) { + if (qk_head_size != 64 && qk_head_size != 128) { ORT_THROW("qk_head_size must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu index 2af748d8d4a62..32ed961a68049 100644 --- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu +++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu @@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK) const int* attention_masks, const int batch_size, const int sequence_length) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int batch_id = blockIdx.x; - const int* batch_mask = attention_masks + (batch_id * sequence_length); - const bool leftmost_non_zero = (batch_mask[0] != 0); - int biggest_position = 0; - - for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { - if (leftmost_non_zero == (batch_mask[i] != 0)) { - biggest_position = i; - } else { - break; - } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int batch_id = blockIdx.x; + const int* batch_mask = attention_masks + (batch_id * sequence_length); + const bool leftmost_non_zero = (batch_mask[0] != 0); + int biggest_position = 0; + + for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { + if (leftmost_non_zero == (batch_mask[i] != 0)) { + biggest_position = i; + } else { + break; } + } - int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); + int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); - if (threadIdx.x == 0) { - int batch_offset = batch_id * sequence_length; - trt_mha_padding_offset[2 * batch_id] = batch_offset; - trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; - if (batch_id == gridDim.x - 1) { - trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; - } + if (threadIdx.x == 0) { + int batch_offset = batch_id * sequence_length; + trt_mha_padding_offset[2 * batch_id] = batch_offset; + trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; + if (batch_id == gridDim.x - 1) { + trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; } + } } // only support simple left padding with mask 0s on leading left, diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index a2dfca8cd6f09..ae53eca541fa5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_ } inline Status ComputeMaskIndex(cudaStream_t stream, - const int sequence_length, - const int batch_size, - const int* mask, - int* mask_index) { + const int sequence_length, + const int batch_size, + const int* mask, + int* mask_index) { // Mask idx is of length batch_size and assumes the valid region is contiguous starting // from the beginning of the sequence @@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel( } if (nullptr == position_ids) { position_id = blockIdx.x; - } else if (broadcast_position_ids){ + } else if (broadcast_position_ids) { position_id = position_ids[sequence_position % gridDim.x]; } else { position_id = position_ids[sequence_position]; @@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel( void* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { - if (mask_index != nullptr) { if (nullptr == input_mask) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)); } else { ORT_RETURN_IF_ERROR( - ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); + ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu index 1b0de47a834ec..c9498eb1bcd7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu @@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const float* input, const float* bias, float* output, bool /*use_half2*/) { + const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; FastGeluKernel<<>>(A, B, C, input_length, bias_length, @@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const half* input, const half* bias, half* output, bool use_half2) { + const half* input, const half* bias, half* output, bool use_half2) { constexpr int blockSize = 256; if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) { const int n = input_length / 2; @@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { + const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; // remove nv_bfloat162 implementation for now to fix build issue diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h index 9db98061bbd66..811b1be7d4315 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -12,9 +12,13 @@ struct BlockInfo { template __device__ BlockInfo(const Params& params, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), - sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), - actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q), - actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) { + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } template @@ -30,6 +34,8 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; const int actual_seqlen_k; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 9394a19c9897a..0aaf5e5f1ba28 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -45,6 +45,7 @@ struct Qkv_params { struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -56,9 +57,10 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; // The scaling factors for the kernel. float scale_softmax; @@ -70,9 +72,26 @@ struct Flash_fwd_params : public Qkv_params { int* __restrict__ blockmask; + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + bool is_bf16 = false; bool is_causal; + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + int num_splits; // For split-KV version + const cudaDeviceProp* dprops; }; @@ -80,6 +99,8 @@ struct Flash_fwd_params : public Qkv_params { template void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); } // namespace flash -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 87831d1eddfe9..805a73be96778 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -34,24 +34,37 @@ void set_params_fprop(Flash_fwd_params& params, void* p_d, void* softmax_lse_d, float softmax_scale, - bool is_causal) { + bool is_causal, + bool kv_bsnh = true) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; params.v_ptr = v; params.o_ptr = out; - // All stride are in elements, not bytes. - params.q_row_stride = num_heads * head_size; - params.k_row_stride = num_heads_k * head_size; - params.v_row_stride = num_heads * head_size; - params.q_head_stride = head_size; - params.k_head_stride = head_size; - params.v_head_stride = head_size; - params.o_row_stride = num_heads * head_size; - params.o_head_stride = head_size; params.is_bf16 = false; + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) @@ -90,6 +103,7 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax_log2 = softmax_scale * M_LOG2E; params.is_causal = is_causal; + params.is_seqlens_k_cumulative = true; } size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { @@ -97,14 +111,85 @@ size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { return bytes; } -void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } }); }); } +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits, bool new_kv, bool is_sm8x) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) + : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); + const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -119,7 +204,11 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, - bool is_causal) { + bool is_causal, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -139,7 +228,26 @@ Status mha_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + kv_bsnh); + + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } run_mha_fwd(params, stream); return Status::OK(); @@ -192,6 +300,101 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); } +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded +) { + if (seqlen_q == 1) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + params.dprops = &dprops; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + past_bsnh); + + if (k != nullptr && v != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k; + params.vnew_ptr = v; + // All stride are in elements, not bytes. + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + + return Status::OK(); +} + } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 2ae46d34c373a..0a0328edb0059 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -34,6 +34,7 @@ namespace onnxruntime { namespace flash { + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size @@ -48,7 +49,11 @@ Status mha_fwd(const cudaDeviceProp& dprops, int seqlen_q, int seqlen_k, float softmax_scale, - bool is_causal); + bool is_causal, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -68,7 +73,36 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal); +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded +); + size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); + +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index b5af31e432d42..eb1c794d6df54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -79,7 +79,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T flash::reduce_sum(scores, scores_sum); } else { cute::Tensor scores_max_prev = make_fragment_like(scores_max); - copy(scores_max, scores_max_prev); + cute::copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); @@ -109,7 +109,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T template inline __device__ void write_softmax_to_gmem( - cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_thr_copy_P) { + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) cute::Layout l = tOrP.layout(); cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); @@ -117,7 +117,7 @@ inline __device__ void write_softmax_to_gmem( CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { - copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; @@ -147,6 +147,45 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= 0) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } } // We iterate over the blocks in reverse order. This is because the last block is the only one @@ -504,6 +543,494 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = n_split_idx * n_blocks_per_split; + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + if constexpr (Append_KV) { + // if (cute::thread0()) { print(tKgK); } + // if (cute::thread0()) { print(tKsK); } + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + // __syncthreads(); + // if (cute::thread0()) { print(tKgK); } + // __syncthreads(); + } + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (Append_KV) { + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + + if (n_block > n_block_min) { + // Advance gK + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (Append_KV) { + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, + binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (Append_KV) { + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (Append_KV) { + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, + binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; @@ -524,6 +1051,187 @@ inline __device__ void compute_attn(const Params& params) { } //////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kBlockM = 16; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); + static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); + static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } +// Load Oaccum in then scale and accumulate to O +#pragma unroll 2 + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + } // namespace flash } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e633ef4d45fbb..e0be6b828f85d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) { flash::compute_attn(params); } +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { + flash::compute_attn_splitkv(params); +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +} + template void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; @@ -25,8 +36,6 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { @@ -40,9 +49,7 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { // ORT_ENFORCE(cudaFuncSetAttribute( // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); @@ -51,6 +58,72 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { }); } +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr int kBlockM = 64; // Fixed for all head dimensions + if (!is_sm8x) { // A100, H100 + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); + } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); + } +} + template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { constexpr int Headdim = 32; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000000..68ae2ea759813 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000000..94564a6aba8f3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000000..ec9e9e738c5b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 0000000000000..e6c4ff5d95584 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000000..552966852cdbe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000000..e9f191a4828d6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000000..d628a556680ad --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000000..88b6cc0fb1e22 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 0c967faa85c45..134f159e258c4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -111,7 +111,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; @@ -139,18 +140,28 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy>; using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + cute::Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store + cute::Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -289,13 +300,13 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; + // static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 49ee687419d0e..02042e183f808 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -96,46 +96,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ float2 half2_unpack(uint32_t a); - -template <> -inline __device__ float2 half2_unpack<__half>(uint32_t a) { - return __half22float2(reinterpret_cast<__half2(&)>(a)); -} - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template <> -inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { - return __bfloat1622float2(reinterpret_cast<__nv_bfloat162(&)>(a)); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert two half2's or bf162's into float, then take their dot product. -template -inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { - float2 af = flash::half2_unpack(a); - float2 bf = flash::half2_unpack(b); - return af.x * bf.x + af.y * bf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Converted two vectors of 8 half's or bf16's into float, then take their dot product. -template -inline __device__ float hmulsum8(const uint4 a, const uint4 b) { - float sum; - sum = flash::hfma2_to_float(a.x, b.x); - sum += flash::hfma2_to_float(a.y, b.y); - sum += flash::hfma2_to_float(a.z, b.z); - sum += flash::hfma2_to_float(a.w, b.w); - return sum; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MaxOp { __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } @@ -245,7 +205,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -261,9 +224,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { static_assert(mma_shape_K == 8 || mma_shape_K == 16); constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -338,7 +305,7 @@ CUTE_HOST_DEVICE void cp_async_wait() { template -inline __device__ void copy(TiledCopy thr_copy, Tensor const& S, +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); @@ -354,13 +321,80 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const& #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - copy(thr_copy, S(_, m, k), D(_, m, k)); + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, + Tensor const& S1, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int row_idx_switch = 0) { + CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } +#pragma unroll + for (int m = 0; m < size<1>(S0); ++m) { + auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S0); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { - clear(D(_, m, k)); + cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { - clear(D(_, m, _)); + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc new file mode 100644 index 0000000000000..65d19d4473872 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention.h" +#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +// #include "contrib_ops/cpu/utils/console_dumper.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 5), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : CudaKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 1) == 1; + is_past_bsnh_ = info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); +#else + disable_flash_attention_ = true; +#endif +} + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* past_seq_len = context->Input(5); + + auto& device_prop = GetDeviceProp(); + GroupQueryAttentionParameters parameters; + typedef typename ToCudaType::MappedType CudaT; + GroupQueryAttentionData data; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + ¶meters, + num_heads_, + kv_num_heads_, + past_seq_len, + is_past_bsnh_, + scale_, + device_prop.maxThreadsPerBlock)); + parameters.is_unidirectional = is_unidirectional_; + int sequence_length = parameters.sequence_length; + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.kv_num_heads); + // Allocate buffers + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + size_t seqlens_k_bytes = 0; + if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + // split kv buffers + parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount, 128, false, + device_prop.major == 8 && device_prop.minor > 0); + if (parameters.num_splits > 1) { + // softmax_lse_accum buffer + softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( + parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(parameters.head_size, 32); + out_accum_bytes = onnxruntime::flash::get_out_accum_size( + parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); + } + // seqlens_k buffer + if (past_key != nullptr) { + seqlens_k_bytes = sizeof(int) * parameters.batch_size; + } + } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + + // only kernel implemented for gqa right now + ORT_ENFORCE(use_flash_attention); + + data.query = reinterpret_cast(query->Data()); + data.key = reinterpret_cast(key->Data()); + data.value = reinterpret_cast(value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); + data.output = reinterpret_cast(output->MutableData()); + data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.use_flash_attention = use_flash_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } + if (seqlens_k_buffer != nullptr) { + data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); + } + + cublasHandle_t cublas = GetCublasHandle(context); + + return QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h new file mode 100644 index 0000000000000..72c9814fad670 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_kernel.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GroupQueryAttention final : public CudaKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int past_sequence_length_; + bool is_unidirectional_; // causal + bool is_past_bsnh_; + float scale_; + bool disable_flash_attention_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h new file mode 100644 index 0000000000000..be8f5ca0ae3e9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace group_query_attention_helper { + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* past_seq_len, + bool is_past_bsnh, + float scale) { + // Note: Here S* is max_sequence_length, S- is past_sequence_length, S+ is kv_sequence_length + // past_key : (B, S*, N_k, H) or (B, N_k, S*, H) or (B, S-, N_k, H) or (B, N_k, S-, H) + // past_value : (B, S*, N_k, H) or (B, N_k, S*, H) or (B, S-, N_k, H) or (B, N_k, S-, H) + // no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, S+, D_kv) + // value (V) : (B, S+, D_kv) + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = Q_K_V_BSNH; + + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + const auto& value_dims = value->Shape().GetDims(); + + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = static_cast(q_hidden_size) / num_heads; + + int kv_sequence_length = sequence_length; + int kv_hidden_size = (key_dims.size() == 3) + ? static_cast(key_dims[2]) + : (kv_num_heads * static_cast(key_dims[3])); + + int max_sequence_length = 0; + if (past_key != nullptr && past_value != nullptr) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + // BNSH + if (!is_past_bsnh) { + past_kv_format = Q_K_V_BNSH; + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are left-padded to max or past sequence length + max_sequence_length = static_cast(past_key_dims[2]); + // BSNH + } else { + past_kv_format = Q_K_V_BSNH; + if (past_key_dims[1] != past_value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are left-padded to max or past sequence length + max_sequence_length = static_cast(past_key_dims[1]); + } + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + } else if (past_key != nullptr || past_value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall be both present or both absent"); + } + + if (key != nullptr) { + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + if (key_dims[2] != value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); + } + + qkv_format = Q_K_V_BSNH; + kv_sequence_length = static_cast(key_dims[1]); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Missing key tensor."); + } + + if (value != nullptr) { + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } + + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } + + if (static_cast(kv_sequence_length) != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Missing value tensor."); + } + + // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. + int32_t past_sequence_length = 0; + int present_sequence_length = 0; + if (past_seq_len != nullptr) { + if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when using past kv."); + } + if (past_seq_len->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32) { + past_sequence_length = *((*past_seq_len).template Data()); + } else { + past_sequence_length = static_cast(*((*past_seq_len).template Data())); + } + present_sequence_length = max_sequence_length; + } else if (past_key != nullptr) { + past_sequence_length = max_sequence_length; // this is the length of past_key tensor + present_sequence_length = past_sequence_length + kv_sequence_length; + } + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->past_sequence_length = past_sequence_length; + output_parameters->kv_sequence_length = kv_sequence_length; + output_parameters->present_sequence_length = present_sequence_length; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_unidirectional = true; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return Status::OK(); +} + +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* past_key, + const T* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const T* past_seq_len, + bool is_past_bsnh, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + + return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, past_seq_len, is_past_bsnh, scale); +} + +} // namespace group_query_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu new file mode 100644 index 0000000000000..ab3029ca34886 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -0,0 +1,279 @@ +/* + The implementation of this file is based on our Multi-Head Attention impl.cu file, + which is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: +// (1) support GPT-2 past state, unidirectional mask (causal) +// (2) use flash attention kernel from (https://github.com/Dao-AILab/flash-attention) +// (3) support different number of heads for Q and KV +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" +#include "contrib_ops/cuda/bert/transformer_common.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cuda/bert/bert_padding.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Kernel for seqlens_k +__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id < batch_size) seqlens_k[id] = seqlen; +} + +// Kernel to append new and past kv in either BSNH or BNSH format +// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file +template +__global__ void ConcatNewToPastKV(const int new_seqlen, + const T* past_kv, + const T* new_kv, + T* present_kv, + const bool is_bsnh) { // refers to past; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH or BNLH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < present_seqlen) { + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } +} + +template +__global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int H, + const T* past_kv, + const T* new_kv, + T* present_kv, + const bool is_bsnh) { + // Use when (H*)*num_heads > 1024 + int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int thread_stride = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH or BNLH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; + + while (h < H) { + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < present_seqlen) { + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } + h += thread_stride; + } +} + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + assert(data.use_flash_attention); + +#if USE_FLASH_ATTENTION + auto stream = static_cast(ort_stream->GetHandle()); + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; + if (data.use_flash_attention) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(parameters.num_heads % parameters.kv_num_heads == 0); + + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key == nullptr && data.present_key == nullptr) { + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), + parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, + parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); + + } else if (data.past_key == data.present_key) { + // Assume past and present kv share buffer. + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(parameters.past_sequence_length >= 0); + assert(data.past_value != nullptr); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + if (head_size % 4 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); + } + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); + } +#endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); +} + +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h new file mode 100644 index 0000000000000..0bad9eeb61231 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +struct GroupQueryAttentionData { + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; + int* seqlens_k = nullptr; + T* output = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + bool use_flash_attention = false; +}; + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index de3c3fb6ca065..f00239460071b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -924,55 +924,55 @@ Status LongformerQkvToContext( if (disable_compact_memory) { ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxSimpleKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } else { ORT_ENFORCE(max_num_global <= window); ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - max_num_global, - compact_global_q, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + max_num_global, + compact_global_q, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } // The temp_output is BxNxSxH, transpose it to final output BxSxNxH diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index f2ee076a8a03d..bfecacf4fb717 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -90,7 +90,6 @@ __global__ void SkipLayerNormKernel( // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); - for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; @@ -130,10 +129,10 @@ __global__ void SkipLayerNormKernelSmall( *input_val = *reinterpret_cast(&input[idx]); VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted){ - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - }else{ - *skip_val = *reinterpret_cast(&skip[idx]); + if (skip_broadcasted) { + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); + } else { + *skip_val = *reinterpret_cast(&skip[idx]); } if (hasBias) { diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 6b5356c9912fc..71ee5ae1ddbe6 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -71,6 +71,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -221,6 +222,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e5956a575d73d..e8d4785adf429 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -233,6 +233,59 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c } } +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { + // Output 0 has shape (batch_size, sequence_length, hidden_size) + + // Q, K and V: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, kv_hidden_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, kv_hidden_size) + + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + if (hasInputShape(ctx, 0)) { + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims.size() != 3) { + fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); + } + + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = query_dims[2]; + updateOutputShape(ctx, 0, output_shape); + return; + } else { + fail_shape_inference("Missing input 2 (value)"); + } + } + + if (ctx.getNumOutputs() > 1) { // has present output + if (hasInputShape(ctx, past_key_index)) { + auto& past_shape = getInputShape(ctx, past_key_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 4) { + fail_shape_inference("The past_key input shall be 4 dimensions"); + } + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + } + } +} + constexpr const char* Attention_ver1_doc = R"DOC( Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -823,7 +876,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(1, "present_key", - "past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", @@ -831,7 +884,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Output(2, "present_value", - "past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", @@ -930,6 +983,84 @@ ONNX_MS_OPERATOR_SET_SCHEMA( MultiHeadAttentionTypeAndShapeInference(ctx, 6); })); +constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( +Group Query Self/Cross Attention. + +Supports different number of heads for q and kv. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupQueryAttention, 1, + OpSchema() + .SetDoc(GroupQueryAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) + .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 1.", + AttributeProto::INT, + static_cast(1)) + .Attr("is_past_bsnh", + "Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).", + AttributeProto::INT, + static_cast(1)) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Input(0, + "query", + "Query with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "key", + "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", + "T") + .Input(2, + "value", + "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", + "T") + .Input(3, + "past_key", + "past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(4, + "past_value", + "past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(5, + "past_sequence_length", + "When buffered past_key and past_value is used (present_key uses same tensor as past_key), required" + "to specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.", + "M", + OpSchema::Optional) + .Output(0, + "output", + "3D output tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Output(1, + "present_key", + "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T", + OpSchema::Optional) + .Output(2, + "present_value", + "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + GroupQueryAttentionTypeAndShapeInference(ctx, 3); + })); + constexpr const char* Longformer_Attention_doc = R"DOC( Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 3c31997286254..2007654e88242 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -83,6 +83,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -182,6 +183,7 @@ class OpSet_Microsoft_ver1 { #endif fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py new file mode 100644 index 0000000000000..a9bef025a70bb --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -0,0 +1,339 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: +sh benchmark_mha.sh +""" + +import math +import random +import statistics +import time + +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, OrtValue, SessionOptions + + +class InputFormats: + QKV_BSNH = 0 + QKV_BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +def create_group_query_attention_graph_past( + config, causal=False, past_kv_format=InputFormats.QKV_BSNH, share_buffer=True +): + past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length + ) + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + "past_key", + "past_value", + "past_sequence_length" if share_buffer else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + unidirectional=1 if causal else 0, + is_past_bsnh=1 if past_kv_format == InputFormats.QKV_BSNH else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_sequence_length", + TensorProto.INT32, + [1], + ) + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_gqa_session( + config: Config, + causal: bool = False, + past_format=InputFormats.QKV_BSNH, + share_buffer: bool = True, +) -> InferenceSession: + onnx_model_str = create_group_query_attention_graph_past(config, causal, past_format, share_buffer) + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + return ort_session + + +def bind_io(io_binding, input_dict, device, share_buffer=True): + io_binding.bind_cpu_input("query", input_dict["query"]) + io_binding.bind_cpu_input("key", input_dict["key"]) + io_binding.bind_cpu_input("value", input_dict["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, "float16", input_dict["past_key"].shape(), input_dict["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + "float16", + input_dict["past_value"].shape(), + input_dict["past_value"].data_ptr(), + ) + io_binding.bind_output("output") + if share_buffer: + io_binding.bind_cpu_input("past_sequence_length", input_dict["past_sequence_length"]) + io_binding.bind_output( + "present_key", + device_type="cuda", + device_id=device, + element_type="float16", + shape=input_dict["past_key"].shape(), + buffer_ptr=input_dict["past_key"].data_ptr(), + ) + io_binding.bind_output( + "present_value", + device_type="cuda", + device_id=device, + element_type="float16", + shape=input_dict["past_value"].shape(), + buffer_ptr=input_dict["past_value"].data_ptr(), + ) + else: + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + +def measure_latency(ort_session, io_binding): + start = time.time() + _ = ort_session.run_with_iobinding(io_binding) + end = time.time() + return end - start + + +def flops(batch, q_seqlen, kv_seqlen, head_size, num_heads): + return 4 * batch * q_seqlen * kv_seqlen * num_heads * head_size + + +def tflops_per_second(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def benchmark_op(session, io_binding, repeats=100): + # warm up session + _ = measure_latency(session, io_binding) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, io_binding) + latency_list.append(latency) + return statistics.mean(latency_list) + + +def run_tflops_test(dtype=torch.float16, repeats: int = 100): + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + print("---- GQA BSNH vs GQA BNSH ----") + print("op\tbatch\ts_kv\theads\th_dim\tms\tTFLOPS") + mean_bsnh_lat = 0 + mean_bnsh_lat = 0 + num_trials = 0 + share_buffer = True + random.seed(69) + for b in [1, 3, 8, 16]: + for s_q, s_kv in [(1, 128), (128, 256), (512, 512), (128, 1024), (1, 2048)]: + for n_q, n_kv in [(8, 8), (16, 8), (32, 32), (12, 3), (128, 64)]: + for h in [32, 64, 128]: + sp = random.randint(1, s_kv - 1) if s_kv - 1 > 0 else 0 + config = Config(b, s_q, s_kv, sp, n_q, n_kv, h) + + bsnh_session = create_gqa_session( + config, + causal=False, + past_format=InputFormats.QKV_BSNH, + share_buffer=share_buffer, + ) + bnsh_session = create_gqa_session( + config, + causal=False, + past_format=InputFormats.QKV_BNSH, + share_buffer=share_buffer, + ) + + q = torch.randn(b, s_q, n_q * h, device=device, dtype=dtype) + kv = torch.randn(b, s_q, 2, n_kv * h, device=device, dtype=dtype) + k, v = kv.unbind(dim=2) + + past_kv = torch.rand(b, s_kv if share_buffer else sp, 2, n_kv, h, device=device, dtype=dtype) + past_k, past_v = past_kv.unbind(dim=2) + + input_dict_bsnh = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", device_id), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", device_id), + } + input_dict_bnsh = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy( + past_k.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 + ), + "past_value": OrtValue.ortvalue_from_numpy( + past_v.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 + ), + } + if share_buffer: + input_dict_bsnh["past_sequence_length"] = ( + torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() + ) + input_dict_bnsh["past_sequence_length"] = ( + torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() + ) + + io_binding_bsnh = bsnh_session.io_binding() + io_binding_bnsh = bnsh_session.io_binding() + bind_io(io_binding_bsnh, input_dict_bsnh, device_id, share_buffer) + bind_io(io_binding_bnsh, input_dict_bnsh, device_id, share_buffer) + average_gqa_bsnh_latency = benchmark_op(bsnh_session, io_binding_bsnh, repeats) + average_gqa_bnsh_latency = benchmark_op(bnsh_session, io_binding_bnsh, repeats) + + del bsnh_session + del bnsh_session + + # compute TFLOPS per second + bsnh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bsnh_latency) + print(f"bsnh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bsnh_latency * 1000:.2f}\t{bsnh_speed:.2f}") + bnsh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bnsh_latency) + print(f"bnsh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bnsh_latency * 1000:.2f}\t{bnsh_speed:.2f}") + print("---------") + if average_gqa_bsnh_latency > 10 * average_gqa_bnsh_latency: + continue + num_trials += 1 + mean_bsnh_lat += average_gqa_bsnh_latency + mean_bnsh_lat += average_gqa_bnsh_latency + mean_bsnh_lat /= num_trials + mean_bnsh_lat /= num_trials + print(f"average bsnh latency:\t{mean_bsnh_lat}") + print(f"average bnsh latency:\t{mean_bnsh_lat}") + + +if __name__ == "__main__": + run_tflops_test() diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index f90a9475b4588..04351cd6e6782 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -10,6 +10,7 @@ # license information. # ------------------------------------------------------------------------- import math +import random import numpy import torch @@ -17,23 +18,32 @@ from einops import rearrange, repeat from onnx import TensorProto, helper -from onnxruntime import InferenceSession, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions torch.manual_seed(0) +class Formats: + BSNH = 0 + BNSH = 1 + + class Config: batch_size = 0 sequence_length = 0 kv_sequence_length = 0 + past_sequence_length = 0 num_heads = 0 + kv_num_heads = 0 head_size = 0 - def __init__(self, b, s, s2, n, h): + def __init__(self, b, s, s2, sp, n, n2, h): self.batch_size = b self.sequence_length = s self.kv_sequence_length = s2 + self.past_sequence_length = sp self.num_heads = n + self.kv_num_heads = n2 self.head_size = h @@ -149,6 +159,196 @@ def create_multihead_attention_graph(config): return model.SerializeToString() +def create_group_query_attention_graph_no_past(config, causal=False): + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + ], + ["output"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + unidirectional=1 if causal else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_past(config, causal=False, past_kv_format=Formats.BSNH, share_buffer=True): + past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length + ) + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + "past_key", + "past_value", + "past_sequence_length" if share_buffer else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + unidirectional=1 if causal else 0, + is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_sequence_length", + TensorProto.INT32, + [1], + ) + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": @@ -314,6 +514,7 @@ def generate_token_offset(cu_seqlens, max_seqlen): return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) +# TODO(aciddelgado): rename def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): onnx_model_str = create_packed_multihead_attention_graph(config) qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) @@ -329,7 +530,7 @@ def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config return output -def flash_attn_func(q, k, v, config, causal=False): +def mha_func(q, k, v, config): onnx_model_str = create_multihead_attention_graph(config) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) @@ -342,10 +543,108 @@ def flash_attn_func(q, k, v, config, causal=False): sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) ort_output = ort_session.run(None, ort_inputs) + ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) return output +def gqa_no_past_func(q, k, v, config, causal=True): + onnx_model_str = create_group_query_attention_graph_no_past(config, causal) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + ort_output = ort_session.run(None, ort_inputs) + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output + + +def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, causal=False, share_buffer=True): + onnx_model_str = create_group_query_attention_graph_past(config, causal, past_kv_format, share_buffer) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + past_k = k.clone() + past_v = v.clone() + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "past_sequence_length": torch.tensor([config.past_sequence_length], dtype=torch.int32) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + numpy.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("past_sequence_length", ort_inputs["past_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": past_k.detach().cpu().numpy(), + "past_value": past_v.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + + +def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + return col_idx > row_idx + sk - sq + + def attention_ref( q, k, @@ -390,10 +689,17 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) + # causal_mask = torch.triu( + # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 + # ) + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) + if causal: # Some rows are completely masked out so we fill them with zero instead of NaN + attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: @@ -422,7 +728,7 @@ def attention_qkvpacked_ref( ) -def parity_check( +def parity_check_mha( config, packed, rtol=1e-3, @@ -456,7 +762,7 @@ def parity_check( k = torch.randn( config.batch_size, config.kv_sequence_length, - config.num_heads, + config.kv_num_heads, config.head_size, device="cuda", dtype=torch.float16, @@ -465,27 +771,362 @@ def parity_check( v = torch.randn( config.batch_size, config.kv_sequence_length, - config.num_heads, + config.kv_num_heads, config.head_size, device="cuda", dtype=torch.float16, requires_grad=False, ) - out = flash_attn_func(q, k, v, config) + out = mha_func(q, k, v, config) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None) + out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) out_ref = out_ref.detach().cpu().numpy() + + # Compare results + print( + " B:", + config.batch_size, + " S:", + config.sequence_length, + " N:", + config.num_heads, + " kvN:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_no_past( + config, + causal=False, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + # Pytorch to compare + out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + out_ref = out_ref.detach().cpu().numpy() + # Flash function + out = gqa_no_past_func(q, k, v, config, causal=causal) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Compare results + print( + " causal:", + causal, + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past( + config, + causal=False, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, causal, True) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past_no_buff( + config, + causal=False, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = None + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, causal, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # print(present_k[0, 0, config.past_sequence_length, :10]) + # print(k_cache_ref[0, 0, config.past_sequence_length, :10]) + # print(k_cache_ref.shape) + + # print(present_k - k_cache_ref.detach().cpu().numpy()) + + # Make sure past-present buffer updating correctly + if past_format == Formats.BSNH: + assert numpy.allclose( + present_k, + k_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + assert numpy.allclose( + present_v, + v_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + else: + assert numpy.allclose( + present_k, + k_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + assert numpy.allclose( + present_v, + v_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Compare results print( + "Unbuffered", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, " B:", config.batch_size, " S:", config.sequence_length, + " kv S:", + config.kv_sequence_length, " N:", config.num_heads, + " kv N:", + config.kv_num_heads, " h:", config.head_size, " Mean Error:", @@ -506,8 +1147,8 @@ def parity_check( for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: for n in [6]: for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, n, h) - parity_check(config, True) + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) print("-------- TEST MHA ---------") for b in [5]: for s, s2 in [ @@ -524,5 +1165,60 @@ def parity_check( ]: for n in [6]: for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, n, h) - parity_check(config, False) + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + print("-------- TEST GQA ---------") + for b in [5]: + for s, s2 in [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + print("-------- TEST GQA PAST ---------") + random.seed(69) + for b in [2]: + for s, s2 in [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ]: + for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: + for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c07766b5e5d34..d6a5761d54042 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -898,7 +898,7 @@ def number_of_nvcc_threads(args): # Standard_NC4as_T4_v3 has 4 CPUs and 28 GB memory. When parallel=4 and nvcc_threads=2, # total nvcc threads is 4 * 2, which is barely able to build in 28 GB memory so we will use nvcc_threads=1. memory_per_thread = 4 * 1024 * 1024 * 1024 - fmha_cu_files = 4 if is_windows() else 8 + fmha_cu_files = 4 if is_windows() else 16 fmha_parallel_jobs = min(fmha_cu_files, number_of_parallel_jobs(args)) nvcc_threads = max(1, int(available_memory / (memory_per_thread * fmha_parallel_jobs))) print( @@ -2269,7 +2269,9 @@ def generate_documentation(source_dir, build_dir, configs, validate): have_diff = False def diff_file(path, regenerate_qualifiers=""): - diff = subprocess.check_output(["git", "diff", path], cwd=source_dir).decode("utf-8") + diff = subprocess.check_output(["git", "diff", "--ignore-blank-lines", path], cwd=source_dir).decode( + "utf-8" + ) if diff: nonlocal have_diff have_diff = True