diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 407e08c96a891..734506681ab60 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -482,7 +482,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
@@ -508,7 +508,7 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
-|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
index 7b84971585f9f..c8fe9c77d8ff8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
@@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
}}; // For element-wise add
// Allocate space for output of Q(BS, D) + bias(D)
@@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
},
[](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
}}; // For element-wise add
// Get Q's bias from combined bias
@@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+
template
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
@@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);
+template Status MaybeTransposeToBNSH(AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, OrtValue& out);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index bfec9aef56727..ccaeb6654e286 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -75,7 +75,7 @@ class GQAAttentionBase {
int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]);
// Compute the attention score.
- size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
+ size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
@@ -87,16 +87,17 @@ class GQAAttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
- ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp);
+ present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
- ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(),
+ ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v,
+ seqlens_k->Data(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
- is_prompt, tp);
+ is_prompt, tp, allocator);
return Status::OK();
}
@@ -106,7 +107,7 @@ class GQAAttentionBase {
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template
- void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
+ void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
@@ -120,7 +121,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
- ThreadPool* tp) const { // thread pool
+ ThreadPool* tp, // thread pool
+ AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt(0);
@@ -131,7 +133,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
if (!past_present_share_buffer) {
- memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
+ memset((void*)present_key,
+ 0,
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}
const size_t loop_len = batch_size * num_heads_;
@@ -164,7 +168,7 @@ class GQAAttentionBase {
const size_t past_chunk_length = past_seqlen * head_size;
const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length;
- T* output = attention_probs + output_offset;
+ float* output = attention_probs + output_offset;
const T* k;
if (packed_qkv) {
@@ -190,12 +194,28 @@ class GQAAttentionBase {
q = Q + q_input_chunk_length * i;
}
- math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
- static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output,
- static_cast(present_buffer_sequence_length), nullptr);
+ if constexpr (std::is_same::value) {
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
+ static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/,
+ output, static_cast(present_buffer_sequence_length), nullptr);
+ } else {
+ size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
+ auto q_k_fp32 = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));
+
+ float* q_fp32 = static_cast(q_k_fp32);
+ MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length);
+
+ float* k_fp32 = q_fp32 + head_size * sequence_length;
+ MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen);
+
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32,
+ static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/,
+ output, static_cast(present_buffer_sequence_length), nullptr);
+ }
// compute Softmax
- T* output_softmax = output;
+ float* output_softmax = output;
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;
if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) {
@@ -237,7 +257,7 @@ class GQAAttentionBase {
template
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
- const T* attention_probs, // Attention probs with size BxNxSxT
+ const float* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const size_t batch_size, // batch size
@@ -251,7 +271,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
- ThreadPool* tp) const {
+ ThreadPool* tp,
+ AllocatorPtr allocator) const {
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt(0);
@@ -261,7 +282,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
if (!past_present_share_buffer) {
- memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
+ memset((void*)present_value,
+ 0,
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}
const size_t loop_len = batch_size * num_heads_;
@@ -285,6 +308,13 @@ class GQAAttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;
+ size_t output_fp32_bytes = 0;
+ if constexpr (std::is_same::value) {
+ output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float);
+ }
+ auto output_fp32 = allocator->Alloc(output_fp32_bytes);
+ BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator));
+
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
@@ -305,15 +335,39 @@ class GQAAttentionBase {
i / kv_num_heads_factor);
}
- T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i;
- math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
- attention_probs + attention_probs_offset,
- static_cast(present_buffer_sequence_length), v, static_cast(head_size),
- 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr);
+ if constexpr (std::is_same::value) {
+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
+ 1.f, /*alpha*/ attention_probs + attention_probs_offset,
+ static_cast(present_buffer_sequence_length), v,
+ static_cast(head_size), 0.0f /*beta*/, output_current,
+ static_cast(hidden_size), nullptr);
+ } else {
+ size_t bytes = head_size * total_seqlen * sizeof(float);
+ auto v_fp32 = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));
+
+ float* v_fp32_ptr = static_cast(v_fp32);
+ MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen);
+
+ float* output_fp32_current = static_cast(output_fp32) +
+ (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
+ 1.f, /*alpha*/ attention_probs + attention_probs_offset,
+ static_cast(present_buffer_sequence_length), v_fp32_ptr,
+ static_cast(head_size), 0.0f /*beta*/, output_fp32_current,
+ static_cast(hidden_size), nullptr);
+ }
}
});
+
+ if constexpr (std::is_same::value) {
+ MlasConvertFloatToHalfBuffer(static_cast(output_fp32),
+ output,
+ SafeInt(sequence_length) * batch_size * num_heads_ * head_size);
+ }
}
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 2a38e4a1ac636..a1ed35e54b008 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -22,16 +22,20 @@ namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
-ONNX_OPERATOR_TYPED_KERNEL_EX(
- GroupQueryAttention,
- kMSDomain,
- 1,
- float,
- kCpuExecutionProvider,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType())
- .TypeConstraint("M", DataTypeImpl::GetTensorType()),
- GroupQueryAttention);
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ GroupQueryAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ GroupQueryAttention);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
template
GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 6732f8b96cce2..cbfd2f0949363 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -13,16 +13,20 @@ namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
-ONNX_OPERATOR_TYPED_KERNEL_EX(
- RotaryEmbedding,
- kMSDomain,
- 1,
- float,
- kCpuExecutionProvider,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType())
- .TypeConstraint("M", DataTypeImpl::GetTensorType()),
- RotaryEmbedding);
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RotaryEmbedding, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ RotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
@@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
- T sign = 0;
+ bool sign = false;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
- sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
- j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
+ sign = i & 1;
+ j = sign ? i - 1 : i + 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
- sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
+ sign = (i >= half_rotary_emb_dim);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
- output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+ float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]);
+ float input_data_j = static_cast(input_data[j]);
+ float sin_data_cache_idx = static_cast(sin_data[cache_idx]);
+ if (sign) {
+ output_data_i += input_data_j * sin_data_cache_idx;
+ } else {
+ output_data_i -= input_data_j * sin_data_cache_idx;
+ }
+ output_data[i] = static_cast(output_data_i);
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
@@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar
const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output,
bool interleaved);
+template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input,
+ const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache,
+ MLFloat16* output, bool interleaved);
+
template
Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input(0);
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index dcd1f5ec22b52..e75d485830ca5 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
@@ -288,8 +290,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py
index dc21d4e4a5890..08ec5de328b9d 100644
--- a/onnxruntime/test/python/transformers/test_gqa_cpu.py
+++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py
@@ -29,6 +29,12 @@
GREEN = "\033[32m"
RESET = "\033[0m"
+ORT_TYPE = TensorProto.FLOAT
+TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32
+NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32
+RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3
+ATOL = RTOL
+
class Formats:
BSNH = 0
@@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt(
graph_input = [
helper.make_tensor_value_info(
"query",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.q_sequence_length,
@@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt(
graph_input += [
helper.make_tensor_value_info(
"key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.kv_sequence_length,
@@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.kv_sequence_length,
@@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt(
graph_input += [
helper.make_tensor_value_info(
"past_key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"past_value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt(
graph_input += [
helper.make_tensor_value_info(
"cos_cache",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
(math.floor(config.head_size / 16) * 16) // 2,
@@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"sin_cache",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.buffer_sequence_length if share_buffer else config.kv_sequence_length,
(math.floor(config.head_size / 16) * 16) // 2,
@@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt(
graph_output = [
helper.make_tensor_value_info(
"output",
- TensorProto.FLOAT,
+ ORT_TYPE,
[config.batch_size, config.q_sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"present_value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"present_key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt(
),
helper.make_tensor_value_info(
"present_value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -378,7 +384,7 @@ def create_group_query_attention_graph_past(
graph_input = [
helper.make_tensor_value_info(
"query",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.sequence_length,
@@ -391,7 +397,7 @@ def create_group_query_attention_graph_past(
),
helper.make_tensor_value_info(
"past_key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -401,7 +407,7 @@ def create_group_query_attention_graph_past(
),
helper.make_tensor_value_info(
"past_value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -424,7 +430,7 @@ def create_group_query_attention_graph_past(
graph_input += [
helper.make_tensor_value_info(
"key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.sequence_length,
@@ -433,7 +439,7 @@ def create_group_query_attention_graph_past(
),
helper.make_tensor_value_info(
"value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
config.sequence_length,
@@ -445,7 +451,7 @@ def create_group_query_attention_graph_past(
graph_input += [
helper.make_tensor_value_info(
"cos_cache",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
(math.floor(config.head_size / 16) * 16) // 2,
@@ -453,7 +459,7 @@ def create_group_query_attention_graph_past(
),
helper.make_tensor_value_info(
"sin_cache",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.kv_sequence_length + (0 if share_buffer else config.sequence_length),
(math.floor(config.head_size / 16) * 16) // 2,
@@ -464,12 +470,12 @@ def create_group_query_attention_graph_past(
graph_output = [
helper.make_tensor_value_info(
"output",
- TensorProto.FLOAT,
+ ORT_TYPE,
[config.batch_size, config.sequence_length, config.num_heads * config.head_size],
),
helper.make_tensor_value_info(
"present_key",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -479,7 +485,7 @@ def create_group_query_attention_graph_past(
),
helper.make_tensor_value_info(
"present_value",
- TensorProto.FLOAT,
+ ORT_TYPE,
[
config.batch_size,
present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads,
@@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True):
config.num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
key_padding_mask = generate_random_padding_mask(
@@ -722,13 +728,13 @@ def gqa_prompt_func(
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
- "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
+ "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
io_binding.bind_input(
"past_value",
"cpu",
0,
- numpy.float32,
+ NUMPY_TYPE,
ort_inputs["past_value"].shape(),
ort_inputs["past_value"].data_ptr(),
)
@@ -835,13 +841,13 @@ def gqa_past_func(
io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"])
io_binding.bind_cpu_input("query", ort_inputs["query"])
io_binding.bind_input(
- "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
+ "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr()
)
io_binding.bind_input(
"past_value",
"cpu",
0,
- numpy.float32,
+ NUMPY_TYPE,
ort_inputs["past_value"].shape(),
ort_inputs["past_value"].data_ptr(),
)
@@ -1017,9 +1023,11 @@ def attention_ref(
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
):
q = torch.randn(
config.batch_size,
@@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt(
config.num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
k = torch.randn(
@@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt(
config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
v = torch.randn(
@@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt(
config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_k = torch.randn(
@@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_v = torch.randn(
@@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
@@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt(
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
- cos = torch.cos(angle).to(dtype=torch.float32)
- sin = torch.sin(angle).to(dtype=torch.float32)
+ cos = torch.cos(angle).to(dtype=TORCH_TYPE)
+ sin = torch.sin(angle).to(dtype=TORCH_TYPE)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved
@@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt(
kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size)
kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1")
update_mask = arange < kv_seqlens_expanded
- k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
- v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...")
+ k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE)
+ v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE)
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads)
key_padding_mask = arange < cache_seqlens_expanded
@@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt(
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
- assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
- assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
+ assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
+ assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
# Compare results
- all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
+ all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True)
correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET
print(
"KV-buffer",
@@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
):
q = torch.randn(
config.batch_size,
@@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff(
config.num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_k = torch.randn(
@@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_v = torch.randn(
@@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
@@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff(
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
- cos = torch.cos(angle).to(dtype=torch.float32)
- sin = torch.sin(angle).to(dtype=torch.float32)
+ cos = torch.cos(angle).to(dtype=TORCH_TYPE)
+ sin = torch.sin(angle).to(dtype=TORCH_TYPE)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved
@@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff(
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
- assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
- assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
+ assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
+ assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
# Compare results
- all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
+ all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True)
correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET
print(
"No buff",
@@ -1458,8 +1466,8 @@ def parity_check_gqa_past(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
):
q = torch.randn(
config.batch_size,
@@ -1467,7 +1475,7 @@ def parity_check_gqa_past(
config.num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
k = torch.randn(
@@ -1476,7 +1484,7 @@ def parity_check_gqa_past(
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
v = torch.randn(
@@ -1485,7 +1493,7 @@ def parity_check_gqa_past(
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_k = torch.randn(
@@ -1494,7 +1502,7 @@ def parity_check_gqa_past(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_v = torch.randn(
@@ -1503,7 +1511,7 @@ def parity_check_gqa_past(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
@@ -1534,8 +1542,8 @@ def parity_check_gqa_past(
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
- cos = torch.cos(angle).to(dtype=torch.float32)
- sin = torch.sin(angle).to(dtype=torch.float32)
+ cos = torch.cos(angle).to(dtype=TORCH_TYPE)
+ sin = torch.sin(angle).to(dtype=TORCH_TYPE)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved
@@ -1624,11 +1632,11 @@ def parity_check_gqa_past(
out = out.detach().cpu().numpy()
# Make sure past-present buffer updating correctly
- assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
- assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True)
+ assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
+ assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True)
# Compare results
- all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
+ all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True)
correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET
print(
"KV-buffer",
@@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
):
torch.manual_seed(69)
q = torch.randn(
@@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff(
config.num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
k = torch.randn(
@@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff(
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
v = torch.randn(
@@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff(
config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_k = torch.randn(
@@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
new_v = torch.randn(
@@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff(
config.kv_num_heads,
config.head_size,
device="cpu",
- dtype=torch.float32,
+ dtype=TORCH_TYPE,
requires_grad=False,
)
@@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff(
angle = (
torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
)
- cos = torch.cos(angle).to(dtype=torch.float32)
- sin = torch.sin(angle).to(dtype=torch.float32)
+ cos = torch.cos(angle).to(dtype=TORCH_TYPE)
+ sin = torch.sin(angle).to(dtype=TORCH_TYPE)
rot = LlamaMSRotaryEmbedding()
q_ro = rot(
q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved
@@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff(
out = out.detach().cpu().numpy()
# Compare results
- all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True)
+ all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True)
correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET
print(
"NO buff",
@@ -1983,8 +1991,8 @@ def test_gqa_past(self):
config,
local=local,
past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
@@ -1996,8 +2004,8 @@ def test_gqa_past(self):
config,
local=local,
past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
@@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self):
config,
local=local,
past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
@@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self):
config,
local=local,
past_format=past_kv_format,
- rtol=1e-3,
- atol=1e-3,
+ rtol=RTOL,
+ atol=ATOL,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,