diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 888bcdbb9e21b..2a16bdbf7b55d 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -42,6 +42,7 @@ Do not modify directly.*
* com.microsoft.GreedySearch
* com.microsoft.GridSample
* com.microsoft.GroupNorm
+ * com.microsoft.GroupQueryAttention
* com.microsoft.Inverse
* com.microsoft.Irfft
* com.microsoft.LongformerAttention
@@ -1170,9 +1171,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
present_key (optional) : T
-past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
present_value (optional) : T
-past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
#### Type Constraints
@@ -2268,6 +2269,69 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GroupQueryAttention**
+
+ Group Query Self/Cross Attention.
+
+ Supports different number of heads for q and kv.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- is_past_bsnh : int
+- Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).
+- kv_num_heads : int (required)
+- Number of attention heads for k and v
+- num_heads : int (required)
+- Number of attention heads for q
+- scale : float
+- Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+- unidirectional : int
+- Whether every token can only attend to previous tokens. Default value is 1.
+
+
+#### Inputs (3 - 6)
+
+
+- query : T
+- Query with shape (batch_size, sequence_length, hidden_size)
+- key : T
+- Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
+- value : T
+- Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
+- past_key (optional) : T
+- past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+- past_value (optional) : T
+- past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+- past_sequence_length (optional) : M
+- When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
+
+
+#### Outputs (1 - 3)
+
+
+- output : T
+- 3D output tensor with shape (batch_size, sequence_length, hidden_size)
+- present_key (optional) : T
+- present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+- present_value (optional) : T
+- present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
+
+#### Type Constraints
+
+
+- T : tensor(float16)
+- Constrain input and output to float tensors.
+- M : tensor(int32), tensor(int64)
+- Constrain past sequence length to int tensor.
+
+
+
### **com.microsoft.Inverse**
#### Version
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 14b6b339c11f3..ce9d8aabfede3 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -840,6 +840,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)
**T** = tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index 4c9c15d07a9b8..5184dd99309b1 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -82,6 +82,26 @@ struct PackedAttentionParameters {
bool broadcast_res_pos_bias;
};
+// Parameters deduced from node attributes and inputs/outputs.
+struct GroupQueryAttentionParameters {
+ int batch_size;
+ int sequence_length;
+ int past_sequence_length; // actual sequence length of past_key and past_value
+ int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
+ int present_sequence_length; // past_sequence_length + kv_sequence_length
+ int max_sequence_length; // allocated length of past_key and past_value
+ int hidden_size;
+ int num_heads;
+ int head_size;
+ int kv_hidden_size;
+ int kv_num_heads;
+ bool is_unidirectional; // causal
+ float scale;
+ int num_splits; // number of splits for splitkv
+ AttentionQkvFormat qkv_format;
+ AttentionQkvFormat past_kv_format;
+};
+
namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index d846f55f1e28d..626e4c0b87a3c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o
T* k_smem = q_smem + rotary_embedding_dim;
const int half_rotary_dim = rotary_embedding_dim / 2;
- const int half_idx = (head_idx) / half_rotary_dim;
- const int intra_half_idx = (head_idx) % half_rotary_dim;
- const int smem_pitch = half_rotary_dim;
+ const int half_idx = (head_idx) / half_rotary_dim;
+ const int intra_half_idx = (head_idx) % half_rotary_dim;
+ const int smem_pitch = half_rotary_dim;
if (do_rotary) {
*reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
@@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
}
}
-
template
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
// Format 3 for cutlass memory efficient attention
@@ -651,7 +650,7 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
- if (qk_head_size != 64 && qk_head_size !=128) {
+ if (qk_head_size != 64 && qk_head_size != 128) {
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu
index 2af748d8d4a62..32ed961a68049 100644
--- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu
@@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
const int* attention_masks,
const int batch_size,
const int sequence_length) {
- typedef cub::BlockReduce BlockReduce;
- __shared__ typename BlockReduce::TempStorage temp_storage;
-
- const int batch_id = blockIdx.x;
- const int* batch_mask = attention_masks + (batch_id * sequence_length);
- const bool leftmost_non_zero = (batch_mask[0] != 0);
- int biggest_position = 0;
-
- for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
- if (leftmost_non_zero == (batch_mask[i] != 0)) {
- biggest_position = i;
- } else {
- break;
- }
+ typedef cub::BlockReduce BlockReduce;
+ __shared__ typename BlockReduce::TempStorage temp_storage;
+
+ const int batch_id = blockIdx.x;
+ const int* batch_mask = attention_masks + (batch_id * sequence_length);
+ const bool leftmost_non_zero = (batch_mask[0] != 0);
+ int biggest_position = 0;
+
+ for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
+ if (leftmost_non_zero == (batch_mask[i] != 0)) {
+ biggest_position = i;
+ } else {
+ break;
}
+ }
- int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
+ int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
- if (threadIdx.x == 0) {
- int batch_offset = batch_id * sequence_length;
- trt_mha_padding_offset[2 * batch_id] = batch_offset;
- trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
- if (batch_id == gridDim.x - 1) {
- trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
- }
+ if (threadIdx.x == 0) {
+ int batch_offset = batch_id * sequence_length;
+ trt_mha_padding_offset[2 * batch_id] = batch_offset;
+ trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
+ if (batch_id == gridDim.x - 1) {
+ trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
+ }
}
// only support simple left padding with mask 0s on leading left,
diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
index a2dfca8cd6f09..ae53eca541fa5 100644
--- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
@@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_
}
inline Status ComputeMaskIndex(cudaStream_t stream,
- const int sequence_length,
- const int batch_size,
- const int* mask,
- int* mask_index) {
+ const int sequence_length,
+ const int batch_size,
+ const int* mask,
+ int* mask_index) {
// Mask idx is of length batch_size and assumes the valid region is contiguous starting
// from the beginning of the sequence
@@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel(
}
if (nullptr == position_ids) {
position_id = blockIdx.x;
- } else if (broadcast_position_ids){
+ } else if (broadcast_position_ids) {
position_id = position_ids[sequence_position % gridDim.x];
} else {
position_id = position_ids[sequence_position];
@@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel(
void* embedding_sum,
const int* position_ids,
const bool broadcast_position_ids) {
-
if (mask_index != nullptr) {
if (nullptr == input_mask) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream));
} else {
ORT_RETURN_IF_ERROR(
- ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index)));
+ ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index)));
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
index 1b0de47a834ec..c9498eb1bcd7b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
@@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int
template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
- const float* input, const float* bias, float* output, bool /*use_half2*/) {
+ const float* input, const float* bias, float* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<<>>(A, B, C, input_length, bias_length,
@@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int
template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
- const half* input, const half* bias, half* output, bool use_half2) {
+ const half* input, const half* bias, half* output, bool use_half2) {
constexpr int blockSize = 256;
if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
@@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int
template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
- const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
+ const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
// remove nv_bfloat162 implementation for now to fix build issue
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
index 9db98061bbd66..811b1be7d4315 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
@@ -12,9 +12,13 @@ struct BlockInfo {
template
__device__ BlockInfo(const Params& params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
- sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
- actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q),
- actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) {
+ sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]),
+ actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ ,
+ seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])),
+ actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
}
template
@@ -30,6 +34,8 @@ struct BlockInfo {
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+ const int seqlen_k_cache;
const int actual_seqlen_k;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
index 9394a19c9897a..0aaf5e5f1ba28 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
@@ -45,6 +45,7 @@ struct Qkv_params {
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
+ void* __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
@@ -56,9 +57,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum.
void* __restrict__ softmax_lse_ptr;
+ void* __restrict__ softmax_lseaccum_ptr;
// The dimensions.
- int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
@@ -70,9 +72,26 @@ struct Flash_fwd_params : public Qkv_params {
int* __restrict__ blockmask;
+ // The K_new and V_new matrices.
+ void* __restrict__ knew_ptr;
+ void* __restrict__ vnew_ptr;
+
+ // The stride between rows of the Q, K and V matrices.
+ index_t knew_batch_stride;
+ index_t vnew_batch_stride;
+ index_t knew_row_stride;
+ index_t vnew_row_stride;
+ index_t knew_head_stride;
+ index_t vnew_head_stride;
+
bool is_bf16 = false;
bool is_causal;
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ bool is_seqlens_k_cumulative;
+ int num_splits; // For split-KV version
+
const cudaDeviceProp* dprops;
};
@@ -80,6 +99,8 @@ struct Flash_fwd_params : public Qkv_params {
template
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
+template
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
} // namespace flash
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index 87831d1eddfe9..805a73be96778 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -34,24 +34,37 @@ void set_params_fprop(Flash_fwd_params& params,
void* p_d,
void* softmax_lse_d,
float softmax_scale,
- bool is_causal) {
+ bool is_causal,
+ bool kv_bsnh = true) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
params.v_ptr = v;
params.o_ptr = out;
- // All stride are in elements, not bytes.
- params.q_row_stride = num_heads * head_size;
- params.k_row_stride = num_heads_k * head_size;
- params.v_row_stride = num_heads * head_size;
- params.q_head_stride = head_size;
- params.k_head_stride = head_size;
- params.v_head_stride = head_size;
- params.o_row_stride = num_heads * head_size;
- params.o_head_stride = head_size;
params.is_bf16 = false;
+ // All stride are in elements, not bytes.
+ if (kv_bsnh) {
+ params.q_row_stride = num_heads * head_size;
+ params.k_row_stride = num_heads_k * head_size;
+ params.v_row_stride = num_heads_k * head_size;
+ params.q_head_stride = head_size;
+ params.k_head_stride = head_size;
+ params.v_head_stride = head_size;
+ params.o_row_stride = num_heads * head_size;
+ params.o_head_stride = head_size;
+ } else {
+ params.q_row_stride = num_heads * head_size;
+ params.k_row_stride = head_size;
+ params.v_row_stride = head_size;
+ params.q_head_stride = head_size;
+ params.k_head_stride = seqlen_k * head_size;
+ params.v_head_stride = seqlen_k * head_size;
+ params.o_row_stride = num_heads * head_size;
+ params.o_head_stride = head_size;
+ }
+
if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
@@ -90,6 +103,7 @@ void set_params_fprop(Flash_fwd_params& params,
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.is_causal = is_causal;
+ params.is_seqlens_k_cumulative = true;
}
size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) {
@@ -97,14 +111,85 @@ size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) {
return bytes;
}
-void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) {
+size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) {
+ size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads;
+ return bytes;
+}
+
+size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) {
+ size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded;
+ return bytes;
+}
+
+void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) {
FP16_SWITCH(!params.is_bf16, [&] {
FWD_HEADDIM_SWITCH(params.d, [&] {
- run_mha_fwd_(params, stream);
+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
+ run_mha_fwd_(params, stream);
+ } else {
+ run_mha_fwd_splitkv_dispatch(params, stream);
+ }
});
});
}
+// Find the number of splits that maximizes the occupancy. For example, if we have
+// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
+// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
+// splits as that would incur more HBM reads/writes.
+// So we find the best efficiency, then find the smallest number of splits that gets 85%
+// of the best efficiency.
+int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs,
+ int max_splits, bool new_kv, bool is_sm8x) {
+ // This needs to match with run_mha_fwd_splitkv_dispatch
+ const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64))
+ : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64));
+ const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n;
+ // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
+ // In any case we don't expect seqlen_q to be larger than 64 for inference.
+ const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
+ int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks;
+ // If we have enough to almost fill the SMs, then just use 1 split
+ if (batch_nheads_mblocks >= 0.8f * num_SMs) {
+ return 1;
+ }
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
+ float max_efficiency = 0.f;
+ std::vector efficiency;
+ efficiency.reserve(max_splits);
+ auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+ // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
+ // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
+ // (i.e. it's 11 splits anyway).
+ // So we check if the number of blocks per split is the same as the previous num_splits.
+ auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
+ return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
+ };
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+ if (!is_split_eligible(num_splits)) {
+ efficiency.push_back(0.f);
+ } else {
+ float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
+ float eff = n_waves / ceil(n_waves);
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
+ if (eff > max_efficiency) {
+ max_efficiency = eff;
+ }
+ efficiency.push_back(eff);
+ }
+ }
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+ if (!is_split_eligible(num_splits)) {
+ continue;
+ }
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
+ // printf("num_splits chosen = %d\n", num_splits);
+ return num_splits;
+ }
+ }
+ return 1;
+}
+
Status mha_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
@@ -119,7 +204,11 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
float softmax_scale,
- bool is_causal) {
+ bool is_causal,
+ int num_splits,
+ void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
+ void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+ bool kv_bsnh) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -139,7 +228,26 @@ Status mha_fwd(const cudaDeviceProp& dprops,
nullptr,
softmax_lse,
softmax_scale,
- is_causal);
+ is_causal,
+ kv_bsnh);
+
+ params.knew_ptr = nullptr;
+ params.vnew_ptr = nullptr;
+ params.knew_batch_stride = 0;
+ params.vnew_batch_stride = 0;
+ params.knew_row_stride = 0;
+ params.vnew_row_stride = 0;
+ params.knew_head_stride = 0;
+ params.vnew_head_stride = 0;
+
+ params.num_splits = num_splits;
+ if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
+ params.softmax_lseaccum_ptr = softmax_lse_accum;
+ params.oaccum_ptr = out_accum;
+ } else {
+ params.softmax_lseaccum_ptr = nullptr;
+ params.oaccum_ptr = nullptr;
+ }
run_mha_fwd(params, stream);
return Status::OK();
@@ -192,6 +300,101 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0);
}
+// This API is used when past key and value are present... since cached, these are assumed to have sequence length
+// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_.
+Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
+ void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
+ void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* out, // batch_size x seqlen_q x num_heads x head_size
+ void* softmax_lse, // batch_size x num_heads x seqlen_q
+ void* seqlens_k_, // batch_size
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int seqlen_q,
+ int seqlen_k,
+ int seqlen_k_new,
+ const float softmax_scale,
+ bool is_causal,
+ bool past_bsnh, // otherwise bnsh
+ int num_splits,
+ void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
+ void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+) {
+ if (seqlen_q == 1) {
+ is_causal = false;
+ } // causal=true is the same as causal=false in this case
+
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = round_multiple(head_size, 32);
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+ Flash_fwd_params params;
+ params.dprops = &dprops;
+ set_params_fprop(params,
+ batch_size,
+ seqlen_q, seqlen_k,
+ seqlen_q_rounded, seqlen_k_rounded,
+ num_heads, num_heads_k,
+ head_size, head_size_rounded,
+ q, kcache, vcache, out,
+ /*cu_seqlens_q_d=*/nullptr,
+ /*cu_seqlens_k_d=*/nullptr,
+ /*p_ptr=*/nullptr,
+ softmax_lse,
+ softmax_scale,
+ is_causal,
+ past_bsnh);
+
+ if (k != nullptr && v != nullptr) {
+ params.seqlen_knew = seqlen_k_new;
+ params.knew_ptr = k;
+ params.vnew_ptr = v;
+ // All stride are in elements, not bytes.
+ params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.knew_row_stride = num_heads_k * head_size;
+ params.vnew_row_stride = num_heads_k * head_size;
+ params.knew_head_stride = head_size;
+ params.vnew_head_stride = head_size;
+ } else {
+ params.seqlen_knew = 0;
+ params.knew_ptr = nullptr;
+ params.vnew_ptr = nullptr;
+ params.knew_batch_stride = 0;
+ params.vnew_batch_stride = 0;
+ params.knew_row_stride = 0;
+ params.vnew_row_stride = 0;
+ params.knew_head_stride = 0;
+ params.vnew_head_stride = 0;
+ }
+
+ params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
+ if (seqlens_k_ != nullptr) {
+ params.cu_seqlens_k = static_cast(seqlens_k_);
+ }
+
+ params.num_splits = num_splits;
+ if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
+ params.softmax_lseaccum_ptr = softmax_lse_accum;
+ params.oaccum_ptr = out_accum;
+ } else {
+ params.softmax_lseaccum_ptr = nullptr;
+ params.oaccum_ptr = nullptr;
+ }
+
+ // Only split kernel supports appending to KV cache
+ run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
+
+ return Status::OK();
+}
+
} // namespace flash
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 2ae46d34c373a..0a0328edb0059 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -34,6 +34,7 @@
namespace onnxruntime {
namespace flash {
+
Status mha_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
@@ -48,7 +49,11 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int seqlen_q,
int seqlen_k,
float softmax_scale,
- bool is_causal);
+ bool is_causal,
+ int num_splits = 0,
+ void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
+ void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+ bool kv_bsnh = true);
Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
@@ -68,7 +73,36 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
float softmax_scale,
bool is_causal);
+Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
+ void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
+ void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
+ void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
+ void* out, // batch_size x seqlen_q x num_heads x head_size
+ void* softmax_lse, // batch_size x num_heads x seqlen_q
+ void* seqlens_k_, // batch_size
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int seqlen_q,
+ int seqlen_k,
+ int seqlen_k_new,
+ const float softmax_scale,
+ bool is_causal,
+ bool past_bsnh, // otherwise bnsh
+ int num_splits = 0,
+ void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
+ void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+);
+
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
+size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q);
+size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded);
+
+int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x);
bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k);
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
index b5af31e432d42..eb1c794d6df54 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
@@ -79,7 +79,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T
flash::reduce_sum(scores, scores_sum);
} else {
cute::Tensor scores_max_prev = make_fragment_like(scores_max);
- copy(scores_max, scores_max_prev);
+ cute::copy(scores_max, scores_max_prev);
flash::template reduce_max*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@@ -109,7 +109,7 @@ inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, T
template
inline __device__ void write_softmax_to_gmem(
- cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_thr_copy_P) {
+ cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
cute::Layout l = tOrP.layout();
cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
@@ -117,7 +117,7 @@ inline __device__ void write_softmax_to_gmem(
CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) {
- copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
+ cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
@@ -147,6 +147,45 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
+ // We exit early and write 0 to gO and gLSE.
+ // Otherwise we might read OOB elements from gK and gV.
+ if (n_block_max <= 0) {
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o),
+ Shape, Int>{},
+ make_stride(params.o_row_stride, _1{}));
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
+ Shape>{}, Stride<_1>{});
+
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+ Tensor tOrO = make_tensor(shape(tOgO));
+ clear(tOrO);
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+ Tensor tOpO = make_tensor(make_shape(size<2>(tOgO)));
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) {
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+ }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+ for (int m = 0; m < size<1>(tOgO); ++m) {
+ const int row = get<0>(tOcO(0, m, 0));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) {
+ gLSE(row) = INFINITY;
+ }
+ }
+ return;
+ }
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
@@ -504,6 +543,494 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
+template
+inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
+ using Element = typename Kernel_traits::Element;
+ using ElementAccum = typename Kernel_traits::ElementAccum;
+ using index_t = typename Kernel_traits::index_t;
+
+ // Shared memory.
+ extern __shared__ char smem_[];
+
+ // The thread index.
+ const int tidx = threadIdx.x;
+
+ constexpr int kBlockM = Kernel_traits::kBlockM;
+ constexpr int kBlockN = Kernel_traits::kBlockN;
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
+ constexpr int kNWarps = Kernel_traits::kNWarps;
+
+ using GmemTiledCopyO = std::conditional_t<
+ !Split,
+ typename Kernel_traits::GmemTiledCopyOaccum,
+ typename Kernel_traits::GmemTiledCopyO>;
+ using ElementO = std::conditional_t;
+
+ const BlockInfo*Varlen=*/!Is_even_MN> binfo(params, bidb);
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
+ // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+
+ const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
+ const int n_block_min = n_split_idx * n_blocks_per_split;
+ int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
+ if (Is_causal) {
+ n_block_max = std::min(n_block_max,
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
+ }
+ if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
+ // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
+ // Otherwise we might read OOB elements from gK and gV,
+ // or get wrong results when we combine gOaccum from different blocks.
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded;
+ const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+ Shape, Int>{},
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+ Shape>{}, Stride<_1>{});
+
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+ Tensor tOrOaccum = make_tensor(shape(tOgOaccum));
+ clear(tOrOaccum);
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
+ Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) {
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+ }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy(
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+ for (int m = 0; m < size<1>(tOgOaccum); ++m) {
+ const int row = get<0>(tOcO(0, m, 0));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) {
+ gLSEaccum(row) = Split ? -INFINITY : INFINITY;
+ }
+ }
+ return;
+ }
+
+ // We iterate over the blocks in reverse order. This is because the last block is the only one
+ // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
+ // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
+
+ const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
+ // We move K and V to the last block.
+ const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+ const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+ const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+ const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+
+ Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q),
+ Shape, Int>{},
+ make_stride(params.q_row_stride, _1{}));
+ Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k),
+ Shape, Int>{},
+ make_stride(params.k_row_stride, _1{}));
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
+ Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v),
+ Shape, Int>{},
+ make_stride(params.v_row_stride, _1{}));
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+ // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+ // This maps to accessing the first 64 rows of knew_ptr.
+ Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+ Shape, Int>{},
+ make_stride(params.knew_row_stride, _1{}));
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+ Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+ Shape, Int>{},
+ make_stride(params.vnew_row_stride, _1{}));
+
+ Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)),
+ typename Kernel_traits::SmemLayoutQ{});
+ Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
+ Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
+ Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+ Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+
+ Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+ Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+ Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+ Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
+ Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
+ Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+
+ typename Kernel_traits::TiledMma tiled_mma;
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
+ Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
+ Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
+ Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
+
+ Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K
+
+ //
+ // Copy Atom retiling
+ //
+
+ auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+ Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+ auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+ Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+ auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+ Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+ // TODO: this might need to change if we change the mma instruction in SM70
+ Tensor scores_max = make_tensor(Shape(acc_o)>>{});
+ Tensor scores_sum = make_fragment_like(scores_max);
+
+ //
+ // PREDICATES
+ //
+
+ // // Allocate predicate tensors for m and n
+ // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
+ // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
+
+ // Construct identity layout for sQ and sK
+ Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+ // Repeat the partitioning with identity layouts
+ Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+ // Allocate predicate tensors for k
+ Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ)));
+ Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK)));
+
+ // Set predicates for k bounds
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < size(tQpQ); ++k) {
+ tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+ }
+#pragma unroll
+ for (int k = 0; k < size(tKVpKV); ++k) {
+ tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+ }
+ }
+
+ // Prologue
+
+ Tensor tQrQ = make_fragment_like(tQgQ);
+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+ flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
+
+ int n_block = n_block_max - 1;
+ // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ cute::cp_async_fence();
+
+ // flash::cp_async_wait<0>();
+ // __syncthreads();
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
+ // __syncthreads();
+
+ clear(acc_o);
+
+ // For performance reason, we separate out two kinds of iterations:
+ // those that need masking on S, and those that don't.
+ // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
+ // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
+ // We will have at least 1 "masking" iteration.
+
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+ constexpr int n_masking_steps = !Is_causal
+ ? 1
+ : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+ for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
+ Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+
+ if constexpr (Append_KV) {
+ // if (cute::thread0()) { print(tKgK); }
+ // if (cute::thread0()) { print(tKsK); }
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+ flash::copy_w_min_idx(
+ tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ }
+ // __syncthreads();
+ // if (cute::thread0()) { print(tKgK); }
+ // __syncthreads();
+ }
+
+ // Advance gV
+ if (masking_step > 0) {
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ if (Append_KV) {
+ tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+ }
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN);
+ } else {
+ // Clear the smem tiles to account for predicated off loads
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ }
+ cute::cp_async_fence();
+
+ flash::gemm(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K);
+ // if (cute::thread0()) { print(acc_s); }
+
+ // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+ // if (cute::thread0()) { print(scores); }
+ // We don't put the masking before the matmul S = Q K^T because we don't clear sK
+ // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
+ // can produce Inf / NaN.
+ if (!Is_causal) {
+ if (!Is_even_MN) {
+ flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
+ }
+ } else {
+ flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16);
+ }
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
+ // __syncthreads();
+
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
+ if constexpr (Append_KV) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+ flash::copy_w_min_idx(
+ tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ }
+ }
+
+ if (n_block > n_block_min) {
+ // Advance gK
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ if (Append_KV) {
+ tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+ }
+ // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
+ binfo.seqlen_k_cache - (n_block - 1) * kBlockN);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ // We have key_padding_mask so we'll need to Check_inf
+ masking_step == 0
+ ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
+
+ // Convert scores from fp32 to fp16/bf16
+ Tensor rP = flash::convert_type(scores);
+ // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout()));
+
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ // if (cute::thread0()) { print(scores); }
+
+ // This check is at the end of the loop since we always have at least 1 iteration
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
+ --n_block;
+ break;
+ }
+ }
+
+ // These are the iterations where we don't need masking on S
+ for (; n_block >= n_block_min; --n_block) {
+ Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ if constexpr (Append_KV) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+ flash::copy_w_min_idx(
+ tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ }
+ }
+ // Advance gV
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ if (Append_KV) {
+ tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+ }
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+ gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN);
+ cute::cp_async_fence();
+
+ flash::gemm(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K);
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ if constexpr (Append_KV) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+ if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+ flash::copy_w_min_idx(
+ tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ }
+ }
+ if (n_block > n_block_min) {
+ // Advance gK
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ if (Append_KV) {
+ tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+ }
+ flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+ gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
+ binfo.seqlen_k_cache - (n_block - 1) * kBlockN);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+ softmax_rescale_o*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+ Tensor rP = flash::convert_type(scores);
+ // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout()));
+
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ }
+
+ // Epilogue
+
+ // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+ Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+ // if (cute::thread0()) { print(acc_o_rowcol); }
+ Tensor lse = make_fragment_like(scores_sum);
+#pragma unroll
+ for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
+ float sum = scores_sum(mi);
+ float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+ lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
+ float scale = inv_sum;
+#pragma unroll
+ for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
+ acc_o_rowcol(mi, ni) *= scale;
+ }
+ }
+ // if (cute::thread0()) { print(lse); }
+ // if (cute::thread0()) { print(acc_o_rowcol); }
+
+ Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
+ // Partition sO to match the accumulator partitioning
+ using SmemTiledCopyO = std::conditional_t<
+ !Split,
+ typename Kernel_traits::SmemCopyAtomO,
+ typename Kernel_traits::SmemCopyAtomOaccum>;
+ auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
+ auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor rO = flash::convert_type(acc_o);
+ Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
+ Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+ // sOaccum is larger than sQ, so we need to syncthreads here
+ // TODO: allocate enough smem for sOaccum
+ if constexpr (Split) {
+ __syncthreads();
+ }
+
+ cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
+
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded;
+ const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+ Shape, Int>{},
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+ Shape>{}, Stride<_1>{});
+ // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
+
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+
+ __syncthreads();
+
+ Tensor tOrOaccum = make_tensor(shape(tOgOaccum));
+ cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
+
+ Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
+ static_assert(decltype(size<0>(taccOcO))::value == 4);
+ // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+ Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
+ if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+ for (int mi = 0; mi < size(lse); ++mi) {
+ const int row = get<0>(taccOcO_row(mi));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+ gLSEaccum(row) = lse(mi);
+ }
+ }
+ }
+
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) {
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+ }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy(
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+ // __syncthreads();
+ // if (cute::thread0()) { print(tOgOaccum); }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
template
inline __device__ void compute_attn(const Params& params) {
const int m_block = blockIdx.x;
@@ -524,6 +1051,187 @@ inline __device__ void compute_attn(const Params& params) {
}
////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void compute_attn_splitkv(const Params& params) {
+ const int m_block = blockIdx.x;
+ // The block index for the batch.
+ const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
+ // The block index for the head.
+ const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
+ const int n_split_idx = Split ? blockIdx.y : 0;
+ const int num_n_splits = Split ? gridDim.y : 1;
+ flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void combine_attn_seqk_parallel(const Params& params) {
+ using Element = typename Kernel_traits::Element;
+ using ElementAccum = typename Kernel_traits::ElementAccum;
+ using index_t = typename Kernel_traits::index_t;
+ constexpr int kMaxSplits = 1 << Log_max_splits;
+ constexpr int kBlockM = 16;
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
+
+ static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
+ // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
+ static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32");
+ static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads");
+
+ // Shared memory.
+ // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
+ __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
+
+ // The thread and block index.
+ const int tidx = threadIdx.x;
+ const int bidx = blockIdx.x;
+
+ const index_t row_offset_lse = bidx * kBlockM;
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse),
+ Shape, Int>{},
+ make_stride(params.b * params.h * params.seqlen_q, _1{}));
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
+ Shape>{}, Stride<_1>{});
+ constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads;
+
+ // Read the LSE values from gmem and store them in shared memory, then tranpose them.
+ constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM;
+#pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
+ const int col = tidx % kBlockM;
+ ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
+ if (row < kMaxSplits) {
+ sLSE[row][col] = lse;
+ }
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
+ }
+ // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
+ __syncthreads();
+ Tensor lse_accum = make_tensor(Shape>{});
+ constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
+ // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
+ // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
+ // 16 rows, so each time we load we can load 8 rows).
+ // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
+ // static_assert(kThreadsPerSplit <= 32);
+ static_assert(kRowsPerLoadTranspose <= 32);
+ static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
+#pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+ const int col = tidx / kRowsPerLoadTranspose;
+ lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
+ }
+
+ // Compute the logsumexp of the LSE along the split dimension.
+ ElementAccum lse_max = lse_accum(0);
+#pragma unroll
+ for (int l = 1; l < kNLsePerThread; ++l) {
+ lse_max = max(lse_max, lse_accum(l));
+ }
+ MaxOp max_op;
+ lse_max = Allreduce::run(lse_max, max_op);
+ lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
+ float lse_sum = expf(lse_accum(0) - lse_max);
+#pragma unroll
+ for (int l = 1; l < kNLsePerThread; ++l) {
+ lse_sum += expf(lse_accum(l) - lse_max);
+ }
+ SumOp sum_op;
+ lse_sum = Allreduce::run(lse_sum, sum_op);
+ // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
+ // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
+ ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
+ if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
+ gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+ }
+// Store the scales exp(lse - lse_logsum) in shared memory.
+#pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+ const int col = tidx / kRowsPerLoadTranspose;
+ if (row < params.num_splits && col < kBlockM) {
+ sLSE[row][col] = expf(lse_accum(l) - lse_logsum);
+ }
+ }
+ __syncthreads();
+
+ const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum),
+ Shape, Int>{},
+ Stride, _1>{});
+ typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
+ Tensor tOrO = make_tensor(shape(tOgOaccum));
+ Tensor tOrOaccum = make_tensor(shape(tOgOaccum));
+ clear(tOrO);
+
+ // Predicates
+ Tensor cOaccum = make_identity_tensor(Shape, Int>{});
+ // Repeat the partitioning with identity layouts
+ Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
+ Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < size(tOpOaccum); ++k) {
+ tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
+ }
+ }
+// Load Oaccum in then scale and accumulate to O
+#pragma unroll 2
+ for (int split = 0; split < params.num_splits; ++split) {
+ flash::copy*Is_even_MN=*/false, Is_even_K>(
+ gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM);
+#pragma unroll
+ for (int m = 0; m < size<1>(tOrOaccum); ++m) {
+ int row = get<0>(tOcOaccum(0, m, 0));
+ ElementAccum lse_scale = sLSE[split][row];
+#pragma unroll
+ for (int k = 0; k < size<2>(tOrOaccum); ++k) {
+#pragma unroll
+ for (int i = 0; i < size<0>(tOrOaccum); ++i) {
+ tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
+ }
+ }
+ // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); }
+ }
+ tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
+ }
+ // if (cute::thread0()) { print(tOrO); }
+
+ Tensor rO = flash::convert_type(tOrO);
+// Write to gO
+#pragma unroll
+ for (int m = 0; m < size<1>(rO); ++m) {
+ const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
+ if (idx < params.b * params.h * params.seqlen_q) {
+ const int batch_idx = idx / (params.h * params.seqlen_q);
+ const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
+ // The index to the rows of Q
+ const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
+ auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
+#pragma unroll
+ for (int k = 0; k < size<2>(rO); ++k) {
+ if (Is_even_K || tOpOaccum(k)) {
+ const int col = get<1>(tOcOaccum(0, m, k));
+ Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
+ Shape(rO))::value>>{}, Stride<_1>{});
+ // TODO: Should check if this is using vectorized store, but it seems pretty fast
+ copy(rO(_, m, k), gO);
+ // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
+ // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k);
+ }
+ }
+ }
+ }
+}
+
} // namespace flash
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
index e633ef4d45fbb..e0be6b828f85d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
@@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn(params);
}
+template
+__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
+ flash::compute_attn_splitkv(params);
+}
+
+template
+__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
+ static_assert(Log_max_splits >= 1);
+ flash::combine_attn_seqk_parallel(params);
+}
+
template
void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
@@ -25,8 +36,6 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
- // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
- // for cu_seqlens_q as well.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
@@ -40,9 +49,7 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
// ORT_ENFORCE(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
- int ctas_per_sm;
- cudaOccupancyMaxActiveBlocksPerMultiprocessor(
- &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+ // int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
@@ -51,6 +58,72 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
});
}
+template
+void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
+ static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
+ static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
+ constexpr size_t smem_size = Kernel_traits::kSmemSize;
+ const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+ dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
+ const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+ const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+ BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+ BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+ BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+ BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+ // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+ // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
+ auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ;
+ // auto kernel = &flash_fwd_splitkv_kernel;
+ // auto kernel = &flash_fwd_splitkv_kernel;
+ if (smem_size >= 48 * 1024) {
+ cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+ }
+ kernel<<>>(params);
+ });
+ });
+ });
+ });
+ });
+ if (params.num_splits > 1) {
+ dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
+ BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+ if (params.num_splits <= 2) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 4) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 8) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 16) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 32) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 64) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ } else if (params.num_splits <= 128) {
+ flash_fwd_splitkv_combine_kernel<<>>(params);
+ }
+ });
+ }
+}
+
+template
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) {
+ bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+ constexpr int kBlockM = 64; // Fixed for all head dimensions
+ if (!is_sm8x) { // A100, H100
+ // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
+ // and for headdim 192 with block size 64 x 128.
+ constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
+ run_flash_splitkv_fwd>(params, stream);
+ } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
+ constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
+ run_flash_splitkv_fwd>(params, stream);
+ }
+}
+
template
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
constexpr int Headdim = 32;
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu
new file mode 100644
index 0000000000000..68ae2ea759813
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu
new file mode 100644
index 0000000000000..94564a6aba8f3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu
new file mode 100644
index 0000000000000..ec9e9e738c5b3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu
new file mode 100644
index 0000000000000..e6c4ff5d95584
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu
new file mode 100644
index 0000000000000..552966852cdbe
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu
new file mode 100644
index 0000000000000..e9f191a4828d6
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu
new file mode 100644
index 0000000000000..d628a556680ad
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu
new file mode 100644
index 0000000000000..88b6cc0fb1e22
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu
@@ -0,0 +1,15 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
index 0c967faa85c45..134f159e258c4 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
@@ -111,7 +111,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape, Int>{}));
- using SmemCopyAtomO = Copy_Atom;
+ using SmemCopyAtomO = Copy_Atom;
+ using SmemCopyAtomOaccum = Copy_Atom;
static constexpr int kSmemQCount = cute::size(SmemLayoutQ{});
static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2;
@@ -139,18 +140,28 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy>;
using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{},
GmemLayoutAtom{},
- Layout>{})); // Val layout, 8 vals per read
+ cute::Layout>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{},
GmemLayoutAtom{},
- Layout>{})); // Val layout, 8 vals per store
+ cute::Layout>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
- using GmemLayoutAtomP = Layout, Int>,
- Stride, _1>>;
+ using GmemLayoutAtomP = cute::Layout, cute::Int>,
+ cute::Stride, _1>>;
using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{},
GmemLayoutAtomP{},
- Layout>{})); // Val layout, 8 vals per store
+ cute::Layout>{})); // Val layout, 8 vals per store
+
+ using GmemLayoutAtomOaccum = std::conditional_t<
+ kBlockKSmem == 32,
+ cute::Layout, // Thread layout, 8 threads per row
+ cute::Stride<_8, _1>>,
+ cute::Layout, // Thread layout, 16 threads per row
+ cute::Stride<_16, _1>>>;
+ using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{},
+ GmemLayoutAtomOaccum{},
+ cute::Layout>{})); // Val layout, 4 vals per store
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@@ -289,13 +300,13 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{});
static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{});
- static constexpr int kSmemdPsumCount = kBlockM;
+ // static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
- static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
+ // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
index 49ee687419d0e..02042e183f808 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
@@ -96,46 +96,6 @@ inline __device__ uint32_t convert_relu2(const float2 x) {
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-inline __device__ float2 half2_unpack(uint32_t a);
-
-template <>
-inline __device__ float2 half2_unpack<__half>(uint32_t a) {
- return __half22float2(reinterpret_cast<__half2(&)>(a));
-}
-
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
-template <>
-inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
- return __bfloat1622float2(reinterpret_cast<__nv_bfloat162(&)>(a));
-}
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-// Convert two half2's or bf162's into float, then take their dot product.
-template
-inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
- float2 af = flash::half2_unpack(a);
- float2 bf = flash::half2_unpack(b);
- return af.x * bf.x + af.y * bf.y;
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
-template
-inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
- float sum;
- sum = flash::hfma2_to_float(a.x, b.x);
- sum += flash::hfma2_to_float(a.y, b.y);
- sum += flash::hfma2_to_float(a.z, b.z);
- sum += flash::hfma2_to_float(a.w, b.w);
- return sum;
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
template
struct MaxOp {
__device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; }
@@ -245,7 +205,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
- return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+ // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
+ // "int_tuple.hpp(74): error: conversion to inaccessible base class"
+ // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+ return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -261,9 +224,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
- return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
- get<0, 1>(l),
- get<1, 1, 1>(l));
+ // TD [2023-08-13]: Same error as above on Cutlass 3.2
+ // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
+ // get<0, 1>(l),
+ // get<1, 1, 1>(l));
+ return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
+ get<1>(get<0>(l)),
+ get<1>(get<1>(get<1>(l))));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -338,7 +305,7 @@ CUTE_HOST_DEVICE void cp_async_wait() {
template
-inline __device__ void copy(TiledCopy thr_copy, Tensor const& S,
+inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S,
Tensor& D, Tensor const& identity_MN,
Tensor const& predicate_K, int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
@@ -354,13 +321,80 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor const&
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
- copy(thr_copy, S(_, m, k), D(_, m, k));
+ cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+ } else if (Clear_OOB_K) {
+ cute::clear(D(_, m, k));
+ }
+ }
+ } else if (Clear_OOB_MN) {
+ cute::clear(D(_, m, _));
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0,
+ Tensor const& S1,
+ Tensor& D, Tensor const& identity_MN,
+ Tensor const& predicate_K,
+ const int max_MN = 0, const int row_idx_switch = 0) {
+ CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
+ // There's no case where !Clear_OOB_K && Clear_OOB_MN
+ static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
+// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
+#pragma unroll
+ for (int m = 0; m < size<1>(S0); ++m) {
+ auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
+ if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+#pragma unroll
+ for (int k = 0; k < size<2>(S0); ++k) {
+ if (Is_even_K || predicate_K(k)) {
+ cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
- clear(D(_, m, k));
+ cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
- clear(D(_, m, _));
+ cute::clear(D(_, m, _));
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void copy_w_min_idx(Tensor const& S,
+ Tensor& D, Tensor const& identity_MN,
+ Tensor const& predicate_K,
+ const int max_MN = 0, const int min_MN = 0) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+#pragma unroll
+ for (int m = 0; m < size<1>(S); ++m) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+#pragma unroll
+ for (int k = 0; k < size<2>(S); ++k) {
+ if (Is_even_K || predicate_K(k)) {
+ cute::copy(S(_, m, k), D(_, m, k));
+ }
+ }
}
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
new file mode 100644
index 0000000000000..65d19d4473872
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -0,0 +1,185 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/platform/env_var_utils.h"
+#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
+#include "contrib_ops/cuda/bert/group_query_attention.h"
+#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
+// #include "contrib_ops/cpu/utils/console_dumper.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GroupQueryAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \
+ .MayInplace(3, 1) \
+ .MayInplace(4, 2) \
+ .InputMemoryType(OrtMemTypeCPUInput, 5), \
+ GroupQueryAttention);
+
+// REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+template
+GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
+ : CudaKernel(info) {
+ int64_t num_heads = 0;
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
+ num_heads_ = static_cast(num_heads);
+ kv_num_heads_ = static_cast(kv_num_heads);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 1) == 1;
+ is_past_bsnh_ = info.GetAttrOrDefault("is_past_bsnh", 1) == 1;
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+
+#if USE_FLASH_ATTENTION
+ disable_flash_attention_ = sizeof(T) != 2 ||
+ ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false);
+#else
+ disable_flash_attention_ = true;
+#endif
+}
+
+template
+Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* query = context->Input(0);
+ const Tensor* key = context->Input(1);
+ const Tensor* value = context->Input(2);
+ const Tensor* past_key = context->Input(3);
+ const Tensor* past_value = context->Input(4);
+ const Tensor* past_seq_len = context->Input(5);
+
+ auto& device_prop = GetDeviceProp();
+ GroupQueryAttentionParameters parameters;
+ typedef typename ToCudaType::MappedType CudaT;
+ GroupQueryAttentionData data;
+
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
+ key,
+ value,
+ past_key,
+ past_value,
+ ¶meters,
+ num_heads_,
+ kv_num_heads_,
+ past_seq_len,
+ is_past_bsnh_,
+ scale_,
+ device_prop.maxThreadsPerBlock));
+ parameters.is_unidirectional = is_unidirectional_;
+ int sequence_length = parameters.sequence_length;
+
+ TensorShapeVector output_shape(3);
+ output_shape[0] = static_cast(parameters.batch_size);
+ output_shape[1] = static_cast(sequence_length);
+ output_shape[2] = static_cast(parameters.hidden_size);
+ Tensor* output = context->Output(0, output_shape);
+
+ std::vector present_dims;
+ if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) {
+ present_dims = {
+ parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size};
+ } else { // BNSH
+ present_dims = {
+ parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size};
+ }
+ TensorShape present_shape(present_dims);
+ Tensor* present_key = context->Output(1, present_shape);
+ Tensor* present_value = context->Output(2, present_shape);
+
+#if USE_FLASH_ATTENTION
+ bool use_flash_attention = !disable_flash_attention_ &&
+ onnxruntime::flash::is_supported(device_prop,
+ parameters.head_size,
+ parameters.num_heads,
+ parameters.kv_num_heads);
+ // Allocate buffers
+ size_t softmax_lse_bytes = 0;
+ size_t softmax_lse_accum_bytes = 0;
+ size_t out_accum_bytes = 0;
+ size_t seqlens_k_bytes = 0;
+ if (use_flash_attention) {
+ softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads);
+ // split kv buffers
+ parameters.num_splits = onnxruntime::flash::num_splits_heuristic(
+ parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
+ parameters.head_size, device_prop.multiProcessorCount, 128, false,
+ device_prop.major == 8 && device_prop.minor > 0);
+ if (parameters.num_splits > 1) {
+ // softmax_lse_accum buffer
+ softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(
+ parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length);
+ // out_accum buffer
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = round_multiple(parameters.head_size, 32);
+ out_accum_bytes = onnxruntime::flash::get_out_accum_size(
+ parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded);
+ }
+ // seqlens_k buffer
+ if (past_key != nullptr) {
+ seqlens_k_bytes = sizeof(int) * parameters.batch_size;
+ }
+ }
+ auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream());
+ auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream());
+ auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream());
+ auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream());
+#else
+ constexpr bool use_flash_attention = false;
+ auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
+ auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
+ auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
+ auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr
+#endif
+
+ // only kernel implemented for gqa right now
+ ORT_ENFORCE(use_flash_attention);
+
+ data.query = reinterpret_cast(query->Data());
+ data.key = reinterpret_cast(key->Data());
+ data.value = reinterpret_cast(value->Data());
+ data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data());
+ data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data());
+ data.output = reinterpret_cast(output->MutableData());
+ data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast