Skip to content

Commit

Permalink
Aciddelgado/gqa local (#18375)
Browse files Browse the repository at this point in the history
### Description
Implement preliminary version of local (sliding window) attention.
Currently only supported by Flash Attention (sm >= 80, Linux). Currently
only supports sliding attention with a large cached kv.



### Motivation and Context
This change enables to run Mistral and other models which use sliding
window attention.
  • Loading branch information
aciddelgado authored Nov 16, 2023
1 parent 6a4e448 commit adb56df
Show file tree
Hide file tree
Showing 15 changed files with 682 additions and 537 deletions.
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr

Group Query Self/Cross Attention.

Supports different number of heads for q and kv.
Supports different number of heads for q and kv. Only supports causal or local attention.

#### Version

Expand All @@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters {
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool left_padding; // copies last token to last index if true
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params {
int seqlen_q_rounded = 0;
int seqlen_k_rounded = 0;
int d_rounded = 0;
int rotary_dim = 0;

// The scaling factors for the kernel.
float scale_softmax = 0.0;
Expand All @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params {
index_t knew_head_stride = 0;
index_t vnew_head_stride = 0;

// The cos and sin matrices for rotary embedding.
void* __restrict__ rotary_cos_ptr = nullptr;
void* __restrict__ rotary_sin_ptr = nullptr;

// The indices to index into the KV cache.
int* __restrict__ cache_batch_idx = nullptr;

// Local window size
int window_size_left = -1;
int window_size_right = -1;

bool is_bf16 = false;
bool is_causal = false;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative = true;

bool is_rotary_interleaved = false;

int num_splits = 0; // For split-KV version

const cudaDeviceProp* dprops = nullptr;
Expand Down
44 changes: 34 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params,
void* softmax_lse_d,
float softmax_scale,
bool is_causal,
bool kv_bsnh = true) {
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
Expand Down Expand Up @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params,
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;

// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
// local and causal, meaning when we have local window size
params.is_causal = is_causal;
if (is_causal && (window_size_left >= 0 || window_size_right != 0)) {
params.is_causal = false;
}
if (window_size_left < 0 && window_size_right >= 0) {
window_size_left = seqlen_k;
}
if (window_size_left >= 0 && window_size_right < 0) {
window_size_right = seqlen_k;
}
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;

params.is_seqlens_k_cumulative = true;
}

Expand Down Expand Up @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
bool kv_bsnh) {
bool kv_bsnh,
int local_window_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand All @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal,
kv_bsnh);
kv_bsnh,
local_window_size,
is_causal ? 0 : -1);
params.dprops = &dprops;
params.knew_ptr = nullptr;
params.vnew_ptr = nullptr;
Expand Down Expand Up @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
nullptr,
softmax_lse,
softmax_scale,
is_causal);
is_causal,
true,
-1,
is_causal ? 0 : -1);
params.dprops = &dprops;
params.num_splits = 0;
params.softmax_lseaccum_ptr = nullptr;
Expand Down Expand Up @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
bool past_bsnh, // otherwise bnsh
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
) {
if (seqlen_q == 1) {
is_causal = false;
} // causal=true is the same as causal=false in this case
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size) {
// if (seqlen_q == 1) {
// is_causal = false;
// } // causal=true is the same as causal=false in this case

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
Expand All @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
softmax_lse,
softmax_scale,
is_causal,
past_bsnh);
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
params.dprops = &dprops;

if (k != nullptr && v != nullptr) {
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
bool kv_bsnh = true);
bool kv_bsnh = true,
int local_window_size = -1);

Status mha_varlen_fwd(const cudaDeviceProp& dprops,
cudaStream_t stream,
Expand Down Expand Up @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
bool past_bsnh, // otherwise bnsh
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
);
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
Loading

0 comments on commit adb56df

Please sign in to comment.