From 05566a001090dd4c62b1f0276e4762dd02419886 Mon Sep 17 00:00:00 2001
From: aciddelgado <139922440+aciddelgado@users.noreply.github.com>
Date: Thu, 16 Nov 2023 15:01:06 -0800
Subject: [PATCH] Aciddelgado/gqa local (#18375)
### Description
Implement preliminary version of local (sliding window) attention.
Currently only supported by Flash Attention (sm >= 80, Linux). Currently
only supports sliding attention with a large cached kv.
### Motivation and Context
This change enables to run Mistral and other models which use sliding
window attention.
---
docs/ContribOperators.md | 4 +-
.../contrib_ops/cpu/bert/attention_common.h | 4 +-
.../cuda/bert/flash_attention/flash.h | 15 +
.../cuda/bert/flash_attention/flash_api.cc | 44 +-
.../cuda/bert/flash_attention/flash_api.h | 7 +-
.../bert/flash_attention/flash_fwd_kernel.h | 375 +++++++++---------
.../flash_fwd_launch_template.h | 117 +++---
.../cuda/bert/flash_attention/kernel_traits.h | 9 +-
.../cuda/bert/flash_attention/softmax.h | 23 +-
.../cuda/bert/flash_attention/utils.h | 164 ++++++--
.../cuda/bert/group_query_attention.cc | 14 +-
.../cuda/bert/group_query_attention.h | 3 +-
.../cuda/bert/group_query_attention_impl.cu | 67 +---
.../core/graph/contrib_ops/bert_defs.cc | 10 +-
.../python/transformers/test_flash_attn.py | 363 ++++++++---------
15 files changed, 682 insertions(+), 537 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9c31978c66486..da900e5c59405 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Group Query Self/Cross Attention.
- Supports different number of heads for q and kv.
+ Supports different number of heads for q and kv. Only supports causal or local attention.
#### Version
@@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- kv_num_heads : int (required)
- Number of attention heads for k and v
+- local_window_size : int
+- left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
- num_heads : int (required)
- Number of attention heads for q
- scale : float
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index b693b58c7c40a..a7f83469a768d 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters {
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
+ int local_window_size;
bool kv_share_buffer;
- bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
- bool left_padding; // copies last token to last index if true
+ bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
index 89e2351428d40..cbe536c6ce45a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
@@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params {
int seqlen_q_rounded = 0;
int seqlen_k_rounded = 0;
int d_rounded = 0;
+ int rotary_dim = 0;
// The scaling factors for the kernel.
float scale_softmax = 0.0;
@@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params {
index_t knew_head_stride = 0;
index_t vnew_head_stride = 0;
+ // The cos and sin matrices for rotary embedding.
+ void* __restrict__ rotary_cos_ptr = nullptr;
+ void* __restrict__ rotary_sin_ptr = nullptr;
+
+ // The indices to index into the KV cache.
+ int* __restrict__ cache_batch_idx = nullptr;
+
+ // Local window size
+ int window_size_left = -1;
+ int window_size_right = -1;
+
bool is_bf16 = false;
bool is_causal = false;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative = true;
+
+ bool is_rotary_interleaved = false;
+
int num_splits = 0; // For split-KV version
const cudaDeviceProp* dprops = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index 89a27c4d2b0d3..76190aad68fdb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params,
void* softmax_lse_d,
float softmax_scale,
bool is_causal,
- bool kv_bsnh = true) {
+ bool kv_bsnh = true,
+ int window_size_left = -1,
+ int window_size_right = -1) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
@@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params,
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
+ // local and causal, meaning when we have local window size
params.is_causal = is_causal;
+ if (is_causal && (window_size_left >= 0 || window_size_right != 0)) {
+ params.is_causal = false;
+ }
+ if (window_size_left < 0 && window_size_right >= 0) {
+ window_size_left = seqlen_k;
+ }
+ if (window_size_left >= 0 && window_size_right < 0) {
+ window_size_right = seqlen_k;
+ }
+ params.window_size_left = window_size_left;
+ params.window_size_right = window_size_right;
+
params.is_seqlens_k_cumulative = true;
}
@@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- bool kv_bsnh) {
+ bool kv_bsnh,
+ int local_window_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal,
- kv_bsnh);
+ kv_bsnh,
+ local_window_size,
+ is_causal ? 0 : -1);
params.dprops = &dprops;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
@@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
nullptr,
softmax_lse,
softmax_scale,
- is_causal);
+ is_causal,
+ true,
+ -1,
+ is_causal ? 0 : -1);
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
@@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
bool past_bsnh, // otherwise bnsh
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
- void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
-) {
- if (seqlen_q == 1) {
- is_causal = false;
- } // causal=true is the same as causal=false in this case
+ void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+ int local_window_size) {
+ // if (seqlen_q == 1) {
+ // is_causal = false;
+ // } // causal=true is the same as causal=false in this case
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
@@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal,
- past_bsnh);
+ past_bsnh,
+ local_window_size,
+ is_causal ? 0 : -1);
params.dprops = &dprops;
if (k != nullptr && v != nullptr) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 58f4304251872..efc1f565c4fa0 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- bool kv_bsnh = true);
+ bool kv_bsnh = true,
+ int local_window_size = -1);
Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
@@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
bool past_bsnh, // otherwise bnsh
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
- void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
-);
+ void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
+ int local_window_size = -1);
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
index eb1c794d6df54..028233f66850f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
@@ -29,47 +29,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-CUTE_HOST_DEVICE auto
-make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(cute::Layout, cute::Int>,
- cute::Stride<_1, cute::Int>>{},
- make_layout(cute::size<2>(TileShape_MNK{})));
-
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-CUTE_HOST_DEVICE auto
-make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(cute::Layout, cute::Int>,
- cute::Stride<_1, cute::Int>>{},
- // TODO: Shouldn't this be size<1>?
- make_layout(cute::size<2>(TileShape_MNK{})));
- // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
template
inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum,
Tensor2& acc_o, float softmax_scale_log2) {
@@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
@@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
const BlockInfo*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+ const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
- if (Is_causal) {
- n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
+ if (Is_causal || Is_local) {
+ n_block_max = std::min(n_block_max,
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
- if (n_block_max <= 0) {
+ if (n_block_max <= n_block_min) {
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o),
@@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
-
cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q),
cute::Shape, cute::Int>{},
make_stride(params.q_row_stride, _1{}));
@@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
- constexpr int n_masking_steps = !Is_causal
+ constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
- : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N)
@@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
- if (!Is_causal) {
+ if (!Is_causal && !Is_local) {
if (!Is_even_MN) {
flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
}
} else {
// I can't get the stride from idx_row
- flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
- // m_block * kBlockM + get<0>(idx_row(0)),
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q,
- kNWarps * 16);
+ flash::apply_mask_local*HasWSLeft=*/Is_local>(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ // m_block * kBlockM + get<0>(idx_row(0)),
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right);
}
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
@@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
- ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
- : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
cute::Tensor rP = flash::convert_type(scores);
@@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// This check is at the end of the loop since we always have at least 1 iteration
- if (n_masking_steps > 1 && n_block <= 0) {
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
- for (; n_block >= 0; --n_block) {
+ for (; n_block >= n_block_min; --n_block) {
cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
@@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
@@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
- cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- softmax_rescale_o*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+ if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+ flash::apply_mask_local(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right);
+ }
+ softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
cute::Tensor rP = flash::convert_type(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
@@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
@@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
- const int n_block_min = n_split_idx * n_blocks_per_split;
+ const int n_block_min = !Is_local
+ ? n_split_idx * n_blocks_per_split
+ : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
- if (Is_causal) {
+ if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max,
- cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
@@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
- const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
- const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
- const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
- const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+ const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
+ const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+ const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q),
Shape, Int>{},
@@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v),
Shape, Int>{},
make_stride(params.v_row_stride, _1{}));
- // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
- // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
- // This maps to accessing the first 64 rows of knew_ptr.
- Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
- Shape, Int>{},
- make_stride(params.knew_row_stride, _1{}));
- // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
- Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
- Shape, Int>{},
- make_stride(params.vnew_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)),
typename Kernel_traits::SmemLayoutQ{});
@@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
- Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
- Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
- Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
- Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
+ Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma;
@@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
}
// Prologue
+ // Copy from Knew to K, optionally apply rotary embedding.
+ typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
+ auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
+ auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
+ if constexpr (Append_KV) {
+ // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
+ // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
+ // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
+ const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+ Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+ Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+ Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+ Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+ // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
+ // if (cute::thread(8, 0)) { print_tensor(gCos); }
+ // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
+
+ const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+ const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+ // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+ // This maps to accessing the first 64 rows of knew_ptr.
+ Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+ Shape, Int>{},
+ make_stride(params.knew_row_stride, _1{}));
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+ Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+ Shape, Int>{},
+ make_stride(params.vnew_row_stride, _1{}));
+ Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
+
+ const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+ for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
+ flash::copy_w_min_idx(
+ tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+ if (params.rotary_dim == 0) {
+ flash::copy_w_min_idx(
+ tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ } else {
+ if (params.is_rotary_interleaved) {
+ // Don't clear OOB_K because we're writing to global memory
+ flash::copy_rotary_interleaved(
+ tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim);
+ tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
+ tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
+ } else {
+ // Don't clear OOB_K because we're writing to global memory
+ flash::copy_rotary_contiguous(
+ tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim);
+ tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+ tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+ }
+ }
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+ }
+ // Need this before we can read in K again, so that we'll see the updated K values.
+ __syncthreads();
+ if (n_block_max > n_block_copy_min) {
+ tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
+ tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
+ }
+ }
+ // Read Q from gmem to smem, optionally apply rotary embedding.
Tensor tQrQ = make_fragment_like(tQgQ);
- // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
- flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
- binfo.actual_seqlen_q - m_block * kBlockM);
+ if (!Append_KV || params.rotary_dim == 0) {
+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+ flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
+ } else {
+ const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+ // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
+ // We do this by setting the row stride of gCos / gSin to 0.
+ Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape, Int>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+ Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+ Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+ Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+ if (params.is_rotary_interleaved) {
+ flash::copy_rotary_interleaved(
+ tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+ 0, params.d, params.rotary_dim);
+ } else {
+ flash::copy_rotary_contiguous(
+ tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+ 0, params.d, params.rotary_dim);
+ }
+ }
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
- flash::copy_2_sources*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
- binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// flash::cp_async_wait<0>();
@@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
- constexpr int n_masking_steps = !Is_causal
+ constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
- : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N)
@@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
flash::cp_async_wait<0>();
__syncthreads();
- if constexpr (Append_KV) {
- // if (cute::thread0()) { print(tKgK); }
- // if (cute::thread0()) { print(tKsK); }
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
- flash::copy_w_min_idx(
- tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
- }
- // __syncthreads();
- // if (cute::thread0()) { print(tKgK); }
- // __syncthreads();
- }
-
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- if (Append_KV) {
- tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
- }
- flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN);
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
- flash::copy_2_sources*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
- binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
+ flash::copy(
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
}
cute::cp_async_fence();
@@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
- if (!Is_causal) {
+ if (!Is_causal && !Is_local) {
if (!Is_even_MN) {
flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
}
} else {
- flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q,
- kNWarps * 16);
+ flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right);
}
flash::cp_async_wait<0>();
@@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
// __syncthreads();
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
- if constexpr (Append_KV) {
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
- flash::copy_w_min_idx(
- tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
- }
- }
-
if (n_block > n_block_min) {
// Advance gK
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- if (Append_KV) {
- tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
- }
- // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
- flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
- binfo.seqlen_k_cache - (n_block - 1) * kBlockN);
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
// We have key_padding_mask so we'll need to Check_inf
masking_step == 0
- ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
- : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
@@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
- if constexpr (Append_KV) {
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
- flash::copy_w_min_idx(
- tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
- }
- }
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- if (Append_KV) {
- tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
- }
- flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
- gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN);
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm(
@@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
flash::cp_async_wait<0>();
__syncthreads();
- if constexpr (Append_KV) {
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
- if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
- flash::copy_w_min_idx(
- tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN);
- }
- }
if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- if (Append_KV) {
- tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
- }
- flash::copy_2_sources*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
- gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
- binfo.seqlen_k_cache - (n_block - 1) * kBlockN);
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- softmax_rescale_o*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+ flash::apply_mask_local(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right);
+ }
+ softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
@@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
inline __device__ void compute_attn(const Params& params) {
const int m_block = blockIdx.x;
// The block index for the batch.
@@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
- flash::compute_attn_1rowblock(params, bidb, bidh, m_block);
+ flash::compute_attn_1rowblock(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
inline __device__ void compute_attn_splitkv(const Params& params) {
const int m_block = blockIdx.x;
// The block index for the batch.
@@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1;
- flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+ flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
inline __device__ void combine_attn_seqk_parallel(const Params& params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int kMaxSplits = 1 << Log_max_splits;
- constexpr int kBlockM = 16;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
+ constexpr int kNThreads = Kernel_traits::kNThreads;
static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
- // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
- static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32");
- static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads");
+ static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
+ static_assert(kNThreads == 128, "We assume that each block has 128 threads");
// Shared memory.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
@@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
make_stride(params.b * params.h * params.seqlen_q, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
Shape>{}, Stride<_1>{});
- constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads;
+ constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
- constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM;
+ constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
@@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum),
Shape, Int>{},
Stride, _1>{});
- typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+ constexpr int kBlockN = kNThreads / kBlockM;
+ using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>;
+ using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{},
+ GmemLayoutAtomOaccum{},
+ Layout>{})); // Val layout, 4 vals per store
+ GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrO = make_tensor(shape(tOgOaccum));
@@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
}
}
-// Load Oaccum in then scale and accumulate to O
-#pragma unroll 2
+ // Load Oaccum in then scale and accumulate to O
for (int split = 0; split < params.num_splits; ++split) {
flash::copy*Is_even_MN=*/false, Is_even_K>(
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM);
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
index 82dfa59b8f8e7..87d189a803f8a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
@@ -10,29 +10,30 @@
namespace onnxruntime {
namespace flash {
-template
+template
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
+ static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- flash::compute_attn(params);
+ flash::compute_attn(params);
#else
(void)params;
#endif
}
-template
+template
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- flash::compute_attn_splitkv(params);
+ flash::compute_attn_splitkv(params);
#else
(void)params;
#endif
}
-template
+template
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(Log_max_splits >= 1);
- flash::combine_attn_seqk_parallel(params);
+ flash::combine_attn_seqk_parallel(params);
#else
(void)params;
#endif
@@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
- // Will only return softmax if dropout, to reduce compilation time.
- auto kernel = &flash_fwd_kernel;
- // auto kernel = &flash_fwd_kernel;
- if (smem_size >= 48 * 1024) {
- cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
- // ORT_ENFORCE(cudaFuncSetAttribute(
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- }
- // int ctas_per_sm;
- // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
- // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
- kernel<<>>(params);
+ BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
+ // Will only return softmax if dropout, to reduce compilation time.
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+ // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+ // If Is_local, set Is_causal to false
+ auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
+ // auto kernel = &flash_fwd_kernel;
+ if (smem_size >= 48 * 1024) {
+ cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+ // ORT_ENFORCE(cudaFuncSetAttribute(
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+ }
+ // int ctas_per_sm;
+ // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+ // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+ kernel<<>>(params);
+ });
});
});
}
@@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
- BOOL_SWITCH(params.num_splits > 1, Split, [&] {
- BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
- // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
- // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
- auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ;
- // auto kernel = &flash_fwd_splitkv_kernel;
- // auto kernel = &flash_fwd_splitkv_kernel;
- if (smem_size >= 48 * 1024) {
- cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
- }
- kernel<<>>(params);
+ BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
+ BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+ BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+ // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+ // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
+ auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
+ // auto kernel = &flash_fwd_splitkv_kernel;
+ // auto kernel = &flash_fwd_splitkv_kernel;
+ if (smem_size >= 48 * 1024) {
+ cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+ }
+ kernel<<>>(params);
+ });
});
});
});
});
});
if (params.num_splits > 1) {
- dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
+ // We want kBlockM to be as small as possible for more parallelism.
+ // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
+ // If headdim is divisible by 64, then we set kBlockM = 8, etc.
+ constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
+ dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 4) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 8) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 16) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 32) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 64) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
} else if (params.num_splits <= 128) {
- flash_fwd_splitkv_combine_kernel<<>>(params);
+ flash_fwd_splitkv_combine_kernel<<>>(params);
}
});
}
@@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream)
template
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr int Headdim = 32;
+ constexpr static int Headdim = 32;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd, Is_causal>(params, stream);
});
@@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
template
void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr int Headdim = 64;
+ constexpr static int Headdim = 64;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
@@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) {
template
void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr int Headdim = 128;
- const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+ constexpr static int Headdim = 128;
+ bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
@@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) {
template
void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr int Headdim = 160;
- const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
+ constexpr static int Headdim = 160;
+ bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) {
template
void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr size_t Headdim = 224;
- constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64);
- size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin;
+ constexpr static int Headdim = 224;
+ int max_smem_per_block = params.dprops->sharedMemPerBlockOptin;
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- if (max_smem_per_block >= threshold) { // 112 KB
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd, Is_causal>(params, stream);
} else {
run_flash_fwd, Is_causal>(params, stream);
@@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) {
template
void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) {
- constexpr size_t Headdim = 256;
- constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64);
- constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64);
+ constexpr static int Headdim = 256;
size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor;
size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin;
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
- if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) {
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd, Is_causal>(params, stream);
} else {
run_flash_fwd, Is_causal>(params, stream);
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
index 134f159e258c4..1c0ed7f2fc2e8 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h
@@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base {
cute::Stride<_16, _1>>>;
using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{},
GmemLayoutAtomOaccum{},
- cute::Layout>{})); // Val layout, 4 vals per store
+ Layout>{})); // Val layout, 4 vals per store
+ using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+ using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{},
+ GmemLayoutAtomRotcossin{},
+ Layout>{})); // Val layout, 4 vals per load
+ using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{},
+ GmemLayoutAtomRotcossin{},
+ Layout>{})); // Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
index 842edf3a98a86..8017f83bbb01d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_
}
}
-template
-inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_,
- const int max_seqlen_k, const int row_idx_offset_,
- const int max_seqlen_q, const int warp_row_stride) {
+template
+inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset_,
+ const int max_seqlen_q, const int warp_row_stride,
+ const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
- const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
+ const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+ const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
- if (col_idx >= col_idx_limit) {
+ if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i
}
}
+template
+inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset_,
+ const int max_seqlen_q, const int warp_row_stride) {
+ // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+ apply_mask_local*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_,
+ max_seqlen_q, warp_row_stride, -1, 0);
+}
+
template
inline __device__ void apply_mask_causal_w_idx(
Tensor& tensor, Tensor const& idx_rowcol,
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
index 02042e183f808..271112c5e890a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h
@@ -307,7 +307,7 @@ template
inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S,
Tensor& D, Tensor const& identity_MN,
- Tensor const& predicate_K, int max_MN = 0) {
+ Tensor const& predicate_K, const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
@@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0,
- Tensor const& S1,
+inline __device__ void copy_w_min_idx(Tensor const& S,
Tensor& D, Tensor const& identity_MN,
Tensor const& predicate_K,
- const int max_MN = 0, const int row_idx_switch = 0) {
- CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
+ const int max_MN = 0, const int min_MN = 0) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
- CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
- CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
- CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
- // There's no case where !Clear_OOB_K && Clear_OOB_MN
- static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
-// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
-// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
- for (int m = 0; m < size<1>(S0); ++m) {
- auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
- if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+ for (int m = 0; m < size<1>(S); ++m) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
- for (int k = 0; k < size<2>(S0); ++k) {
+ for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
- cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+ cute::copy(S(_, m, k), D(_, m, k));
+ }
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void copy_rotary_interleaved(Tensor const& S,
+ Tensor& D,
+ Tensor const& Cos,
+ Tensor const& Sin,
+ Tensor const& identity_MN,
+ const int max_MN, const int min_MN,
+ const int dim, const int rotary_dim) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
+ static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
+ Tensor rCos = make_fragment_like(Cos);
+ Tensor rSin = make_fragment_like(Sin);
+ Tensor rS = make_fragment_like(S);
+#pragma unroll
+ for (int m = 0; m < size<1>(S); ++m) {
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+#pragma unroll
+ for (int k = 0; k < size<2>(S); ++k) {
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+ cute::copy(S(_, m, k), rS(_, m, k));
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+ cute::copy(Cos(_, m, k), rCos(_, m, k));
+ cute::copy(Sin(_, m, k), rSin(_, m, k));
+ Tensor S_fp32 = convert_type(rS(_, m, k));
+ Tensor cos_fp32 = convert_type(rCos(_, m, k));
+ Tensor sin_fp32 = convert_type(rSin(_, m, k));
+#pragma unroll
+ for (int i = 0; i < size<0>(rS) / 2; ++i) {
+ float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
+ float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
+ S_fp32(2 * i) = real;
+ S_fp32(2 * i + 1) = imag;
+ }
+ // Idk but I need to copy for the convert_type to work
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
+ cute::copy(S_fp32, S_fp32_copy);
+ using T = typename Engine0::value_type;
+ Tensor S_og_type = convert_type(S_fp32_copy);
+ cute::copy(S_og_type, rS(_, m, k));
+ }
+ cute::copy(rS(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
- } else if (Clear_OOB_MN) {
- cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-inline __device__ void copy_w_min_idx(Tensor const& S,
- Tensor& D, Tensor const& identity_MN,
- Tensor const& predicate_K,
- const int max_MN = 0, const int min_MN = 0) {
+inline __device__ void copy_rotary_contiguous(Tensor const& S,
+ Tensor& D,
+ Tensor const& Cos,
+ Tensor const& Sin,
+ Tensor const& identity_MN,
+ const int max_MN, const int min_MN,
+ const int dim, const int rotary_dim) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
- CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
- CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
- CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
-// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
+ Tensor rCos = make_fragment_like(Cos);
+ Tensor rSin = make_fragment_like(Sin);
+ Tensor rS = make_fragment_like(S);
+ Tensor rS_other = make_fragment_like(rS(_, 0, 0));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
- // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
-// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
- if (Is_even_K || predicate_K(k)) {
- cute::copy(S(_, m, k), D(_, m, k));
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+ cute::copy(S(_, m, k), rS(_, m, k));
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+ const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
+ Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
+ cute::copy(gS_other, rS_other);
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
+ Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
+ Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
+ cute::copy(gCos, rCos(_, m, k));
+ cute::copy(gSin, rSin(_, m, k));
+ // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
+ Tensor S_fp32 = convert_type(rS(_, m, k));
+ Tensor S_other_fp32 = convert_type(rS_other);
+ Tensor cos_fp32 = convert_type(rCos(_, m, k));
+ Tensor sin_fp32 = convert_type(rSin(_, m, k));
+#pragma unroll
+ for (int i = 0; i < size<0>(rS); ++i) {
+ S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
+ }
+ // Idk but I need to copy for the convert_type to work
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
+ cute::copy(S_fp32, S_fp32_copy);
+ using T = typename Engine0::value_type;
+ Tensor S_og_type = convert_type(S_fp32_copy);
+ cute::copy(S_og_type, rS(_, m, k));
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
+ }
+ cute::copy(rS(_, m, k), D(_, m, k));
+ } else if (Clear_OOB_K) {
+ cute::clear(D(_, m, k));
}
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index f21dff08e0350..93892169f6c79 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
num_heads_ = static_cast(num_heads);
kv_num_heads_ = static_cast(kv_num_heads);
- is_unidirectional_ = true;
- // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1;
is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1;
+ local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
scale_ = info.GetAttrOrDefault("scale", 0.0f);
#if USE_FLASH_ATTENTION
@@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
is_past_bsnh_,
scale_,
device_prop.maxThreadsPerBlock));
- parameters.is_unidirectional = is_unidirectional_;
- // parameters.left_padding = left_padding_;
+ parameters.local_window_size = local_window_size_;
int sequence_length = parameters.sequence_length;
TensorShapeVector output_shape(3);
@@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
+ local_window_size_ == -1 &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
@@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
data.k = reinterpret_cast(k_buffer.get());
data.v = reinterpret_cast(v_buffer.get());
}
+ if (k_buffer != nullptr) {
+ data.k = reinterpret_cast(k_buffer.get());
+ data.v = reinterpret_cast(v_buffer.get());
+ }
+ if (fmha_buffer != nullptr) {
+ data.fmha_buffer = reinterpret_cast(fmha_buffer.get());
+ }
cublasHandle_t cublas = GetCublasHandle(context);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index aade0436dc141..54a8127e29e7b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel {
protected:
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
- // bool left_padding_; // shifts last token to end of buffer
- bool is_unidirectional_; // causal
+ int local_window_size_;
bool is_past_bsnh_;
float scale_;
bool disable_flash_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 2d158155eeba9..b22ccb68c1e7b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i
return CUDA_CALL(cudaGetLastError());
}
-// // Kernel to append new kv to kv buffer in place
-// template
-// __global__ void LeftPadLast(const int max_seqlen,
-// T* kv_buff,
-// const int* seqlens_k) { // refers to kv buff; otherwise bnsh
-// const int h = threadIdx.x;
-// const int n = blockIdx.x;
-// const int b = blockIdx.y;
-
-// const int num_heads = gridDim.x;
-// const int H = blockDim.x;
-
-// const int present_batch_stride = max_seqlen * num_heads * H;
-// const int present_row_stride = num_heads * H;
-// const int present_head_stride = H;
-
-// // kv_buff: BTNH or BNTH with buffered memory for new
-// // new_kv: BLNH
-
-// const int s = seqlens_k[b];
-
-// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h;
-// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h;
-// kv_buff[out_offset] = kv_buff[in_offset];
-// }
-
-// // Concat new to kv buffer in place
-// template
-// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters,
-// GroupQueryAttentionData& data,
-// cudaStream_t stream,
-// const int max_threads_per_block) {
-// const int batch_size = parameters.batch_size;
-// const int sequence_length = parameters.sequence_length;
-// const int num_heads = parameters.num_heads;
-// const int head_size = parameters.head_size;
-
-// // Indicates past sequence_length of each sequence
-// const int* seqlens_k = reinterpret_cast(data.seqlens_k);
-
-// const int H = head_size / 4;
-// const dim3 grid(num_heads, batch_size, 1);
-// const dim3 block(H, 1, 1);
-// LeftPadLast<<>>(sequence_length,
-// reinterpret_cast(data.output),
-// seqlens_k);
-// return CUDA_CALL(cudaGetLastError());
-// }
-
////////// Launch Kernels
#if USE_FLASH_ATTENTION
@@ -541,7 +492,7 @@ Status FlashAttention(
void* key = reinterpret_cast(const_cast(data.key));
void* value = reinterpret_cast