diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9c31978c66486..da900e5c59405 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. - Supports different number of heads for q and kv. + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
kv_num_heads : int (required)
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
scale : float
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index b693b58c7c40a..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + int local_window_size; bool kv_share_buffer; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor - bool left_padding; // copies last token to last index if true + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 89e2351428d40..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_q_rounded = 0; int seqlen_k_rounded = 0; int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride = 0; index_t vnew_head_stride = 0; + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + bool is_bf16 = false; bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + int num_splits = 0; // For split-KV version const cudaDeviceProp* dprops = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 89a27c4d2b0d3..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, - bool kv_bsnh = true) { + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh) { + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - kv_bsnh); + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -) { - if (seqlen_q == 1) { - is_causal = false; - } // causal=true is the same as causal=false in this case + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - past_bsnh); + past_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; if (k != nullptr && v != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 58f4304251872..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh = true); + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -); + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index eb1c794d6df54..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), @@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); @@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params& params) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; @@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } -// Load Oaccum in then scale and accumulate to O -#pragma unroll 2 + // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 82dfa59b8f8e7..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,29 +10,30 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); #else (void)params; #endif @@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); } @@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } }); } @@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 134f159e258c4..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base { cute::Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 02042e183f808..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -307,7 +307,7 @@ template inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, - Tensor const& S1, +inline __device__ void copy_w_min_idx(Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, - const int max_MN = 0, const int row_idx_switch = 0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); -// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { + for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f21dff08e0350..93892169f6c79 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = true; - // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); - parameters.is_unidirectional = is_unidirectional_; - // parameters.left_padding = left_padding_; + parameters.local_window_size = local_window_size_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && + local_window_size_ == -1 && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index aade0436dc141..54a8127e29e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - // bool left_padding_; // shifts last token to end of buffer - bool is_unidirectional_; // causal + int local_window_size_; bool is_past_bsnh_; float scale_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d158155eeba9..b22ccb68c1e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } -// // Kernel to append new kv to kv buffer in place -// template -// __global__ void LeftPadLast(const int max_seqlen, -// T* kv_buff, -// const int* seqlens_k) { // refers to kv buff; otherwise bnsh -// const int h = threadIdx.x; -// const int n = blockIdx.x; -// const int b = blockIdx.y; - -// const int num_heads = gridDim.x; -// const int H = blockDim.x; - -// const int present_batch_stride = max_seqlen * num_heads * H; -// const int present_row_stride = num_heads * H; -// const int present_head_stride = H; - -// // kv_buff: BTNH or BNTH with buffered memory for new -// // new_kv: BLNH - -// const int s = seqlens_k[b]; - -// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; -// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; -// kv_buff[out_offset] = kv_buff[in_offset]; -// } - -// // Concat new to kv buffer in place -// template -// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, -// GroupQueryAttentionData& data, -// cudaStream_t stream, -// const int max_threads_per_block) { -// const int batch_size = parameters.batch_size; -// const int sequence_length = parameters.sequence_length; -// const int num_heads = parameters.num_heads; -// const int head_size = parameters.head_size; - -// // Indicates past sequence_length of each sequence -// const int* seqlens_k = reinterpret_cast(data.seqlens_k); - -// const int H = head_size / 4; -// const dim3 grid(num_heads, batch_size, 1); -// const dim3 block(H, 1, 1); -// LeftPadLast<<>>(sequence_length, -// reinterpret_cast(data.output), -// seqlens_k); -// return CUDA_CALL(cudaGetLastError()); -// } - ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -541,7 +492,7 @@ Status FlashAttention( void* key = reinterpret_cast(const_cast(data.key)); void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = parameters.is_unidirectional; + bool is_causal = true; // Note: seqlens_k is past sequence length for flash if (parameters.is_prompt) { @@ -579,7 +530,7 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } else { // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient @@ -611,13 +562,9 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, 0, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -704,9 +651,11 @@ Status EfficientAttention( p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; - p.causal = parameters.is_unidirectional; + p.causal = true; p.scale = scale; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; p.query = query; p.key = key; p.value = value; @@ -721,10 +670,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dcde2ddeb8270..a99bb36984538 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -991,7 +991,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. -Supports different number of heads for q and kv. +Supports different number of heads for q and kv. Only supports causal or local attention. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1004,10 +1004,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - // .Attr("left_padding_last_token", - // "Copy last token to last index of buffer. Default is 0; 1 when true.", - // AttributeProto::INT, - // OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 99f62ffdb9f53..8a839875de2a2 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -183,7 +183,9 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_prompt( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length nodes = [ @@ -202,6 +204,7 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -297,6 +300,26 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN config.head_size, ], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -310,7 +333,9 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN return model.SerializeToString() -def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_past( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length @@ -331,6 +356,7 @@ def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -636,8 +662,12 @@ def mha_func(q, k, v, config): return output -def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) +def gqa_prompt_func( + q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True +): + onnx_model_str = create_group_query_attention_graph_prompt( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None @@ -706,8 +736,12 @@ def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_forma return output, present_k, present_v -def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) +def gqa_past_func( + q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 +): + onnx_model_str = create_group_query_attention_graph_past( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() @@ -796,6 +830,28 @@ def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_paddi return col_idx > row_idx + sk - sq +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + def attention_ref( q, k, @@ -805,6 +861,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -817,6 +874,8 @@ def attention_ref( key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -826,6 +885,8 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + if causal: + window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() @@ -839,12 +900,24 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) - scores.masked_fill_(causal_mask, float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - if causal: # Some rows are completely masked out so we fill them with zero instead of NaN - attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) @@ -853,7 +926,6 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -957,6 +1029,8 @@ def parity_check_mha( def parity_check_gqa_prompt( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1007,6 +1081,15 @@ def parity_check_gqa_prompt( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1033,14 +1116,18 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_prompt_func( + q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1052,6 +1139,10 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1080,6 +1171,8 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1112,6 +1205,15 @@ def parity_check_gqa_prompt_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = new_k.clone() v_cache_ref = new_v.clone() @@ -1132,14 +1234,18 @@ def parity_check_gqa_prompt_no_buff( new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_prompt_func( + q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1179,6 +1285,8 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1228,6 +1336,14 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) # Pytorch to compare k_cache_ref = k.clone() @@ -1253,14 +1369,18 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1274,6 +1394,10 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, + " local:", + local, " B:", config.batch_size, " S:", @@ -1300,6 +1424,8 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1351,6 +1477,15 @@ def parity_check_gqa_past_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1378,14 +1513,18 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1401,142 +1540,10 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", - "past kv format:", - "BSNH" if past_format == Formats.BSNH else "BNSH", - " B:", - config.batch_size, - " S:", - config.sequence_length, - " kv S:", - config.kv_sequence_length, - " N:", - config.num_heads, - " kv N:", - config.kv_num_heads, - " h:", - config.head_size, - " Mean Error:", - numpy.mean(numpy.abs(out - out_ref)), - numpy.allclose( - out, - out_ref, - rtol=rtol, - atol=atol, - equal_nan=True, - ), - ) - - -def parity_check_gqa_past_no_buff_no_mask( - config, - past_format=Formats.BSNH, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = None - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - # Make sure past-present buffer updating correctly - if past_format == Formats.BSNH: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - else: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - - # Compare results - print( - "Unbuffered", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1663,10 +1670,11 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1725,24 +1733,25 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() - # test_gqa = TestGQA() - # test_gqa.test_gqa_past()