Skip to content

Commit

Permalink
[CUDA] GroupQueryAttention operator using FlashAttention (#17674)
Browse files Browse the repository at this point in the history
### Description
Added Group Query Attention op, supporting integer multiple number of
heads for Q / KV. As of now, this op can only use FlashAttention kernel,
meaning it only supports sm>=80 on Linux.

Results from onnxruntime/test/python/transformers/benchmark_gqa.py show
an on-average ~37% speed-up over Decoder Masked Multi-Head Attention,
with even greater improvements for long past sequence lengths.

```
op      batch   s_kv    heads   h_dim   ms      TFLOPS
gqa     16      2048    8       32      0.34    0.10
dmmha   16      2048    8       32      0.39    0.09
---------
gqa     16      2048    8       64      0.45    0.15
dmmha   16      2048    8       64      0.61    0.11
---------
gqa     16      2048    8       128     0.54    0.25
dmmha   16      2048    8       128     0.83    0.16
---------
gqa     16      2048    16      32      0.45    0.15
dmmha   16      2048    16      32      0.69    0.10
---------
gqa     16      2048    16      64      0.69    0.19
dmmha   16      2048    16      64      0.83    0.16
---------
gqa     16      2048    16      128     0.71    0.38
dmmha   16      2048    16      128     1.28    0.21
---------
gqa     16      2048    32      32      0.58    0.23
dmmha   16      2048    32      32      0.77    0.17
---------
gqa     16      2048    32      64      0.58    0.46
dmmha   16      2048    32      64      1.25    0.21
---------
gqa     16      2048    32      128     0.76    0.71
dmmha   16      2048    32      128     2.15    0.25
---------
gqa     16      2048    64      32      0.68    0.39
dmmha   16      2048    64      32      1.23    0.22
---------
gqa     16      2048    64      64      0.77    0.70
dmmha   16      2048    64      64      2.11    0.25
---------
gqa     16      2048    64      128     1.10    0.97
dmmha   16      2048    64      128     4.06    0.26
---------
gqa     16      2048    128     32      1.00    0.54
dmmha   16      2048    128     32      2.09    0.26
---------
gqa     16      2048    128     64      1.10    0.97
dmmha   16      2048    128     64      4.08    0.26
```


### Motivation and Context
As of now, this op is targeted for use on LLama models, as it supports
kv-caching and different number of heads for Q and KV (Grouped Query
Attention). We plan to add support for more platforms, input formats,
etc. in the future.

---------

Co-authored-by: Tianlei Wu <[email protected]>
Co-authored-by: [email protected] <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent ba72bb6 commit 406cd32
Show file tree
Hide file tree
Showing 36 changed files with 3,448 additions and 191 deletions.
68 changes: 66 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Do not modify directly.*
* <a href="#com.microsoft.GreedySearch">com.microsoft.GreedySearch</a>
* <a href="#com.microsoft.GridSample">com.microsoft.GridSample</a>
* <a href="#com.microsoft.GroupNorm">com.microsoft.GroupNorm</a>
* <a href="#com.microsoft.GroupQueryAttention">com.microsoft.GroupQueryAttention</a>
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
Expand Down Expand Up @@ -1170,9 +1171,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dd>present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
</dl>

#### Type Constraints
Expand Down Expand Up @@ -2268,6 +2269,69 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.GroupQueryAttention"></a><a name="com.microsoft.groupqueryattention">**com.microsoft.GroupQueryAttention**</a>

Group Query Self/Cross Attention.

Supports different number of heads for q and kv.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>is_past_bsnh</tt> : int</dt>
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
</dl>

#### Inputs (3 - 6)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
</dl>

#### Outputs (1 - 3)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32), tensor(int64)</dt>
<dd>Constrain past sequence length to int tensor.</dd>
</dl>


### <a name="com.microsoft.Inverse"></a><a name="com.microsoft.inverse">**com.microsoft.Inverse**</a>

#### Version
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ struct PackedAttentionParameters {
bool broadcast_res_pos_bias;
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length;
int past_sequence_length; // actual sequence length of past_key and past_value
int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
int present_sequence_length; // past_sequence_length + kv_sequence_length
int max_sequence_length; // allocated length of past_key and past_value
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
bool is_unidirectional; // causal
float scale;
int num_splits; // number of splits for splitkv
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
};

namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o
T* k_smem = q_smem + rotary_embedding_dim;

const int half_rotary_dim = rotary_embedding_dim / 2;
const int half_idx = (head_idx) / half_rotary_dim;
const int intra_half_idx = (head_idx) % half_rotary_dim;
const int smem_pitch = half_rotary_dim;
const int half_idx = (head_idx) / half_rotary_dim;
const int intra_half_idx = (head_idx) % half_rotary_dim;
const int smem_pitch = half_rotary_dim;

if (do_rotary) {
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
Expand Down Expand Up @@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
}
}


