diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 407e08c96a891..734506681ab60 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -482,7 +482,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -508,7 +508,7 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index 7b84971585f9f..c8fe9c77d8ff8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Allocate space for output of Q(BS, D) + bias(D) @@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Get Q's bias from combined bias @@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); +template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, @@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, OrtValue& out); +template Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfec9aef56727..ccaeb6654e286 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,7 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -87,16 +87,17 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp); + is_prompt, tp, allocator); return Status::OK(); } @@ -106,7 +107,7 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor @@ -120,7 +121,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { // thread pool + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -131,7 +133,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_key, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -164,7 +168,7 @@ class GQAAttentionBase { const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; - T* output = attention_probs + output_offset; + float* output = attention_probs + output_offset; const T* k; if (packed_qkv) { @@ -190,12 +194,28 @@ class GQAAttentionBase { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, - static_cast(present_buffer_sequence_length), nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } else { + size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); + auto q_k_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); + + float* q_fp32 = static_cast(q_k_fp32); + MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + + float* k_fp32 = q_fp32 + head_size * sequence_length; + MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32, + static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } // compute Softmax - T* output_softmax = output; + float* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { @@ -237,7 +257,7 @@ class GQAAttentionBase { template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT + const float* attention_probs, // Attention probs with size BxNxSxT const T* V, // V value with size BxN_kvxSxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor const size_t batch_size, // batch size @@ -251,7 +271,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { + ThreadPool* tp, + AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -261,7 +282,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_value, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -285,6 +308,13 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const size_t batch_index = i / num_heads_; @@ -305,15 +335,39 @@ class GQAAttentionBase { i / kv_num_heads_factor); } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v, static_cast(head_size), - 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + if constexpr (std::is_same::value) { + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, + static_cast(head_size), 0.0f /*beta*/, output_current, + static_cast(hidden_size), nullptr); + } else { + size_t bytes = head_size * total_seqlen * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); + + float* v_fp32_ptr = static_cast(v_fp32); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen); + + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v_fp32_ptr, + static_cast(head_size), 0.0f /*beta*/, output_fp32_current, + static_cast(hidden_size), nullptr); + } } }); + + if constexpr (std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 2a38e4a1ac636..a1ed35e54b008 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -22,16 +22,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - GroupQueryAttention, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - GroupQueryAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + GroupQueryAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 6732f8b96cce2..cbfd2f0949363 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -13,16 +13,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - RotaryEmbedding, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - RotaryEmbedding); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { @@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* sin_data = sin_cache + cache_offset; int cache_idx = 0; - T sign = 0; + bool sign = false; int j = 0; for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign } else { cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + sign = (i >= half_rotary_emb_dim); j = (i + half_rotary_emb_dim) % rotary_emb_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); } for (int i = rotary_emb_dim; i < head_size; i++) { output_data[i] = input_data[i]; @@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, bool interleaved); +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, + MLFloat16* output, bool interleaved); + template Status RotaryEmbedding::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index dcd1f5ec22b52..e75d485830ca5 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -288,8 +290,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dc21d4e4a5890..08ec5de328b9d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -29,6 +29,12 @@ GREEN = "\033[32m" RESET = "\033[0m" +ORT_TYPE = TensorProto.FLOAT +TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 +NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 +RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 +ATOL = RTOL + class Formats: BSNH = 0 @@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.q_sequence_length, @@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -378,7 +384,7 @@ def create_group_query_attention_graph_past( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -391,7 +397,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -401,7 +407,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -424,7 +430,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -433,7 +439,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -445,7 +451,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -453,7 +459,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -464,12 +470,12 @@ def create_group_query_attention_graph_past( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -479,7 +485,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) key_padding_mask = generate_random_padding_mask( @@ -722,13 +728,13 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -835,13 +841,13 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -1017,9 +1023,11 @@ def attention_ref( attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt( kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded @@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "No buff", @@ -1458,8 +1466,8 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1467,7 +1475,7 @@ def parity_check_gqa_past( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1476,7 +1484,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1485,7 +1493,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1494,7 +1502,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1503,7 +1511,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1534,8 +1542,8 @@ def parity_check_gqa_past( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1624,11 +1632,11 @@ def parity_check_gqa_past( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): torch.manual_seed(69) q = torch.randn( @@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff( angle = ( torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi ) - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff( out = out.detach().cpu().numpy() # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "NO buff", @@ -1983,8 +1991,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -1996,8 +2004,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed,