Skip to content

Commit

Permalink
GQA Rotary and Packed QKV with Flash (#18906)
Browse files Browse the repository at this point in the history
### Description
These changes add rotary embedding and packed qkv input to gqa. As of
now, the changes are only supported with Flash-Attention (SM >= 80) but
should soon be supported with Memory Efficient Attention as well.



### Motivation and Context
With the fusion of rotary embedding into this Attention op, we hope to
observe some perf gain. The packed QKV should also provide some perf
gain in the context of certain models, like Llama2, that would benefit
from running ops on the fused QKV matrix, rather than the separate Q, K,
and V.

---------

Co-authored-by: Yufeng Li <[email protected]>
  • Loading branch information
2 people authored and rachguo committed Jan 30, 2024
1 parent d101450 commit 9343dac
Show file tree
Hide file tree
Showing 15 changed files with 1,517 additions and 272 deletions.
16 changes: 12 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>

#### Inputs
#### Inputs (7 - 9)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
Expand All @@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};

namespace attention {
Expand Down
51 changes: 34 additions & 17 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size

Check warning on line 358 in onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc#L358

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc:358:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size

Check warning on line 359 in onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc#L359

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc:359:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
Expand All @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size) {
// if (seqlen_q == 1) {
// is_causal = false;
// } // causal=true is the same as causal=false in this case

int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
Expand All @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_causal ? 0 : -1);
params.dprops = &dprops;

if (k != nullptr && v != nullptr) {
if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
params.knew_ptr = k;
params.vnew_ptr = v;
params.knew_ptr = k_new;
params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
if (is_packed_qkv) {
params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
} else {
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
}
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
Expand All @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
}

if (rotary_cos != nullptr) {
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
}

params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
Expand All @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
}

// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
Expand All @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1);
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
26 changes: 24 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
kv_num_heads_ = static_cast<int>(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

#if USE_FLASH_ATTENTION
Expand All @@ -62,6 +64,9 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_memory_efficient_attention_ = true;
#endif
if (!disable_flash_attention_) {
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
}
}

template <typename T>
Expand All @@ -73,6 +78,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);

auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
Expand All @@ -84,6 +91,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
value,
past_key,
past_value,
cos_cache,
sin_cache,
&parameters,
num_heads_,
kv_num_heads_,
Expand All @@ -93,7 +102,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;

TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
Expand Down Expand Up @@ -139,6 +154,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
Expand Down Expand Up @@ -182,8 +199,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_shape);

data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = reinterpret_cast<const CudaT*>(value->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
Expand Down Expand Up @@ -229,6 +246,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
bool is_unidirectional_;
bool is_past_bsnh_;
bool do_rotary_;
bool rotary_interleaved_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)

Check warning on line 33 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/group_query_attention.h#L33

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/group_query_attention.h:33:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
IAllocatorUniquePtr<int> zeros_;
};

} // namespace cuda
Expand Down
Loading

0 comments on commit 9343dac

Please sign in to comment.