template <typename T>
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
// Format 3 for cutlass memory efficient attention
Expand Down Expand Up @@ -651,7 +650,7 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
if (qk_head_size != 64 && qk_head_size !=128) {
if (qk_head_size != 64 && qk_head_size != 128) {
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
Expand Down
44 changes: 22 additions & 22 deletions onnxruntime/contrib_ops/cuda/bert/bert_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
const int* attention_masks,
const int batch_size,
const int sequence_length) {
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int batch_id = blockIdx.x;
const int* batch_mask = attention_masks + (batch_id * sequence_length);
const bool leftmost_non_zero = (batch_mask[0] != 0);
int biggest_position = 0;

for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
if (leftmost_non_zero == (batch_mask[i] != 0)) {
biggest_position = i;
} else {
break;
}
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

const int batch_id = blockIdx.x;
const int* batch_mask = attention_masks + (batch_id * sequence_length);
const bool leftmost_non_zero = (batch_mask[0] != 0);
int biggest_position = 0;

for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
if (leftmost_non_zero == (batch_mask[i] != 0)) {
biggest_position = i;
} else {
break;
}
}

int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);

if (threadIdx.x == 0) {
int batch_offset = batch_id * sequence_length;
trt_mha_padding_offset[2 * batch_id] = batch_offset;
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
if (batch_id == gridDim.x - 1) {
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
if (threadIdx.x == 0) {
int batch_offset = batch_id * sequence_length;
trt_mha_padding_offset[2 * batch_id] = batch_offset;
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
if (batch_id == gridDim.x - 1) {
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
}
}

// only support simple left padding with mask 0s on leading left,
Expand Down
13 changes: 6 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_
}

inline Status ComputeMaskIndex(cudaStream_t stream,
const int sequence_length,
const int batch_size,
const int* mask,
int* mask_index) {
const int sequence_length,
const int batch_size,
const int* mask,
int* mask_index) {
// Mask idx is of length batch_size and assumes the valid region is contiguous starting
// from the beginning of the sequence

Expand Down Expand Up @@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel(
}
if (nullptr == position_ids) {
position_id = blockIdx.x;
} else if (broadcast_position_ids){
} else if (broadcast_position_ids) {
position_id = position_ids[sequence_position % gridDim.x];
} else {
position_id = position_ids[sequence_position];
Expand Down Expand Up @@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel(
void* embedding_sum,
const int* position_ids,
const bool broadcast_position_ids) {

if (mask_index != nullptr) {
if (nullptr == input_mask) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream));
} else {
ORT_RETURN_IF_ERROR(
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));
}
}

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const float* input, const float* bias, float* output, bool /*use_half2*/) {
const float* input, const float* bias, float* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
FastGeluKernel<float, blockSize><<<gridSize, blockSize, 0, stream>>>(A, B, C, input_length, bias_length,
Expand All @@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const half* input, const half* bias, half* output, bool use_half2) {
const half* input, const half* bias, half* output, bool use_half2) {
constexpr int blockSize = 256;
if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) {
const int n = input_length / 2;
Expand All @@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int

template <>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;

// remove nv_bfloat162 implementation for now to fix build issue
Expand Down
12 changes: 9 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ struct BlockInfo {
template <typename Params>
__device__ BlockInfo(const Params& params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) {
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])),
actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
}

template <typename index_t>
Expand All @@ -30,6 +34,8 @@ struct BlockInfo {
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};

Expand Down
25 changes: 23 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct Qkv_params {
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void* __restrict__ o_ptr;
void* __restrict__ oaccum_ptr;

// The stride between rows of O.
index_t o_batch_stride;
Expand All @@ -56,9 +57,10 @@ struct Flash_fwd_params : public Qkv_params {

// The pointer to the softmax sum.
void* __restrict__ softmax_lse_ptr;
void* __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;

// The scaling factors for the kernel.
float scale_softmax;
Expand All @@ -70,16 +72,35 @@ struct Flash_fwd_params : public Qkv_params {

int* __restrict__ blockmask;

// The K_new and V_new matrices.
void* __restrict__ knew_ptr;
void* __restrict__ vnew_ptr;

// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;

bool is_bf16 = false;
bool is_causal;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
int num_splits; // For split-KV version

const cudaDeviceProp* dprops;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
template <typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream);

} // namespace flash
} // namespace onnxruntime
} // namespace onnxruntime
Loading

0 comments on commit 406cd32

Please sign in to comment.