diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 624cda1d37f73..e7b537d6894c8 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- do_rotary : int
+- Whether to use rotary position embedding. Default value is 0.
- kv_num_heads : int (required)
- Number of attention heads for k and v
- local_window_size : int
- left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
- num_heads : int (required)
- Number of attention heads for q
+- rotary_interleaved : int
+- Rotate using interleaved pattern. Default value is 0 (False).
- scale : float
- Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs
+#### Inputs (7 - 9)
- query : T
-- Query with shape (batch_size, sequence_length, hidden_size)
-- key : T
+- Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+- key (optional) : T
- Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-- value : T
+- value (optional) : T
- Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
- past_key (optional) : T
- past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- 1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
- total_sequence_length : M
- Scalar tensor of total sequence length (past + new).
+- cos_cache (optional) : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+- sin_cache (optional) : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 3b695af2839b6..31cca232fde34 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index da489a6901512..8afeb874750b4 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
+ bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
+ bool do_rotary;
+ bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
+ int zeros_count;
+ int* zero_ptr;
};
namespace attention {
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index d6eb87228bb4a..2c296bf4f8483 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
- void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
- void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
- void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
- void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+ void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+ void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
+ void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
+ void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- int local_window_size) {
- // if (seqlen_q == 1) {
- // is_causal = false;
- // } // causal=true is the same as causal=false in this case
-
+ int local_window_size,
+ bool is_rotary_interleaved,
+ bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+ // In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
@@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_causal ? 0 : -1);
params.dprops = &dprops;
- if (k != nullptr && v != nullptr) {
+ if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
- params.knew_ptr = k;
- params.vnew_ptr = v;
+ params.knew_ptr = k_new;
+ params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
- params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
- params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
- params.knew_row_stride = num_heads_k * head_size;
- params.vnew_row_stride = num_heads_k * head_size;
+ if (is_packed_qkv) {
+ params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ } else {
+ params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.knew_row_stride = num_heads_k * head_size;
+ params.vnew_row_stride = num_heads_k * head_size;
+ }
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
@@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.cu_seqlens_k = static_cast(seqlens_k_);
}
+ if (rotary_cos != nullptr) {
+ params.rotary_cos_ptr = rotary_cos;
+ params.rotary_sin_ptr = rotary_sin;
+ params.is_rotary_interleaved = is_rotary_interleaved;
+ params.rotary_dim = (head_size / 16) * 16;
+ }
+
params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
@@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
}
// Only split kernel supports appending to KV cache
- run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
+ run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 3d75d6834b8e0..387d1cf9d84fe 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
+ void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
+ void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- int local_window_size = -1);
+ int local_window_size = -1,
+ bool is_rotary_interleaved = false,
+ bool is_packed_qkv = false);
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index fd6fb79742cac..fe56f84f0a886 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
kv_num_heads_ = static_cast(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1;
local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault("scale", 0.0f);
#if USE_FLASH_ATTENTION
@@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_memory_efficient_attention_ = true;
#endif
+ if (!disable_flash_attention_) {
+ zeros_ = this->GetScratchBuffer(kZerosCount, nullptr);
+ }
}
template
@@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input(4);
const Tensor* seqlens_k = context->Input(5);
const Tensor* total_seqlen = context->Input(6);
+ const Tensor* cos_cache = context->Input(7);
+ const Tensor* sin_cache = context->Input(8);
auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
@@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
value,
past_key,
past_value,
+ cos_cache,
+ sin_cache,
¶meters,
num_heads_,
kv_num_heads_,
@@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
+ parameters.is_unidirectional = is_unidirectional_;
+ parameters.zeros_count = kZerosCount;
+ parameters.zero_ptr = zeros_.get();
+ // parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;
+ parameters.do_rotary = do_rotary_;
+ parameters.rotary_interleaved = rotary_interleaved_;
TensorShapeVector output_shape(3);
output_shape[0] = static_cast(parameters.batch_size);
@@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
+ do_rotary_ == false &&
+ key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
@@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_shape);
data.query = reinterpret_cast(query->Data());
- data.key = reinterpret_cast(key->Data());
- data.value = reinterpret_cast(value->Data());
+ data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data());
+ data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data());
data.output = reinterpret_cast(output->MutableData());
@@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast(fmha_buffer.get());
}
+ // Rotary
+ if (parameters.do_rotary) {
+ data.cos_cache = reinterpret_cast(cos_cache->Data());
+ data.sin_cache = reinterpret_cast(sin_cache->Data());
+ }
cublasHandle_t cublas = GetCublasHandle(context);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index 54a8127e29e7b..15573ece166fc 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
+ bool is_unidirectional_;
bool is_past_bsnh_;
+ bool do_rotary_;
+ bool rotary_interleaved_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
+ static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
+ IAllocatorUniquePtr zeros_;
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
index 2cb9955807f26..853e1a710cb24 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
@@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query,
bool is_past_bsnh,
float scale) {
// Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length
- // past_key : (B, N_k, S*, H) or (B, N_k, S-, H)
- // past_value : (B, N_k, S*, H) or (B, N_k, S-, H)
+ // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
+ // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
// no packing for q/k/v:
- // query (Q) : (B, S, D)
- // key (K) : (B, S, D_kv)
- // value (V) : (B, S, D_kv)
+ // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
+ // key (K) : (B, S, D_kv) or nullptr
+ // value (V) : (B, S, D_kv) or nullptr
ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
-
+ const bool is_packed_qkv = key == nullptr;
const auto& query_dims = query->Shape().GetDims();
- const auto& key_dims = key->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query,
int batch_size = static_cast(query_dims[0]);
int sequence_length = static_cast(query_dims[1]);
int q_hidden_size = static_cast(query_dims[2]);
- int head_size = static_cast(q_hidden_size) / num_heads;
+ int head_size = 0;
+
+ if (num_heads % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
+ num_heads % kv_num_heads);
+ }
- int kv_hidden_size = static_cast(key_dims[2]);
+ int kv_hidden_size = 0;
+ // Check key and value when not packed
+ if (!is_packed_qkv) {
+ head_size = static_cast(q_hidden_size) / num_heads;
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ key_dims.size());
+ } else if (query_dims[0] != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != key_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 1 (sequence length)");
+ }
+ kv_hidden_size = static_cast(key_dims[2]);
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ } else if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 1 (sequence length)");
+ } else if (value_dims[2] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
+ }
+ } else {
+ // Check packed qkv
+ head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads);
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ q_hidden_size = head_size * num_heads;
+ kv_hidden_size = head_size * kv_num_heads;
+ }
+ // Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
@@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}
- if (key_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
- key_dims.size());
- }
- if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- }
-
- if (num_heads % kv_num_heads != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
- num_heads % kv_num_heads);
- }
-
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
- value_dims.size());
- }
-
- if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch_size)");
- }
-
- if (static_cast(sequence_length) != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)");
- }
-
- if (value_dims[2] != kv_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
- }
-
// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
@@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query,
int total_sequence_length = *((*total_seqlen).template Data());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
+ if (cos_cache != nullptr && sin_cache != nullptr) {
+ const auto& cos_dims = cos_cache->Shape().GetDims();
+ const auto& sin_dims = sin_cache->Shape().GetDims();
+
+ if (head_size % 16 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size shall be a multiple of 16. Got head_size % 16 == ",
+ head_size % 16);
+ }
+ if (cos_dims[0] != present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 0 must be of present_sequence_length.");
+ }
+ if (sin_dims[0] != present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 0 must be of present_sequence_length.");
+ }
+ if (cos_dims[1] != (head_size / 16) * 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (sin_dims[1] != (head_size / 16) * 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ } else if (cos_cache != nullptr || sin_cache != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
+ }
+
bool is_prompt = sequence_length != 1;
if (parameters != nullptr) {
@@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query,
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
- output_parameters->head_size = q_hidden_size / num_heads;
+ output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
+ output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
@@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
+ return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
}
} // namespace group_query_attention_helper
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 5b0f5d0cfe601..d88e9a49fb5ee 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -151,9 +151,10 @@ template
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
cudaStream_t stream,
- const int max_threads_per_block) {
+ const int max_threads_per_block,
+ const bool past_only = false) {
const int batch_size = parameters.batch_size;
- const int kv_sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = past_only ? 0 : parameters.sequence_length;
const int past_sequence_length = parameters.seqlen_past_kv_cache;
const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int kv_num_heads = parameters.kv_num_heads;
@@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
return CUDA_CALL(cudaGetLastError());
}
-
__global__ void PastToTotalSeqlen(int32_t* seqlens_k,
int32_t* seqlens_k_buff,
const int add_seqlen) {
@@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
// Convert Past to Total sequence length tensor
Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
- const int threads_per_block) {
+ const int threads_per_block) {
if (parameters.is_prompt) {
return Status::OK();
}
@@ -482,91 +482,63 @@ Status FlashAttention(
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.sequence_length;
- const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
-
- void* query = reinterpret_cast(const_cast(data.query));
- void* key = reinterpret_cast(const_cast(data.key));
- void* value = reinterpret_cast(const_cast(data.value));
-
bool is_causal = true;
-
bool is_bf16 = std::is_same::value;
- // Note: seqlens_k is past sequence length for flash
- if (parameters.is_prompt) {
- // Launch kernel to copy seqlen
- constexpr int thr_per_blk = 256;
- int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk;
- repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size);
- }
-
- void* seqlens_k = reinterpret_cast(data.seqlens_k);
-
- if (parameters.kv_share_buffer) {
- // Share buffer case
- if (data.past_key == nullptr || data.past_key != data.present_key) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Past and present kv shall share the same tensor when kv_share_buffer is on.");
- }
-
- if (parameters.is_prompt) {
- ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
- key = nullptr;
- value = nullptr;
- seqlens_k = reinterpret_cast(data.seqlens_k_total);
- }
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1);
+ void* query = reinterpret_cast(const_cast(data.query));
+ void* key;
+ void* value;
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
- device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse),
- seqlens_k, batch_size, num_heads, kv_num_heads,
- head_size, sequence_length, present_sequence_length, kv_sequence_length,
- scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
- reinterpret_cast(data.out_accum), parameters.local_window_size));
+ if (!parameters.is_packed_qkv) {
+ key = reinterpret_cast(const_cast(data.key));
+ value = reinterpret_cast(const_cast(data.value));
} else {
- // Not share buffer case
- // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
- if (data.past_key != nullptr && data.past_key == data.present_key) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Past and present kv share the same tensor but kv_share_buffer is not on.");
- }
-
- ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
+ const size_t key_offset = static_cast(num_heads * head_size);
+ const size_t value_offset = static_cast(kv_num_heads * head_size);
+ key = reinterpret_cast(query) + key_offset;
+ value = reinterpret_cast(key) + value_offset;
+ }
- if (!parameters.is_prompt) {
- ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256));
+ void* seqlens_k = reinterpret_cast(data.seqlens_k);
+ if (parameters.is_prompt) {
+ // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value
+ // user should use seqlens_k to index into output to get new tokens
+ if (batch_size <= parameters.zeros_count) {
+ seqlens_k = parameters.zero_ptr;
+ } else {
+ // Launch kernel to create larger seqlen tensor when batch_size > 256
+ constexpr int thr_per_blk = 256;
+ int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
+ repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size);
+ seqlens_k = data.seqlens_k_total;
}
-
- seqlens_k = reinterpret_cast(data.seqlens_k_total);
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1);
- DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size);
- DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size);
- DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size);
-
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
- device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse),
- seqlens_k, batch_size, num_heads, kv_num_heads,
- head_size, sequence_length, present_sequence_length, 0,
- scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
- reinterpret_cast(data.out_accum), parameters.local_window_size));
+ } else if (!parameters.kv_share_buffer) { // copy past kv to present kv
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true));
}
+ void* present_key = reinterpret_cast(const_cast(data.present_key));
+ void* present_value = reinterpret_cast(const_cast(data.present_value));
+ void* cos_cache = reinterpret_cast(const_cast(data.cos_cache));
+ void* sin_cache = reinterpret_cast(const_cast(data.sin_cache));
+
+ bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, stream, query, present_key, present_value, key, value, data.output,
+ reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
+ batch_size, num_heads, kv_num_heads, head_size, sequence_length,
+ parameters.seqlen_present_kv_cache, kv_sequence_length,
+ scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
+ reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
+ parameters.is_packed_qkv));
+
+ // if (parameters.left_padding && parameters.is_prompt) {
+ // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
+ // }
+
DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
@@ -672,7 +644,6 @@ Status EfficientAttention(
p.has_custom_right_padding = true;
run_memory_efficient_attention(p);
- DUMP_TENSOR_INIT();
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index de32d7ea93163..1bf91f9c875eb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -21,6 +21,8 @@ struct GroupQueryAttentionData {
const T* past_key = nullptr;
const T* past_value = nullptr;
int* seqlens_k = nullptr;
+ const T* cos_cache = nullptr;
+ const T* sin_cache = nullptr;
// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 7f34647f1faef..8583474a1e391 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = query_dims[2];
updateOutputShape(ctx, 0, output_shape);
- } else {
- fail_shape_inference("Missing input 2 (value)");
}
}
if (ctx.getNumOutputs() > 1) { // has present output
if (hasInputShape(ctx, past_key_index)) {
+ // auto& query_shape = getInputShape(ctx, 0);
+ // auto& query_dims = query_shape.dim();
auto& past_shape = getInputShape(ctx, past_key_index);
auto& past_dims = past_shape.dim();
if (past_dims.size() != 4) {
@@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
}
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1);
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2);
- ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1);
- ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2);
+ // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
}
}
}
@@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"left_window_size for local attention (like Mistral). Default value is -1 meaning unused.",
AttributeProto::INT,
static_cast(-1))
+ .Attr("do_rotary",
+ "Whether to use rotary position embedding. Default value is 0.",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Attr("rotary_interleaved",
+ "Rotate using interleaved pattern. Default value is 0 (False).",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Input(0,
"query",
- "Query with shape (batch_size, sequence_length, hidden_size)",
+ "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape"
+ "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ",
- "T")
+ "T",
+ OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, kv_hidden_size)",
- "T")
+ "T",
+ OpSchema::Optional)
.Input(3,
"past_key",
"past state key with support for format BNSH. When past_key uses same tensor as present_key"
@@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"total_sequence_length",
"Scalar tensor of total sequence length (past + new).",
"M")
+ .Input(7,
+ "cos_cache",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "T",
+ OpSchema::Optional)
+ .Input(8,
+ "sin_cache",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py
new file mode 100644
index 0000000000000..42bff9c92b41b
--- /dev/null
+++ b/onnxruntime/test/python/transformers/rotary_flash.py
@@ -0,0 +1,693 @@
+# Copyright (c) 2023, Tri Dao.
+
+
+from typing import Optional, Tuple, Union
+
+import torch
+import triton
+import triton.language as tl
+from einops import rearrange, repeat
+
+##### TRITON KERNEL FOR ROTARY #####
+
+
+# @triton.autotune(
+# configs=[
+# triton.Config({"block_m": 2}),
+# triton.Config({"block_m": 4}),
+# triton.Config({"block_m": 8}),
+# triton.Config({"block_m": 16}),
+# ],
+# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
+# )
+@triton.jit
+def rotary_kernel(
+ out_, # Pointers to matrices
+ x_,
+ cos_,
+ sin_,
+ CU_SEQLENS,
+ SEQLEN_OFFSETS, # this could be int or a pointer
+ # Matrix dimensions
+ seqlen,
+ nheads,
+ rotary_dim,
+ seqlen_ro,
+ CACHE_KEY_SEQLEN,
+ # strides
+ stride_out_batch,
+ stride_out_seqlen,
+ stride_out_nheads,
+ stride_out_headdim,
+ stride_x_batch,
+ stride_x_seqlen,
+ stride_x_nheads,
+ stride_x_headdim,
+ # Meta-parameters
+ block_k: tl.constexpr,
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ INTERLEAVED: tl.constexpr,
+ CONJUGATE: tl.constexpr,
+ block_m: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+ rotary_dim_half = rotary_dim // 2
+
+ if not IS_VARLEN:
+ x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads
+ out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads
+ else:
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
+ out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
+
+ if pid_m * block_m >= seqlen:
+ return
+ rm = pid_m * block_m + tl.arange(0, block_m)
+ if not IS_SEQLEN_OFFSETS_TENSOR:
+ rm_cs = rm + SEQLEN_OFFSETS
+ else:
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
+ rk = tl.arange(0, block_k)
+ rk_half = tl.arange(0, block_k // 2)
+
+ if not INTERLEAVED:
+ # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_
+ x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
+ cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
+ sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
+ cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(
+ tl.float32
+ )
+ sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(
+ tl.float32
+ )
+ x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)
+ x1 = tl.load(
+ x_ + rotary_dim_half * stride_x_headdim,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ o0 = x0 * cos - x1 * sin
+ o1 = x0 * sin + x1 * cos
+ # write back result
+ out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
+ tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
+ tl.store(
+ out_ + rotary_dim_half * stride_out_headdim,
+ o1,
+ mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
+ )
+ else:
+ # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow.
+ # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...].
+ # Loading x0 will be fast but x1 will be slow.
+ # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...].
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
+ # and for the odd indices.
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
+ rk_repeat = tl.arange(0, block_k) // 2
+ x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
+ x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
+ cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ cos = tl.load(
+ cos_,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=1.0,
+ ).to(tl.float32)
+ sin = tl.load(
+ sin_,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=0.0,
+ ).to(tl.float32)
+ x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)
+ x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)
+ if CONJUGATE:
+ sin = -sin
+ x0_cos = x0 * cos
+ x1_sin = x1 * sin
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
+ out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
+ tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
+
+
+def apply_rotary(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ interleaved=False,
+ inplace=False,
+ conjugate=False,
+) -> torch.Tensor:
+ """
+ Arguments:
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim).
+ cos: (seqlen_ro, rotary_dim / 2)
+ sin: (seqlen_ro, rotary_dim / 2)
+ seqlen_offsets: integer or integer tensor of size (batch,)
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Returns:
+ y: (batch, seqlen, nheads, headdim)
+ """
+ is_varlen = cu_seqlens is not None
+ if not is_varlen:
+ batch, seqlen, nheads, headdim = x.shape
+ else:
+ assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
+ total_seqlen, nheads, headdim = x.shape
+ batch_p_1 = cu_seqlens.shape[0]
+ batch = batch_p_1 - 1
+ seqlen = max_seqlen
+ seqlen_ro, rotary_dim = cos.shape
+ assert sin.shape == cos.shape
+ rotary_dim *= 2
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
+ assert headdim <= 256, "Only support headdim <= 256"
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
+
+ assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
+ assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
+
+ cos, sin = cos.contiguous(), sin.contiguous()
+ if isinstance(seqlen_offsets, torch.Tensor):
+ assert seqlen_offsets.shape == (batch,)
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
+ seqlen_offsets = seqlen_offsets.contiguous()
+ else:
+ assert seqlen_offsets + seqlen <= seqlen_ro
+
+ output = torch.empty_like(x) if not inplace else x
+ if rotary_dim < headdim and not inplace:
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
+
+ block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
+ grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa
+ block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
+
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+ with torch.cuda.device(x.device.index):
+ rotary_kernel[grid](
+ output, # data ptrs
+ x,
+ cos,
+ sin,
+ cu_seqlens,
+ seqlen_offsets,
+ seqlen, # shapes
+ nheads,
+ rotary_dim,
+ seqlen_ro,
+ seqlen // 128, # key for triton cache (limit number of compilations)
+ output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ output.stride(-3), # seqlen_stride or total_seqlen_stride
+ output.stride(-2), # nheads_stride
+ output.stride(-1), # headdim_stride
+ x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ x.stride(-3), # seqlen stride or total_seqlen_stride
+ x.stride(-2), # nheads stride
+ x.stride(-1), # headdim stride
+ block_k,
+ isinstance(seqlen_offsets, torch.Tensor),
+ is_varlen,
+ interleaved,
+ conjugate,
+ block_m,
+ )
+ return output
+
+
+##### ROTARY API #####
+
+
+def rotate_half(x, interleaved=False):
+ if not interleaved:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+ else:
+ x1, x2 = x[..., ::2], x[..., 1::2]
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
+
+
+def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
+ """
+ x: (batch_size, seqlen, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
+ """
+ ro_dim = cos.shape[-1] * 2
+ assert ro_dim <= x.shape[-1]
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ return torch.cat(
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
+ dim=-1,
+ )
+
+
+class ApplyRotaryEmb(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ cos,
+ sin,
+ interleaved=False,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ out = apply_rotary(
+ x,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ interleaved=interleaved,
+ inplace=inplace,
+ )
+ if isinstance(seqlen_offsets, int):
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.interleaved = interleaved
+ ctx.inplace = inplace
+ ctx.max_seqlen = max_seqlen
+ return out if not inplace else x
+
+ @staticmethod
+ def backward(ctx, do):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
+ if not ctx.interleaved and not ctx.inplace:
+ do = do.clone()
+ dx = apply_rotary(
+ do,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ interleaved=ctx.interleaved,
+ inplace=ctx.inplace,
+ conjugate=True,
+ )
+ return dx, None, None, None, None, None, None, None
+
+
+def apply_rotary_emb(
+ x,
+ cos,
+ sin,
+ interleaved=False,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+):
+ """
+ Arguments:
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ inplace: if True, apply rotary embedding in-place.
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Return:
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding to the first rotary_dim of x.
+ """
+ return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen)
+
+
+# For backward compatibility
+apply_rotary_emb_func = apply_rotary_emb
+
+
+class ApplyRotaryEmbQKV(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ qkv,
+ cos,
+ sin,
+ cos_k=None,
+ sin_k=None,
+ interleaved=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ ):
+ batch, seqlen, three, nheads, headdim = qkv.shape
+ assert three == 3
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
+ # Call 1 kernel instead of 2 kernels
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
+ # dimensions, we get the same tensor
+ # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
+ qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
+ apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
+ else:
+ cos_k = cos if cos_k is None else cos_k
+ sin_k = sin if sin_k is None else sin_k
+ q, k = qkv[:, :, 0], qkv[:, :, 1]
+ apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
+ apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
+ if isinstance(seqlen_offsets, int):
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.interleaved = interleaved
+ return qkv
+
+ @staticmethod
+ def backward(ctx, dqkv):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin, cos_k, sin_k = ctx.saved_tensors
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
+ # Call 1 kernel instead of 2 kernels
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
+ # dimensions, we get the same tensor
+ dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
+ apply_rotary(
+ dqk,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ interleaved=ctx.interleaved,
+ inplace=True,
+ conjugate=True,
+ )
+ else:
+ cos_k = cos if cos_k is None else cos_k
+ sin_k = sin if sin_k is None else sin_k
+ dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
+ apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True)
+ apply_rotary(
+ dk,
+ cos_k,
+ sin_k,
+ seqlen_offsets,
+ interleaved=ctx.interleaved,
+ inplace=True,
+ conjugate=True,
+ )
+ return dqkv, None, None, None, None, None, None
+
+
+def apply_rotary_emb_qkv_(
+ qkv,
+ cos,
+ sin,
+ cos_k=None,
+ sin_k=None,
+ interleaved=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+):
+ """
+ Arguments:
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2)
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
+ 1st half and 2nd half (GPT-NeoX style).
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ Return:
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
+ """
+ return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
+
+
+class ApplyRotaryEmbKV(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
+ batch, seqlen, two, nheads, headdim = kv.shape
+ assert two == 2
+ k = kv[:, :, 0]
+ apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True)
+ if isinstance(seqlen_offsets, int):
+ ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.interleaved = interleaved
+ return kv
+
+ @staticmethod
+ def backward(ctx, dkv):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin = ctx.saved_tensors
+ apply_rotary(
+ dkv[:, :, 0],
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ interleaved=ctx.interleaved,
+ inplace=True,
+ conjugate=True,
+ )
+ return dkv, None, None, None, None
+
+
+apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply
+
+
+def apply_rotary_emb_kv_(
+ kv,
+ cos,
+ sin,
+ interleaved=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+):
+ """
+ Arguments:
+ kv: (batch_size, seqlen, 2, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
+ 1st half and 2nd half (GPT-NeoX style).
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ Return:
+ kv: (batch_size, seqlen, 2, nheads, headdim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
+ """
+ return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ The rotary position embeddings from RoFormer_ (Su et. al).
+ A crucial insight from the method is that the query and keys are
+ transformed by rotation matrices which depend on the relative positions.
+
+ Other implementations are available in the Rotary Transformer repo_ and in
+ GPT-NeoX_, GPT-NeoX was an inspiration
+
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
+
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ base=10000.0,
+ interleaved=False,
+ scale_base=None,
+ pos_idx_in_fp32=True,
+ device=None,
+ ):
+ """
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
+ otherwise they might be in lower precision.
+ This option was added because previously (before 2023-07-02), when we construct
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
+ self.inv_freq would be bf16, and the position indices are also in bf16.
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
+ embeddings for some positions will coincide.
+ To maintain compatibility with models previously trained in pure bf16,
+ we add this option.
+ """
+ super().__init__()
+ self.dim = dim
+ self.base = float(base)
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = self._compute_inv_freq(device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.interleaved = interleaved
+ self.scale_base = scale_base
+ scale = (
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
+ if scale_base is not None
+ else None
+ )
+ self.register_buffer("scale", scale, persistent=False)
+
+ self._seq_len_cached = 0
+ self._cos_cached = None
+ self._sin_cached = None
+ self._cos_k_cached = None
+ self._sin_k_cached = None
+
+ def _compute_inv_freq(self, device=None):
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
+
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
+ # Reset the tables if the sequence length has changed,
+ # if we're on a new device (possibly due to tracing for instance),
+ # or if we're switching from inference mode to training
+ if (
+ seqlen > self._seq_len_cached
+ or self._cos_cached is None
+ or self._cos_cached.device != device
+ or self._cos_cached.dtype != dtype
+ or (self.training and self._cos_cached.is_inference())
+ ):
+ self._seq_len_cached = seqlen
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
+ if self.pos_idx_in_fp32:
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
+ # cos & sin output to change significantly.
+ # We want to recompute self.inv_freq if it was not loaded in fp32
+ if self.inv_freq.dtype != torch.float32:
+ inv_freq = self._compute_inv_freq(device=device)
+ else:
+ inv_freq = self.inv_freq
+ else:
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
+ inv_freq = self.inv_freq
+ # Don't do einsum, it converts fp32 to fp16 under AMP
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ freqs = torch.outer(t, inv_freq)
+ if self.scale is None:
+ self._cos_cached = torch.cos(freqs).to(dtype)
+ self._sin_cached = torch.sin(freqs).to(dtype)
+ else:
+ power = (
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
+ ) / self.scale_base
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
+ # We want the multiplication by scale to happen in fp32
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
+
+ def forward(
+ self,
+ qkv: torch.Tensor,
+ kv: Optional[torch.Tensor] = None,
+ seqlen_offset: Union[int, torch.Tensor] = 0,
+ max_seqlen: Optional[int] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
+ else it's just q of shape (batch, seqlen, nheads, headdim)
+ kv: (batch, seqlen, 2, nheads, headdim)
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
+ Apply rotary embedding *inplace* to qkv and / or kv.
+ """
+ seqlen = qkv.shape[1]
+ if max_seqlen is not None:
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
+ elif isinstance(seqlen_offset, int):
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
+ if kv is None:
+ if self.scale is None:
+ return apply_rotary_emb_qkv_(
+ qkv,
+ self._cos_cached,
+ self._sin_cached,
+ interleaved=self.interleaved,
+ seqlen_offsets=seqlen_offset,
+ )
+ else:
+ return apply_rotary_emb_qkv_(
+ qkv,
+ self._cos_cached,
+ self._sin_cached,
+ self._cos_k_cached,
+ self._sin_k_cached,
+ interleaved=self.interleaved,
+ seqlen_offsets=seqlen_offset,
+ )
+ else:
+ q = qkv
+ q = apply_rotary_emb_func(
+ q,
+ self._cos_cached,
+ self._sin_cached,
+ interleaved=self.interleaved,
+ inplace=True,
+ seqlen_offsets=seqlen_offset,
+ )
+ if self.scale is None:
+ kv = apply_rotary_emb_kv_(
+ kv,
+ self._cos_cached,
+ self._sin_cached,
+ interleaved=self.interleaved,
+ seqlen_offsets=seqlen_offset,
+ )
+ else:
+ kv = apply_rotary_emb_kv_(
+ kv,
+ self._cos_k_cached,
+ self._sin_k_cached,
+ interleaved=self.interleaved,
+ seqlen_offsets=seqlen_offset,
+ )
+ return q, kv
diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py
index 8a839875de2a2..90d28872d3cc8 100644
--- a/onnxruntime/test/python/transformers/test_flash_attn.py
+++ b/onnxruntime/test/python/transformers/test_flash_attn.py
@@ -20,6 +20,7 @@
from bert_padding import pad_input, unpad_input
from einops import rearrange, repeat
from onnx import TensorProto, helper
+from rotary_flash import apply_rotary_emb
from onnxruntime import InferenceSession, OrtValue, SessionOptions
@@ -184,7 +185,13 @@ def create_multihead_attention_graph(config):
def create_group_query_attention_graph_prompt(
- config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1
+ config,
+ past_kv_format=Formats.BSNH,
+ share_buffer=True,
+ local_window_size=-1,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
):
past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0
present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length
@@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt(
"GroupQueryAttention",
[
"query",
- "key",
- "value",
+ "key" if not packed else "",
+ "value" if not packed else "",
"past_key" if share_buffer else "",
"past_value" if share_buffer else "",
"seqlens_k",
"total_sequence_length",
+ "cos_cache" if rotary else "",
+ "sin_cache" if rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=local_window_size,
+ do_rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
# is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0,
# kv_share_buffer=1 if share_buffer else 0,
domain="com.microsoft",
@@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt(
[
config.batch_size,
config.q_sequence_length,
- config.num_heads * config.head_size,
- ],
- ),
- helper.make_tensor_value_info(
- "key",
- TensorProto.FLOAT16,
- [
- config.batch_size,
- config.kv_sequence_length,
- config.kv_num_heads * config.head_size,
- ],
- ),
- helper.make_tensor_value_info(
- "value",
- TensorProto.FLOAT16,
- [
- config.batch_size,
- config.kv_sequence_length,
- config.kv_num_heads * config.head_size,
+ (config.num_heads * config.head_size)
+ if not packed
+ else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size),
],
),
helper.make_tensor_value_info(
@@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt(
[1],
),
]
+ if not packed:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "key",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.kv_sequence_length,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "value",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.kv_sequence_length,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ ]
if share_buffer:
graph_input += [
helper.make_tensor_value_info(
@@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt(
],
),
]
+ if rotary:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "cos_cache",
+ TensorProto.FLOAT16,
+ [
+ config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "sin_cache",
+ TensorProto.FLOAT16,
+ [
+ config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ ]
graph_output = [
helper.make_tensor_value_info(
@@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt(
def create_group_query_attention_graph_past(
- config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1
+ config,
+ past_kv_format=Formats.BSNH,
+ share_buffer=True,
+ local_window_size=-1,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
):
past_kv_seqlen = config.kv_sequence_length
present_kv_seqlen = (
@@ -345,18 +386,22 @@ def create_group_query_attention_graph_past(
"GroupQueryAttention",
[
"query",
- "key",
- "value",
+ "key" if not packed else "",
+ "value" if not packed else "",
"past_key",
"past_value",
"seqlens_k",
"total_sequence_length",
+ "cos_cache" if rotary else "",
+ "sin_cache" if rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=local_window_size,
+ do_rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
# is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0,
# kv_share_buffer=1 if share_buffer else 0,
domain="com.microsoft",
@@ -370,25 +415,9 @@ def create_group_query_attention_graph_past(
[
config.batch_size,
config.sequence_length,
- config.num_heads * config.head_size,
- ],
- ),
- helper.make_tensor_value_info(
- "key",
- TensorProto.FLOAT16,
- [
- config.batch_size,
- config.sequence_length,
- config.kv_num_heads * config.head_size,
- ],
- ),
- helper.make_tensor_value_info(
- "value",
- TensorProto.FLOAT16,
- [
- config.batch_size,
- config.sequence_length,
- config.kv_num_heads * config.head_size,
+ (config.num_heads * config.head_size)
+ if not packed
+ else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size),
],
),
helper.make_tensor_value_info(
@@ -411,8 +440,6 @@ def create_group_query_attention_graph_past(
config.head_size,
],
),
- ]
- graph_input += [
helper.make_tensor_value_info(
"seqlens_k",
TensorProto.INT32,
@@ -424,6 +451,46 @@ def create_group_query_attention_graph_past(
[1],
),
]
+ if not packed:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "key",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.sequence_length,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "value",
+ TensorProto.FLOAT16,
+ [
+ config.batch_size,
+ config.sequence_length,
+ config.kv_num_heads * config.head_size,
+ ],
+ ),
+ ]
+ if rotary:
+ graph_input += [
+ helper.make_tensor_value_info(
+ "cos_cache",
+ TensorProto.FLOAT16,
+ [
+ config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ helper.make_tensor_value_info(
+ "sin_cache",
+ TensorProto.FLOAT16,
+ [
+ config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
+ (math.floor(config.head_size / 16) * 16) // 2,
+ ],
+ ),
+ ]
graph_output = [
helper.make_tensor_value_info(
@@ -663,21 +730,38 @@ def mha_func(q, k, v, config):
def gqa_prompt_func(
- q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True
+ q,
+ k,
+ v,
+ config,
+ new_k,
+ new_v,
+ cos=None,
+ sin=None,
+ seqlens_k=None,
+ window_size=-1,
+ past_kv_format=Formats.BSNH,
+ share_buffer=True,
+ rotary_interleaved=False,
):
onnx_model_str = create_group_query_attention_graph_prompt(
- config, past_kv_format, share_buffer, local_window_size=window_size
+ config,
+ past_kv_format,
+ share_buffer,
+ local_window_size=window_size,
+ rotary=cos is not None,
+ rotary_interleaved=rotary_interleaved,
+ packed=new_k is None,
)
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
past_k = k.clone() if share_buffer else None
past_v = v.clone() if share_buffer else None
- new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
- new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
+ if new_k is not None:
+ new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1))
+ new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1))
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
- "key": new_k.detach().cpu().numpy(),
- "value": new_v.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0),
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
@@ -686,9 +770,17 @@ def gqa_prompt_func(
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
io_binding = ort_session.io_binding()
+ if new_k is not None:
+ ort_inputs["key"] = new_k.detach().cpu().numpy()
+ ort_inputs["value"] = new_v.detach().cpu().numpy()
+ io_binding.bind_cpu_input("key", ort_inputs["key"])
+ io_binding.bind_cpu_input("value", ort_inputs["value"])
+ if cos is not None:
+ ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
+ ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
+ io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
+ io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
- io_binding.bind_cpu_input("key", ort_inputs["key"])
- io_binding.bind_cpu_input("value", ort_inputs["value"])
io_binding.bind_input(
"past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
@@ -713,17 +805,23 @@ def gqa_prompt_func(
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
- "key": new_k.detach().cpu().numpy(),
- "value": new_v.detach().cpu().numpy(),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
io_binding = ort_session.io_binding()
+ if new_k is not None:
+ ort_inputs["key"] = new_k.detach().cpu().numpy()
+ ort_inputs["value"] = new_v.detach().cpu().numpy()
+ io_binding.bind_cpu_input("key", ort_inputs["key"])
+ io_binding.bind_cpu_input("value", ort_inputs["value"])
+ if cos is not None:
+ ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
+ ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
+ io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
+ io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
- io_binding.bind_cpu_input("key", ort_inputs["key"])
- io_binding.bind_cpu_input("value", ort_inputs["value"])
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"])
io_binding.bind_output("output")
@@ -737,21 +835,38 @@ def gqa_prompt_func(
def gqa_past_func(
- q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1
+ q,
+ k,
+ v,
+ config,
+ new_k,
+ new_v,
+ cos=None,
+ sin=None,
+ seqlens_k=None,
+ past_kv_format=Formats.BSNH,
+ share_buffer=True,
+ window_size=-1,
+ rotary_interleaved=False,
):
onnx_model_str = create_group_query_attention_graph_past(
- config, past_kv_format, share_buffer, local_window_size=window_size
+ config,
+ past_kv_format,
+ share_buffer,
+ local_window_size=window_size,
+ rotary=cos is not None,
+ rotary_interleaved=rotary_interleaved,
+ packed=new_k is None,
)
q = torch.reshape(q, (config.batch_size, config.sequence_length, -1))
past_k = k.clone()
past_v = v.clone()
- new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1))
- new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1))
+ if new_k is not None:
+ new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1))
+ new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1))
if share_buffer:
ort_inputs = {
"query": q.detach().cpu().numpy(),
- "key": new_k.detach().cpu().numpy(),
- "value": new_v.detach().cpu().numpy(),
"past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0),
"past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
@@ -763,9 +878,17 @@ def gqa_past_func(
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
io_binding = ort_session.io_binding()
+ if new_k is not None:
+ ort_inputs["key"] = new_k.detach().cpu().numpy()
+ ort_inputs["value"] = new_v.detach().cpu().numpy()
+ io_binding.bind_cpu_input("key", ort_inputs["key"])
+ io_binding.bind_cpu_input("value", ort_inputs["value"])
+ if cos is not None:
+ ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
+ ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
+ io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
+ io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
- io_binding.bind_cpu_input("key", ort_inputs["key"])
- io_binding.bind_cpu_input("value", ort_inputs["value"])
io_binding.bind_input(
"past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
@@ -790,8 +913,6 @@ def gqa_past_func(
else:
ort_inputs = {
"query": q.detach().cpu().numpy(),
- "key": new_k.detach().cpu().numpy(),
- "value": new_v.detach().cpu().numpy(),
"past_key": past_k.detach().cpu().numpy(),
"past_value": past_v.detach().cpu().numpy(),
"seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32),
@@ -805,9 +926,17 @@ def gqa_past_func(
sess_options = SessionOptions()
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"])
io_binding = ort_session.io_binding()
+ if new_k is not None:
+ ort_inputs["key"] = new_k.detach().cpu().numpy()
+ ort_inputs["value"] = new_v.detach().cpu().numpy()
+ io_binding.bind_cpu_input("key", ort_inputs["key"])
+ io_binding.bind_cpu_input("value", ort_inputs["value"])
+ if cos is not None:
+ ort_inputs["cos_cache"] = cos.detach().cpu().numpy()
+ ort_inputs["sin_cache"] = sin.detach().cpu().numpy()
+ io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"])
+ io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
- io_binding.bind_cpu_input("key", ort_inputs["key"])
- io_binding.bind_cpu_input("value", ort_inputs["value"])
io_binding.bind_cpu_input("past_key", ort_inputs["past_key"])
io_binding.bind_cpu_input("past_value", ort_inputs["past_value"])
io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"])
@@ -1029,9 +1158,12 @@ def parity_check_mha(
def parity_check_gqa_prompt(
config,
- causal=False,
+ causal=True,
local=False,
past_format=Formats.BSNH,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
rtol=1e-3,
atol=1e-3,
):
@@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt(
dtype=torch.float16,
requires_grad=False,
)
+ # print(k.shape)
+ # print(new_k.shape)
window_size = (-1, -1)
left_window_size = -1
@@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt(
# device="cuda",
# )
# cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length
+ rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size)
+
+ if rotary:
+ rotary_fraction = 1.0
+ rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
+ angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
+ cos = torch.cos(angle).to(dtype=torch.float16)
+ sin = torch.sin(angle).to(dtype=torch.float16)
+ if causal or local:
+ q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
+ else:
+ q_ro = rearrange(
+ apply_rotary_emb(
+ rearrange(q, "b s h d -> b 1 (s h) d"),
+ cos,
+ sin,
+ seqlen_offsets=rotary_seqlens,
+ interleaved=rotary_interleaved,
+ ),
+ "b 1 (s h) d -> b s h d",
+ s=config.q_sequence_length,
+ )
+ # q_ro = q
+ k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
+ else:
+ cos, sin = None, None
+ q_ro, k_ro = q, new_k
+
rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s")
arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size)
kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1")
update_mask = arange < kv_seqlens_expanded
- k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...")
+ k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded
out_ref, _ = attention_ref(
- q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
+ q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
- out, present_k, present_v = gqa_prompt_func(
- q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True
- )
+ if packed:
+ packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
+ out, present_k, present_v = gqa_prompt_func(
+ packed_qkv,
+ k,
+ v,
+ config,
+ None,
+ None,
+ cos,
+ sin,
+ cache_seqlens,
+ left_window_size,
+ past_format,
+ True,
+ rotary_interleaved,
+ )
+ else:
+ out, present_k, present_v = gqa_prompt_func(
+ q,
+ k,
+ v,
+ config,
+ new_k,
+ new_v,
+ cos,
+ sin,
+ cache_seqlens,
+ left_window_size,
+ past_format,
+ True,
+ rotary_interleaved,
+ )
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
+ # print(cache_seqlens[0])
+ # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0])
+ # print((out - out_ref)[0, :, 0, 0])
+
# Make sure past-present buffer updating correctly
assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
@@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt(
# Compare results
print(
"KV-buffer",
+ " packed:",
+ packed,
" causal:",
causal,
" local:",
local,
+ " rotary:",
+ rotary,
+ " rotary_interleaved:",
+ rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
@@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt(
def parity_check_gqa_prompt_no_buff(
config,
- causal=False,
+ causal=True,
local=False,
past_format=Formats.BSNH,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
rtol=1e-3,
atol=1e-3,
):
@@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff(
# device="cuda",
# )
# cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length
+ rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size)
+
+ if rotary:
+ rotary_fraction = 1.0
+ rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
+ angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
+ cos = torch.cos(angle).to(dtype=torch.float16)
+ sin = torch.sin(angle).to(dtype=torch.float16)
+ if causal or local:
+ q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
+ else:
+ q_ro = rearrange(
+ apply_rotary_emb(
+ rearrange(q, "b s h d -> b 1 (s h) d"),
+ cos,
+ sin,
+ seqlen_offsets=rotary_seqlens,
+ interleaved=rotary_interleaved,
+ ),
+ "b 1 (s h) d -> b s h d",
+ s=config.q_sequence_length,
+ )
+ # q_ro = q
+ k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
+ else:
+ cos, sin = None, None
+ q_ro, k_ro = q, k_cache_ref
+ k_cache_ref = k_ro
+
brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
new_mask = brange < cache_seqlens_expanded
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
out_ref, _ = attention_ref(
- q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size
+ q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
- out, present_k, present_v = gqa_prompt_func(
- q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False
- )
+ if packed:
+ packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
+ out, present_k, present_v = gqa_prompt_func(
+ packed_qkv,
+ None,
+ None,
+ config,
+ None,
+ None,
+ cos,
+ sin,
+ cache_seqlens,
+ left_window_size,
+ past_format,
+ False,
+ rotary_interleaved,
+ )
+ else:
+ out, present_k, present_v = gqa_prompt_func(
+ q,
+ None,
+ None,
+ config,
+ new_k,
+ new_v,
+ cos,
+ sin,
+ cache_seqlens,
+ left_window_size,
+ past_format,
+ False,
+ rotary_interleaved,
+ )
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
@@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff(
# Compare results
print(
- "KV-buffer",
+ "No buff",
+ " packed:",
+ packed,
+ " causal:",
+ causal,
+ " local:",
+ local,
+ " rotary:",
+ rotary,
+ " rotary_interleaved:",
+ rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
@@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff(
def parity_check_gqa_past(
config,
- causal=False,
+ causal=True,
local=False,
past_format=Formats.BSNH,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
rtol=1e-3,
atol=1e-3,
):
@@ -1336,6 +1613,7 @@ def parity_check_gqa_past(
dtype=torch.float16,
requires_grad=False,
)
+
window_size = (-1, -1)
left_window_size = -1
if local:
@@ -1359,18 +1637,45 @@ def parity_check_gqa_past(
dtype=torch.int32,
device="cuda",
)
+
+ if rotary:
+ rotary_fraction = 1.0
+ rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
+ angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
+ cos = torch.cos(angle).to(dtype=torch.float16)
+ sin = torch.sin(angle).to(dtype=torch.float16)
+ if causal or local:
+ q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
+ else:
+ q_ro = rearrange(
+ apply_rotary_emb(
+ rearrange(q, "b s h d -> b 1 (s h) d"),
+ cos,
+ sin,
+ seqlen_offsets=cache_seqlens,
+ interleaved=rotary_interleaved,
+ ),
+ "b 1 (s h) d -> b s h d",
+ s=config.sequence_length,
+ )
+ # q_ro = q
+ k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
+ else:
+ cos, sin = None, None
+ q_ro, k_ro = q, new_k
+
arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length
)
- k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...")
+ k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length
out_ref, _ = attention_ref(
- q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
+ q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1378,13 +1683,46 @@ def parity_check_gqa_past(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
- out, present_k, present_v = gqa_past_func(
- q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size
- )
+ if packed:
+ packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
+ out, present_k, present_v = gqa_past_func(
+ packed_qkv,
+ k,
+ v,
+ config,
+ None,
+ None,
+ cos,
+ sin,
+ cache_seqlens,
+ past_format,
+ True,
+ left_window_size,
+ rotary_interleaved,
+ )
+ else:
+ out, present_k, present_v = gqa_past_func(
+ q,
+ k,
+ v,
+ config,
+ new_k,
+ new_v,
+ cos,
+ sin,
+ cache_seqlens,
+ past_format,
+ True,
+ left_window_size,
+ rotary_interleaved,
+ )
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
+ # print(cache_seqlens[0])
+ # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :])
+
# Make sure past-present buffer updating correctly
assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
@@ -1394,10 +1732,16 @@ def parity_check_gqa_past(
"KV-buffer",
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
+ " packed:",
+ packed,
" causal:",
causal,
" local:",
local,
+ " rotary:",
+ rotary,
+ " rotary_interleaved:",
+ rotary_interleaved,
" B:",
config.batch_size,
" S:",
@@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff(
causal=False,
local=False,
past_format=Formats.BSNH,
+ rotary=False,
+ rotary_interleaved=False,
+ packed=False,
rtol=1e-3,
atol=1e-3,
):
@@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff(
device="cuda",
)
cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length
+
+ if rotary:
+ rotary_fraction = 1.0
+ rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
+ angle = (
+ torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
+ )
+ cos = torch.cos(angle).to(dtype=torch.float16)
+ sin = torch.sin(angle).to(dtype=torch.float16)
+ if causal or local:
+ q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
+ else:
+ q_ro = rearrange(
+ apply_rotary_emb(
+ rearrange(q, "b s h d -> b 1 (s h) d"),
+ cos,
+ sin,
+ seqlen_offsets=cache_seqlens,
+ interleaved=rotary_interleaved,
+ ),
+ "b 1 (s h) d -> b s h d",
+ s=config.sequence_length,
+ )
+ # q_ro = q
+ k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
+ else:
+ cos, sin = None, None
+ q_ro, k_ro = q, new_k
+
arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length
)
- k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...")
+ k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length
out_ref, _ = attention_ref(
- q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
+ q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size
)
out_ref = out_ref.detach().cpu().numpy()
if past_format == Formats.BNSH:
@@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff(
v_cache_ref = v_cache_ref.transpose(1, 2)
# Flash function
- out, present_k, present_v = gqa_past_func(
- q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size
- )
+ if packed:
+ packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
+ out, present_k, present_v = gqa_past_func(
+ packed_qkv,
+ k,
+ v,
+ config,
+ None,
+ None,
+ cos,
+ sin,
+ cache_seqlens,
+ past_format,
+ False,
+ window_size=left_window_size,
+ rotary_interleaved=rotary_interleaved,
+ )
+ else:
+ out, present_k, present_v = gqa_past_func(
+ q,
+ k,
+ v,
+ config,
+ new_k,
+ new_v,
+ cos,
+ sin,
+ cache_seqlens,
+ past_format,
+ False,
+ window_size=left_window_size,
+ rotary_interleaved=rotary_interleaved,
+ )
out = torch.squeeze(out, 0)
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
out = out.detach().cpu().numpy()
+ # print(cache_seqlens[0])
+ # print((out - out_ref)[0])
+ # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0])
+
# Make sure past-present buffer updating correctly
# assert numpy.allclose(
# present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True
@@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff(
# Compare results
print(
"NO buff",
+ " packed:",
+ packed,
" causal:",
causal,
" local:",
local,
+ " rotary:",
+ rotary,
+ " rotary_interleaved:",
+ rotary_interleaved,
"past kv format:",
"BSNH" if past_format == Formats.BSNH else "BNSH",
" B:",
@@ -1671,10 +2087,25 @@ def test_gqa_no_past(self):
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
- for past_kv_format in [Formats.BNSH]:
- config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
- parity_check_gqa_prompt(config, local=local, past_format=past_kv_format)
- parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format)
+ for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
+ for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]:
+ config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
+ parity_check_gqa_prompt(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
+ parity_check_gqa_prompt_no_buff(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
def test_gqa_past(self):
if not torch.cuda.is_available():
@@ -1684,7 +2115,6 @@ def test_gqa_past(self):
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- TEST GQA PAST (TOKEN GEN) ---------")
- print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
@@ -1706,6 +2136,7 @@ def test_gqa_past(self):
num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)
+ print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
for b in batches:
for s, s2 in seqs:
for n, n2 in num_h:
@@ -1734,23 +2165,30 @@ def test_gqa_past(self):
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
- for past_kv_format in [Formats.BNSH]:
- sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
- config = Config(b, s, s2, sp, n, n2, h)
- parity_check_gqa_past(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
- )
- parity_check_gqa_past_no_buff(
- config,
- local=local,
- past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
- )
+ for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
+ for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]:
+ sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
+ config = Config(b, s, s2, sp, n, n2, h)
+ parity_check_gqa_past(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
+ parity_check_gqa_past_no_buff(
+ config,
+ local=local,
+ past_format=past_kv_format,
+ rtol=1e-3,
+ atol=1e-3,
+ rotary=rotary,
+ rotary_interleaved=rotary_interleaved,
+ packed=packed,
+ )
if __name__ == "__main__":
diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py
index 1034a82cb2854..6e5cd7b57e403 100644
--- a/tools/ci_build/build.py
+++ b/tools/ci_build/build.py
@@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
numpy_init_version = numpy.__version__
pb_init_version = google.protobuf.__version__
run_subprocess(
- [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR
+ [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"],
+ cwd=SCRIPT_DIR,
)
run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd)
# Restore initial numpy/protobuf version in case other tests use it
diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt
similarity index 94%
rename from tools/ci_build/requirements.txt
rename to tools/ci_build/requirements-transformers-test.txt
index 57fc8f08336d2..a5279781462a7 100644
--- a/tools/ci_build/requirements.txt
+++ b/tools/ci_build/requirements-transformers-test.txt
@@ -3,7 +3,8 @@ packaging
protobuf==3.20.2
numpy==1.24.0 ; python_version < '3.12'
numpy==1.26.0 ; python_version >= '3.12'
+torch
coloredlogs==15.0
transformers==4.36.0
psutil
-einops
\ No newline at end of file
+einops