diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 45306c852a906..ed9e2a0567d2f 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float16), tensor(bfloat16)
+- T : tensor(float), tensor(float16), tensor(bfloat16)
- Constrain input and output to float tensors.
- M : tensor(int32)
- Constrain integer type.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 5f19c16cba616..df5897529baae 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -512,6 +512,7 @@ Do not modify directly.*
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
+|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index af902a713eaa2..a6782daa58f1a 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -68,7 +68,6 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;
int num_heads_; // number of attention heads
- int kv_num_heads_; // different for k and v for group query attention
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index fc4905cd31819..dd52001c2ac6b 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -3,9 +3,8 @@
#pragma once
-#include "attention_base.h"
-#include "attention_helper.h"
-
+#include "contrib_ops/cpu/bert/attention_base.h"
+#include "contrib_ops/cpu/bert/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
index 6b0c5f395cab0..137612a4bf902 100644
--- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -3,8 +3,8 @@
#pragma once
-#include "attention_base.h"
-#include "attention_helper.h"
+#include "contrib_ops/cpu/bert/attention_base.h"
+#include "contrib_ops/cpu/bert/attention_helper.h"
#include "core/common/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
@@ -14,14 +14,31 @@
namespace onnxruntime {
namespace contrib {
-class GQAAttentionBase : public AttentionBase {
+class GQAAttentionBase {
protected:
- GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
- : AttentionBase(info, require_same_hidden_size) {}
+ GQAAttentionBase(const OpKernelInfo& info, bool has_local) {
+ int64_t num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ num_heads_ = static_cast(num_heads);
- int local_window_size_;
- bool do_rotary_;
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
+ kv_num_heads_ = static_cast(kv_num_heads);
+
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
+
+ local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1;
+ }
+
+ int num_heads_; // number of attention heads of Q
+ int kv_num_heads_; // number of attention heads of K or V
+ float scale_; // the scaling factor applied before softmax
+ bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
+ int local_window_size_;
template
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index cad9274e68149..97388a9d6bce8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "group_query_attention.h"
-#include "group_query_attention_helper.h"
-#include "attention_utils.h"
-#include "rotary_embedding.h"
-#include "rotary_embedding_helper.h"
+#include "contrib_ops/cpu/bert/group_query_attention.h"
+#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
+#include "contrib_ops/cpu/bert/rotary_helper.h"
+#include "contrib_ops/cpu/bert/attention_utils.h"
+#include "contrib_ops/cpu/bert/rotary_embedding.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
@@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention);
template
-GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
- int64_t num_heads = 0;
- int64_t kv_num_heads = 0;
- ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
- ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
- num_heads_ = static_cast(num_heads);
- kv_num_heads_ = static_cast(kv_num_heads);
-
- mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
- local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
- do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
- rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
-}
+GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
+ : OpKernel(info), GQAAttentionBase(info, true) {}
template
Status GroupQueryAttention::Compute(OpKernelContext* context) const {
@@ -174,14 +164,14 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
- ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV(tp,
- parameters.batch_size,
- parameters.sequence_length,
- parameters.num_heads,
- parameters.kv_num_heads,
- parameters.head_size,
- v_input,
- v_rotary));
+ ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp,
+ parameters.batch_size,
+ parameters.sequence_length,
+ parameters.num_heads,
+ parameters.kv_num_heads,
+ parameters.head_size,
+ v_input,
+ v_rotary));
}
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
index a7de02452aa58..7ffb72fe55d25 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query,
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
}
-
-template
-Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
- int batch_size,
- int sequence_length,
- int num_heads,
- int kv_num_heads,
- int head_size,
- const T* input,
- T* output) {
- int seq_stride = head_size;
- int head_stride = sequence_length * seq_stride;
- int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;
-
- const int loop_len = batch_size * sequence_length * kv_num_heads;
- const double cost = static_cast(head_size);
- ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
- for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
- const int b = static_cast((ptr / kv_num_heads) / sequence_length);
- const int s = static_cast((ptr / kv_num_heads) % sequence_length);
- const int n = static_cast(ptr % kv_num_heads);
- const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
- const T* input_data = input + block_offset;
- T* output_data = output + block_offset;
- for (int i = 0; i < head_size; i++) {
- output_data[i] = input_data[i];
- }
- }
- });
- return Status::OK();
-}
-
} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h
new file mode 100644
index 0000000000000..714d962dfb34e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_helper.h
@@ -0,0 +1,47 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace rotary_helper {
+
+template
+Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
+ int batch_size,
+ int sequence_length,
+ int num_heads,
+ int kv_num_heads,
+ int head_size,
+ const T* input,
+ T* output) {
+ int seq_stride = head_size;
+ int head_stride = sequence_length * seq_stride;
+ int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;
+
+ const int loop_len = batch_size * sequence_length * kv_num_heads;
+ const double cost = static_cast(head_size);
+ ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
+ const int b = static_cast((ptr / kv_num_heads) / sequence_length);
+ const int s = static_cast((ptr / kv_num_heads) % sequence_length);
+ const int n = static_cast(ptr % kv_num_heads);
+ const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
+ const T* input_data = input + block_offset;
+ T* output_data = output + block_offset;
+ for (int i = 0; i < head_size; i++) {
+ output_data[i] = input_data[i];
+ }
+ }
+ });
+ return Status::OK();
+}
+
+} // namespace rotary_helper
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index e8ca4370135cc..90a51fda0b188 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -21,6 +21,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
@@ -281,6 +282,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc
new file mode 100644
index 0000000000000..e337f41a8688d
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc
@@ -0,0 +1,226 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/sparse/sparse_attention.h"
+#include "contrib_ops/cpu/sparse/sparse_attention_helper.h"
+#include "contrib_ops/cpu/bert/rotary_helper.h"
+#include "contrib_ops/cpu/bert/attention_utils.h"
+#include "contrib_ops/cpu/bert/rotary_embedding.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+
+#include "core/framework/tensorprotoutils.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/common/safeint.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+
+using onnxruntime::concurrency::ThreadPool;
+
+namespace onnxruntime {
+namespace contrib {
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ SparseAttention,
+ kMSDomain,
+ 1,
+ float,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()),
+ SparseAttention);
+
+template
+SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) {
+}
+
+template
+Status SparseAttention::Compute(OpKernelContext* context) const {
+ const Tensor* query = context->Input(0);
+ const Tensor* key = context->Input(1);
+ const Tensor* value = context->Input(2);
+ const Tensor* past_key = context->Input(3);
+ const Tensor* past_value = context->Input(4);
+ const Tensor* block_row_indices = context->Input(5);
+ const Tensor* block_col_indices = context->Input(6);
+ const Tensor* total_seq_len = context->Input(7);
+ const Tensor* total_key_lengths = context->Input(8);
+ const Tensor* cos_cache = context->Input(9);
+ const Tensor* sin_cache = context->Input(10);
+
+ SparseAttentionParameters parameters = {};
+
+ // Parameters from node attribute shall be set before calling CheckInputs
+ parameters.sparse_block_size = sparse_block_size_;
+ parameters.num_heads = num_heads_;
+ parameters.kv_num_heads = kv_num_heads_;
+ parameters.scale = scale_;
+ parameters.do_rotary = do_rotary_;
+ parameters.rotary_interleaved = rotary_interleaved_;
+ ORT_RETURN_IF_ERROR(sparse_attention_helper::CheckInputs(¶meters,
+ query,
+ key,
+ value,
+ past_key,
+ past_value,
+ cos_cache,
+ sin_cache,
+ block_row_indices,
+ block_col_indices,
+ total_key_lengths,
+ total_seq_len));
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ int q_hidden_size = parameters.hidden_size;
+
+ std::vector output_shape(3);
+ output_shape[0] = static_cast(batch_size);
+ output_shape[1] = static_cast(sequence_length);
+ output_shape[2] = static_cast(q_hidden_size);
+ Tensor* output = context->Output(0, output_shape);
+
+ constexpr bool past_present_share_buffer = true; // Only supports share buffer for past and present for now.
+ parameters.past_present_share_buffer = past_present_share_buffer;
+
+ int head_size = parameters.head_size;
+ const int cache_length = past_present_share_buffer
+ ? parameters.max_cache_sequence_length
+ : parameters.total_sequence_length;
+ std::vector present_k_shape({static_cast(batch_size),
+ static_cast(kv_num_heads_),
+ static_cast(cache_length),
+ static_cast(head_size)});
+ std::vector present_v_shape({static_cast(batch_size),
+ static_cast(kv_num_heads_),
+ static_cast(cache_length),
+ static_cast(head_size)});
+ Tensor* present_key = context->Output(1, present_k_shape);
+ Tensor* present_value = context->Output(2, present_v_shape);
+
+ // Check past and present share buffer.
+ if (past_present_share_buffer) {
+ ORT_ENFORCE(past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw());
+ }
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+
+ auto element_type = DataTypeImpl::GetType();
+ OrtValue Q;
+ OrtValue K;
+ OrtValue V;
+
+ const bool packed_qkv = parameters.is_packed_qkv;
+ if (packed_qkv) {
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
+ } else {
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
+ }
+
+ if (do_rotary_) {
+ rotary_embedding_helper::RotaryParameters rotary_params = {};
+ rotary_params.batch_size = batch_size;
+ rotary_params.sequence_length = sequence_length;
+ rotary_params.hidden_size = q_hidden_size;
+ rotary_params.head_size = head_size;
+ rotary_params.rotary_embedding_dim = parameters.rotary_dim;
+ rotary_params.num_heads = num_heads_;
+ rotary_params.max_sequence_length = sequence_length; // unused
+ rotary_params.seq_stride = head_size;
+ rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
+ rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) *
+ rotary_params.head_stride;
+ rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
+ rotary_params.transposed = true;
+ auto* tp = context->GetOperatorThreadPool();
+
+ const bool is_prompt = parameters.total_sequence_length == parameters.sequence_length;
+ std::vector pos_ids(is_prompt ? 1 : batch_size * sequence_length);
+ if (is_prompt) {
+ pos_ids[0] = static_cast(0);
+ } else if (sequence_length == 1) {
+ for (int b = 0; b < batch_size; b++) {
+ pos_ids[b] = static_cast(total_key_lengths->Data()[b]) - 1;
+ }
+ } else {
+ // This supports a rare case that sequence_length > 1 when it is not prompt.
+ for (int b = 0; b < batch_size; b++) {
+ for (int s = 0; s < sequence_length; s++) {
+ pos_ids[b * sequence_length + s] = static_cast(total_key_lengths->Data()[b]) -
+ (sequence_length - s);
+ }
+ }
+ }
+
+ const T* q_input;
+ const T* k_input;
+ T* q_rotary;
+ T* k_rotary;
+ if (packed_qkv) {
+ OrtValue RotaryQKV;
+ TensorShape qkv_shape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size});
+ Tensor::InitOrtValue(element_type, qkv_shape, allocator, RotaryQKV);
+ q_input = Q.Get().Data();
+ k_input = q_input + num_heads_ * sequence_length * head_size;
+ q_rotary = RotaryQKV.GetMutable()->MutableData();
+ k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
+ Q = RotaryQKV;
+ } else {
+ OrtValue RotaryQ;
+ TensorShape q_shape({batch_size, num_heads_, sequence_length, head_size});
+ Tensor::InitOrtValue(element_type, q_shape, allocator, RotaryQ);
+ OrtValue RotaryK;
+ TensorShape k_shape({batch_size, kv_num_heads_, sequence_length, head_size});
+ Tensor::InitOrtValue(element_type, k_shape, allocator, RotaryK);
+ q_input = Q.Get().Data();
+ k_input = K.Get().Data();
+ q_rotary = RotaryQ.GetMutable()->MutableData();
+ k_rotary = RotaryK.GetMutable()->MutableData();
+ Q = RotaryQ;
+ K = RotaryK;
+ }
+
+ ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input,
+ pos_ids.data(), cos_cache->Data(),
+ sin_cache->Data(), q_rotary, rotary_interleaved_));
+
+ rotary_params.num_heads = kv_num_heads_;
+ rotary_params.hidden_size = parameters.kv_hidden_size;
+ if (!packed_qkv) {
+ rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
+ }
+ ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input,
+ pos_ids.data(), cos_cache->Data(),
+ sin_cache->Data(), k_rotary, rotary_interleaved_));
+ if (packed_qkv) {
+ const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
+ T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
+ ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV(tp,
+ parameters.batch_size,
+ parameters.sequence_length,
+ parameters.num_heads,
+ parameters.kv_num_heads,
+ parameters.head_size,
+ v_input,
+ v_rotary));
+ }
+ }
+
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ // Compute the attention score and apply the score to V
+ return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(),
+ packed_qkv ? nullptr : V.Get().Data(), past_key, past_value,
+ output, present_key, present_value,
+ total_key_lengths, block_row_indices, block_col_indices, parameters, allocator, context);
+}
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h
new file mode 100644
index 0000000000000..4267d85c0e35d
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.h
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "contrib_ops/cpu/sparse/sparse_attention_base.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class SparseAttention final : public OpKernel, public SparseAttentionBase {
+ public:
+ SparseAttention(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
new file mode 100644
index 0000000000000..cf66bd8407126
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
@@ -0,0 +1,390 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "contrib_ops/cpu/bert/attention_helper.h"
+
+#include "core/common/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+#include "contrib_ops/cpu/utils/dump_tensor.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class SparseAttentionBase {
+ protected:
+ SparseAttentionBase(const OpKernelInfo& info) {
+ int64_t num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ num_heads_ = static_cast(num_heads);
+
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
+ kv_num_heads_ = static_cast(kv_num_heads);
+
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
+
+ int64_t sparse_block_size = 0;
+ ORT_ENFORCE(info.GetAttr("sparse_block_size", &sparse_block_size).IsOK());
+ sparse_block_size_ = static_cast(sparse_block_size);
+ }
+
+ int num_heads_; // number of attention heads of Q
+ int kv_num_heads_; // number of attention heads of K or V
+ float scale_; // the scaling factor applied before softmax
+ bool do_rotary_; // whether or not to use rotary embeddings
+ bool rotary_interleaved_;
+ int sparse_block_size_;
+
+ template
+ Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
+ const T* K, // K data with shape BxN_kvxSxH
+ const T* V, // V data with shape BxN_kvxSxH
+ const Tensor* past_key, // past K input tensor
+ const Tensor* past_value, // past V input tensor
+ Tensor* output, // output tensor
+ Tensor* present_key, // present K output tensor
+ Tensor* present_value, // present V output tensor
+ const Tensor* total_key_lengths, // total key lengths tensor
+ const Tensor* block_row_indices, // block row indices
+ const Tensor* block_col_indices, // block column indices
+ SparseAttentionParameters& parameters, // attention parameters
+ AllocatorPtr allocator, // allocator for temporary tensors
+ OpKernelContext* context) const {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int head_size = parameters.head_size;
+ const bool packed_qkv = parameters.is_packed_qkv;
+
+ int past_buffer_sequence_length = static_cast(past_key->Shape().GetDims()[2]);
+ int present_buffer_sequence_length = static_cast(present_key->Shape().GetDims()[2]);
+
+ // Allocate a buffer to store Softmax(QK)
+ size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * parameters.total_sequence_length * sizeof(T);
+ auto attention_probs = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
+
+ bool past_present_share_buffer = parameters.past_present_share_buffer;
+ assert(past_present_share_buffer);
+
+ auto* tp = context->GetOperatorThreadPool();
+
+ const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
+ ComputeAttentionProbs(
+ static_cast(attention_probs), Q, k, total_key_lengths->Data(),
+ batch_size, sequence_length, parameters.total_sequence_length,
+ past_buffer_sequence_length, present_buffer_sequence_length, head_size,
+ past_key->Data(), present_key->MutableData(), past_present_share_buffer, packed_qkv,
+ block_row_indices->Data(), block_col_indices->Data(), parameters, tp);
+
+ // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
+ const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
+ ComputeVxAttentionScore(
+ output->MutableData(), static_cast(attention_probs), v,
+ total_key_lengths->Data(),
+ batch_size, sequence_length, parameters.total_sequence_length,
+ past_buffer_sequence_length, present_buffer_sequence_length, head_size, parameters.hidden_size,
+ past_value->Data(), present_value->MutableData(), past_present_share_buffer, packed_qkv, tp);
+
+ return Status::OK();
+ }
+
+ private:
+ // Helper function to compute the attention probs. It does 2 things:
+ // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
+ // attention_probs(B, N, S, T) = Softmax(attention_probs)
+ template
+ void ComputeAttentionProbs(
+ T* attention_probs, // output buffer with size BxNxSxT
+ const T* Q, // query start pointer
+ const T* K, // key start pointer
+ const int32_t* total_key_lengths, // total key sequence lengths (past + new)
+ int batch_size, // batch size
+ int sequence_length, // sequence length of query or new key
+ int total_sequence_length, // maximum past_sequence_length + sequence_length
+ int past_buffer_sequence_length, // sequence length of past_key or past_value
+ int present_buffer_sequence_length, // sequence length of present_key or present_value
+ int head_size, // head size of query
+ const T* past_key, // past key
+ T* present_key, // present key
+ bool past_present_share_buffer, // whether past_key and present_key share the buffer
+ bool packed_qkv, // whether Q, K, V are packed
+ const int32_t* block_row_indices, // block row indices
+ const int32_t* block_col_indices, // block column indices
+ SparseAttentionParameters& parameters, // parameters
+ ThreadPool* tp) const { // thread pool
+ const bool is_prompt = (total_sequence_length == sequence_length);
+ const ptrdiff_t packed_batch_stride =
+ packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
+ : SafeInt(0);
+ const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
+ const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H
+ const size_t kv_input_chunk_length = q_input_chunk_length;
+ const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size;
+ const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size;
+
+ const int loop_len = batch_size * num_heads_;
+ const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_;
+
+ TensorOpCost unit_cost;
+ const ptrdiff_t probs_matrix_bytes =
+ SafeInt(sequence_length) * total_sequence_length * sizeof(T);
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length);
+ unit_cost.bytes_loaded =
+ static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T));
+ unit_cost.bytes_stored = static_cast(probs_matrix_bytes);
+
+ unit_cost.bytes_loaded += static_cast(probs_matrix_bytes);
+ unit_cost.bytes_stored += static_cast(probs_matrix_bytes);
+
+ // Cost to concatenate current key to cache (assume past and present share buffer).
+ double bytes_to_copy_key = static_cast(sizeof(T) * sequence_length * head_size);
+ unit_cost.bytes_loaded += bytes_to_copy_key;
+ unit_cost.bytes_stored += bytes_to_copy_key;
+
+ DUMP_CPU_TENSOR_INIT();
+ DUMP_CPU_TENSOR("block_row_indices", block_row_indices, parameters.num_sparse_layout, parameters.stride_row_indices);
+ DUMP_CPU_TENSOR("block_col_indices", block_col_indices, parameters.num_sparse_layout, parameters.stride_col_indices);
+
+ // Check whether each layout has sparse (has zero in lower triangular)
+ std::vector layout_has_sparse(parameters.num_sparse_layout);
+ for (int layout_index = 0; layout_index < parameters.num_sparse_layout; layout_index++) {
+ int nonzero_elements = block_row_indices[(layout_index + 1) * parameters.stride_row_indices - 1];
+ int dense_nonzero = (parameters.stride_row_indices * (parameters.stride_row_indices - 1)) / 2;
+ layout_has_sparse[layout_index] = nonzero_elements < dense_nonzero;
+ DUMP_STRING("layout_has_sparse[", layout_index, "]=", layout_has_sparse[layout_index]);
+ }
+
+ ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",loop_len=", loop_len, ",begin=", begin, ",end=", end);
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int batch_index = static_cast(i) / num_heads_;
+ const int head_index = static_cast(i) % num_heads_;
+ const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length);
+ const size_t past_chunk_length = static_cast(past_seq_len) * head_size;
+ const int total_seq_len = total_key_lengths[batch_index];
+
+ const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length;
+ T* output = attention_probs + output_offset;
+
+ const T* k;
+ if (packed_qkv) {
+ k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
+ } else {
+ k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
+ }
+
+ // Concatenate past_k + k -> present_k
+ // TODO: avoid copying mutiple times for a group.
+ k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
+ past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
+ i / kv_num_heads_factor);
+
+ // Compute Q*K' + AttentionMask
+ // original transposed each iteration
+ // A: Q (B x N x) S x H (B x N x) S x H S x H
+ // B: K' (B x N x) T x H (B x N x) H x T H x T
+ // C: attention_probs (B x N x) S x T (B x N x) S x T S x T
+ const T* q;
+ if (packed_qkv) {
+ q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index;
+ } else {
+ q = Q + q_input_chunk_length * i;
+ }
+
+ DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index,
+ ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv);
+ DUMP_CPU_TENSOR("Q", q, sequence_length, head_size);
+ DUMP_CPU_TENSOR("K", k, total_seq_len, head_size);
+
+ math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seq_len, head_size, alpha, q,
+ head_size, k, head_size, 0.0f /*bata*/, output, total_seq_len,
+ nullptr);
+
+ DUMP_CPU_TENSOR("QK", output, sequence_length, total_seq_len);
+
+ // Compute Softmax for causal and output result in place.
+ T* output_softmax = output;
+
+ int layout_id = head_index % parameters.num_sparse_layout;
+ bool is_sparse_layout = layout_has_sparse[layout_id];
+
+ DUMP_STRING("layout_id=", layout_id, ",is_sparse_layout=", is_sparse_layout);
+
+ if (!is_sparse_layout) { // dense
+ for (int q_id = 0; q_id < sequence_length; q_id++) {
+ int causal_length = past_seq_len + q_id + 1;
+ ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr);
+ for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) {
+ output_softmax[remain_seq_id] = 0.f;
+ }
+ output_softmax += total_seq_len;
+ }
+ } else { // sparse
+ int q_id = 0;
+ bool has_sparse = false;
+ std::vector mask(parameters.max_sequence_length);
+
+ const int32_t* layout_row_indices = block_row_indices + layout_id * parameters.stride_row_indices;
+ const int32_t* layout_col_indices = block_col_indices + layout_id * parameters.stride_col_indices;
+ do {
+ int q_abs_position = past_seq_len + q_id;
+ int causal_length = q_abs_position + 1;
+
+ // Update mask when query token is the first or at the boundary of sparse block.
+ if (q_id == 0 || q_abs_position % parameters.sparse_block_size == 0) {
+ int row_in_sparse_layout = q_abs_position / parameters.sparse_block_size;
+ int start_in_col_indices = layout_row_indices[row_in_sparse_layout];
+ int end_in_col_indices = layout_row_indices[row_in_sparse_layout + 1];
+ int nonzero_blocks = end_in_col_indices - start_in_col_indices;
+ has_sparse = (nonzero_blocks != row_in_sparse_layout + 1);
+
+ DUMP_STRING("q_id=", q_id,
+ ",q_abs_position=", q_abs_position,
+ ",sparse_block_size=", parameters.sparse_block_size,
+ ",row_in_sparse_layout=", row_in_sparse_layout,
+ ",start_in_col_indices=", start_in_col_indices,
+ ",end_in_col_indices=", end_in_col_indices,
+ ",nonzero_blocks=", nonzero_blocks,
+ ",has_sparse=", has_sparse);
+
+ // Expand attention mask for current row of q_id
+ if (has_sparse) {
+ int block_aligned_length = q_abs_position / parameters.sparse_block_size * parameters.sparse_block_size + parameters.sparse_block_size;
+ DUMP_STRING("block_aligned_length=", block_aligned_length);
+
+ std::fill_n(mask.begin(), block_aligned_length, 0);
+ for (int j = start_in_col_indices; j < end_in_col_indices; j++) {
+ int col_in_sparse_layout = layout_col_indices[j];
+
+ int offset = col_in_sparse_layout * parameters.sparse_block_size;
+ for (int s = 0; s < parameters.sparse_block_size; s++, offset++) {
+ mask[offset] = 1;
+ }
+ }
+
+ DUMP_CPU_TENSOR("mask", mask, block_aligned_length);
+ }
+ }
+
+ // Update inline according to attention mask.
+ if (has_sparse) {
+ for (int s = 0; s < causal_length; s++) {
+ if (mask[s] == 0)
+ output_softmax[s] = std::numeric_limits::lowest();
+ }
+ }
+ ComputeAttentionSoftmaxInplace(output_softmax, 1, causal_length, nullptr);
+
+ for (int remain_seq_id = causal_length; remain_seq_id < total_seq_len; remain_seq_id++) {
+ output_softmax[remain_seq_id] = 0.f;
+ }
+
+ output_softmax += total_seq_len;
+ q_id++;
+
+ } while (q_id < sequence_length);
+ }
+
+ DUMP_CPU_TENSOR("softmax", output, sequence_length, total_seq_len);
+ }
+ });
+ }
+
+ template
+ void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
+ const T* attention_probs, // Softmax of Q*K' with size BxNxSxT
+ const T* V, // v value with size BxN_kvxSxH
+ const int32_t* total_key_lengths, // total sequence lengths
+ int batch_size, // batch size
+ int sequence_length, // sequence length
+ int total_sequence_length, // maximum past_sequence_length + sequence_length
+ int past_buffer_sequence_length, // sequence length in past state
+ int present_buffer_sequence_length, // sequence length in past state
+ int head_size, // head size of Q, K, V
+ int hidden_size, // hidden size of Output
+ const T* past_value, // past value only
+ T* present_value, // present value only
+ bool past_present_share_buffer, // whether past_key and present_key share the buffer
+ bool packed_qkv, // whether Q, K, V are packed
+ ThreadPool* tp) const {
+ const bool is_prompt = sequence_length == total_sequence_length;
+ const ptrdiff_t packed_batch_stride =
+ packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
+ : SafeInt(0);
+ const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
+
+ const int kv_input_chunk_length = sequence_length * head_size; // S x H
+ const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size;
+ const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size;
+
+ // The cost of Gemm.
+ TensorOpCost unit_cost;
+ // Here we use total_sequence_length to estimate total_key_lengths[batch_index] used in GEMM.
+ unit_cost.compute_cycles =
+ static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length);
+ unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) *
+ total_sequence_length * sizeof(T));
+ unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T));
+
+ if (present_value) {
+ double bytes_to_copy_value = static_cast(sizeof(T) * sequence_length * head_size);
+ unit_cost.bytes_loaded += bytes_to_copy_value;
+ unit_cost.bytes_stored += bytes_to_copy_value;
+ }
+
+ DUMP_CPU_TENSOR_INIT();
+
+ ThreadPool::TryParallelFor(
+ tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ DUMP_STRING("batch_size=", batch_size, ",num_heads=", num_heads_, ",begin=", begin, ",end=", end);
+
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int batch_index = static_cast(i / num_heads_);
+ const int head_index = static_cast(i % num_heads_);
+ const int past_seq_len = is_prompt ? 0 : (static_cast(total_key_lengths[batch_index]) - sequence_length);
+ const size_t past_chunk_length = static_cast(past_seq_len) * head_size;
+ const int total_seq_len = total_key_lengths[batch_index];
+
+ DUMP_STRING("i=", i, ",batch_index=", batch_index, ",head_index=", head_index,
+ ",past_seq_len=", past_seq_len, ",total_seq_len=", total_seq_len, ",packed_qkv=", packed_qkv);
+
+ const T* v;
+ if (packed_qkv) {
+ v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
+ } else {
+ v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
+ }
+
+ // Concatenate past_v + v -> present_v
+ v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
+ past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
+ i / kv_num_heads_factor);
+
+ DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size);
+
+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
+ ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_seq_len * i;
+
+ DUMP_CPU_TENSOR("attention_probs", attention_probs + attention_probs_offset, sequence_length, total_seq_len);
+
+ math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seq_len,
+ 1.f, /*alpha*/
+ attention_probs + attention_probs_offset, total_seq_len, v,
+ head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr);
+
+ DUMP_CPU_TENSOR("out", attention_probs + attention_probs_offset, sequence_length, head_size);
+ }
+ });
+ }
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h
similarity index 98%
rename from onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h
rename to onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h
index a5f1d50e618af..ca69370b4ce17 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h
@@ -21,7 +21,7 @@ Status CheckInputs(void* params,
const Tensor* sin_cache,
const Tensor* block_row_indices,
const Tensor* block_col_indices,
- const Tensor* seqlens_k_total,
+ const Tensor* total_key_lengths,
const Tensor* total_seq_len) {
// No packing for q/k/v:
// query (batch_size, sequence_length, num_heads * head_size)
@@ -36,7 +36,7 @@ Status CheckInputs(void* params,
// past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
// block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size
// block_col_indices (num_layout, max_nnz)
- // seqlens_k_total (batch_size) when do_rotary is True, optional otherwise
+ // total_key_lengths (batch_size)
// total_seq_len (1)
// cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
// sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true.
@@ -197,7 +197,7 @@ Status CheckInputs(void* params,
}
// Check the shape of total_key_sequence_lengths. We do not check the values here.
- const auto& k_len_dim = seqlens_k_total->Shape().GetDims();
+ const auto& k_len_dim = total_key_lengths->Shape().GetDims();
if (k_len_dim.size() != 1 && k_len_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"key_total_sequence_lengths must have shape (batch_size).");
diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
index 7d3f6eb9295d8..865a1dc29ce47 100644
--- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
@@ -3,7 +3,7 @@
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
#include "contrib_ops/cuda/sparse/sparse_attention.h"
-#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
+#include "contrib_ops/cpu/sparse/sparse_attention_helper.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
#include "core/platform/env_var_utils.h"
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 2a14ba1db4bb7..7272a949f7218 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1254,7 +1254,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"present_value",
"Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).",
"T")
- .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
SparseAttentionTypeAndShapeInference(ctx, 3);
diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py
index f33a56ee4e1f9..f18bcdba65579 100644
--- a/onnxruntime/test/python/transformers/test_sparse_attention.py
+++ b/onnxruntime/test/python/transformers/test_sparse_attention.py
@@ -8,14 +8,16 @@
"""
import math
import unittest
-from typing import Optional
+from typing import Optional, Union
import torch
+from benchmark_mha import InputFormats
from onnx import TensorProto, helper
+from parameterized import parameterized
from torch import Tensor
-from onnxruntime import InferenceSession, SessionOptions
-from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
+from onnxruntime import InferenceSession, SessionOptions, get_available_providers
+from onnxruntime.transformers.io_binding_helper import CudaSession
ENABLE_DEBUG = False
@@ -34,6 +36,7 @@ def __init__(
softmax_scale: Optional[float],
do_rotary: bool,
rotary_interleaved: bool,
+ provider: str = "CUDAExecutionProvider",
device="cuda",
dtype=torch.float16,
share_buffer: bool = True,
@@ -62,11 +65,13 @@ def __init__(
self.do_rotary = do_rotary
self.rotary_interleaved = rotary_interleaved
+
+ self.provider = provider
self.device = device
+ self.dtype = dtype
self.share_buffer = share_buffer
self.is_packed_qkv = is_packed_qkv
- self.dtype = dtype
def shape_dict(self):
shapes = {
@@ -106,7 +111,7 @@ def get_cos_sin_cache(self, dtype):
def random_inputs(self):
device = self.device
# Since bfloat16 is not supported in ORT python I/O binding API, we always use float16 as model inputs.
- dtype = torch.float16
+ dtype = torch.float16 if self.dtype == torch.bfloat16 else self.dtype
# Always use non-packed qkv to generate same inputs for Torch and ORT.
packed = self.is_packed_qkv # Save the original value.
@@ -153,7 +158,9 @@ def __init__(
softmax_scale=None,
do_rotary: bool = False,
rotary_interleaved: bool = False,
+ provider: str = "CUDAExecutionProvider",
device="cuda",
+ dtype=torch.float16,
local_window_size: int = -1,
attention_mask=None,
is_packed_qkv=False,
@@ -162,17 +169,19 @@ def __init__(
):
super().__init__(
"GroupQueryAttention",
- batch_size,
- sequence_length,
- max_sequence_length,
- past_sequence_length,
- num_heads,
- kv_num_heads,
- head_size,
- softmax_scale,
- do_rotary,
- rotary_interleaved,
- device,
+ batch_size=batch_size,
+ sequence_length=sequence_length,
+ max_sequence_length=max_sequence_length,
+ past_sequence_length=past_sequence_length,
+ num_heads=num_heads,
+ kv_num_heads=kv_num_heads,
+ head_size=head_size,
+ softmax_scale=softmax_scale,
+ do_rotary=do_rotary,
+ rotary_interleaved=rotary_interleaved,
+ provider=provider,
+ device=device,
+ dtype=dtype,
is_packed_qkv=is_packed_qkv,
max_cache_sequence_length=max_cache_sequence_length,
max_rotary_sequence_length=max_rotary_sequence_length,
@@ -220,24 +229,28 @@ def __init__(
softmax_scale=None,
do_rotary: bool = False,
rotary_interleaved: bool = False,
+ provider: str = "CUDAExecutionProvider",
device="cuda",
+ dtype=torch.float16,
is_packed_qkv=False,
max_cache_sequence_length=None,
max_rotary_sequence_length=None,
):
super().__init__(
"SparseAttention",
- batch_size,
- sequence_length,
- max_sequence_length,
- past_sequence_length,
- num_heads,
- kv_num_heads,
- head_size,
- softmax_scale,
- do_rotary,
- rotary_interleaved,
- device,
+ batch_size=batch_size,
+ sequence_length=sequence_length,
+ max_sequence_length=max_sequence_length,
+ past_sequence_length=past_sequence_length,
+ num_heads=num_heads,
+ kv_num_heads=kv_num_heads,
+ head_size=head_size,
+ softmax_scale=softmax_scale,
+ do_rotary=do_rotary,
+ rotary_interleaved=rotary_interleaved,
+ provider=provider,
+ device=device,
+ dtype=dtype,
is_packed_qkv=is_packed_qkv,
max_cache_sequence_length=max_cache_sequence_length,
max_rotary_sequence_length=max_rotary_sequence_length,
@@ -288,17 +301,19 @@ def random_inputs(self):
def get_comparable_ort_gqa_config(self, use_local=False) -> GroupQueryAttentionConfig:
return GroupQueryAttentionConfig(
- self.batch_size,
- self.sequence_length,
- self.max_sequence_length,
- self.past_sequence_length,
- self.num_heads,
- self.kv_num_heads,
- self.head_size,
- self.softmax_scale,
- self.do_rotary,
- self.rotary_interleaved,
- self.device,
+ batch_size=self.batch_size,
+ sequence_length=self.sequence_length,
+ max_sequence_length=self.max_sequence_length,
+ past_sequence_length=self.past_sequence_length,
+ num_heads=self.num_heads,
+ kv_num_heads=self.kv_num_heads,
+ head_size=self.head_size,
+ softmax_scale=self.softmax_scale,
+ do_rotary=self.do_rotary,
+ rotary_interleaved=self.rotary_interleaved,
+ provider=self.provider,
+ device=self.device,
+ dtype=self.dtype,
local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1,
is_packed_qkv=self.is_packed_qkv,
max_cache_sequence_length=self.max_cache_sequence_length,
@@ -314,17 +329,19 @@ def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttenti
attention_mask = attention_mask[:, :, -self.sequence_length :, :]
return GroupQueryAttentionConfig(
- self.batch_size,
- self.sequence_length,
- self.max_sequence_length,
- self.past_sequence_length,
- self.num_heads,
- self.kv_num_heads,
- self.head_size,
- self.softmax_scale,
- self.do_rotary,
- self.rotary_interleaved,
- self.device,
+ batch_size=self.batch_size,
+ sequence_length=self.sequence_length,
+ max_sequence_length=self.max_sequence_length,
+ past_sequence_length=self.past_sequence_length,
+ num_heads=self.num_heads,
+ kv_num_heads=self.kv_num_heads,
+ head_size=self.head_size,
+ softmax_scale=self.softmax_scale,
+ do_rotary=self.do_rotary,
+ rotary_interleaved=self.rotary_interleaved,
+ provider=self.provider,
+ device=self.device,
+ dtype=self.dtype,
attention_mask=attention_mask,
is_packed_qkv=False, # torch reference implementation does not support packed qkv.
max_cache_sequence_length=self.max_cache_sequence_length,
@@ -375,7 +392,7 @@ def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size):
def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
# ORT Python I/O binding API does not support bf16, so always use fp16 as graph inputs/outputs.
- io_float_type = TensorProto.FLOAT16
+ io_float_type = TensorProto.FLOAT if config.dtype == torch.float32 else TensorProto.FLOAT16
suffix = "_bf16" if config.dtype == torch.bfloat16 else ""
nodes = [
@@ -487,9 +504,9 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig):
def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
- assert config.dtype == torch.float16
+ assert config.dtype in [torch.float16, torch.float32]
- float_type = TensorProto.FLOAT16
+ float_type = TensorProto.FLOAT16 if config.dtype in [torch.float16] else TensorProto.FLOAT
nodes = [
helper.make_node(
"GroupQueryAttention",
@@ -599,7 +616,10 @@ def group_query_attention_reference(
attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value)
result = attn_output.transpose(1, 2).contiguous()
- torch.cuda.synchronize()
+
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
return result
@@ -671,25 +691,42 @@ def infer(self):
)
+def create_ort_session(
+ config: Union[SparseAttentionConfig, GroupQueryAttentionConfig], session_options=None, enable_cuda_graph=False
+) -> CudaSession:
+ if isinstance(config, SparseAttentionConfig):
+ onnx_model_str = create_sparse_attention_onnx_model(config)
+ else:
+ onnx_model_str = create_group_query_attention_onnx_model(config)
+
+ if config.provider == "CUDAExecutionProvider":
+ device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
+ provider_options = CudaSession.get_cuda_provider_options(
+ device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream
+ )
+ providers = [(config.provider, provider_options), "CPUExecutionProvider"]
+ else:
+ providers = ["CPUExecutionProvider"]
+
+ ort_session = InferenceSession(onnx_model_str, session_options, providers=providers)
+ # Note that CudaSession could work with both CUDA and CPU providers.
+ cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph)
+ shape_dict = config.shape_dict()
+ cuda_session.allocate_buffers(shape_dict)
+
+ buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
+ for input_name, output_name in buffer_sharing.items():
+ cuda_session.set_buffer_sharing(input_name, output_name)
+
+ return cuda_session
+
+
class OrtGroupQueryAttention:
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
def __init__(self, config: GroupQueryAttentionConfig):
- cuda_provider_options = CudaSession.get_cuda_provider_options(
- torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
- )
- onnx_model_str = create_group_query_attention_onnx_model(config)
- self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
- self.gpu_binding_manager = GpuBindingManager(
- ort_session=self.ort_session,
- device=config.device,
- stream=torch.cuda.current_stream().cuda_stream,
- max_cuda_graphs=2,
- )
- buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
- self.gpu_binding = self.gpu_binding_manager.get_binding(
- config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
- )
+ self.session = create_ort_session(config)
+
self.feed_dict = config.random_inputs()
if ENABLE_DEBUG and not config.is_packed_qkv:
@@ -709,28 +746,14 @@ def __init__(self, config: GroupQueryAttentionConfig):
print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"])
def infer(self):
- return self.gpu_binding.infer(self.feed_dict)
+ return self.session.infer(self.feed_dict)
class OrtSparseAttention:
"""A wrapper of ORT SparseAttention to test relevance and performance."""
def __init__(self, config: SparseAttentionConfig):
- cuda_provider_options = CudaSession.get_cuda_provider_options(
- torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
- )
- onnx_model_str = create_sparse_attention_onnx_model(config)
- self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
- self.gpu_binding_manager = GpuBindingManager(
- ort_session=self.ort_session,
- device=config.device,
- stream=torch.cuda.current_stream().cuda_stream,
- max_cuda_graphs=2,
- )
- buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
- self.gpu_binding = self.gpu_binding_manager.get_binding(
- config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
- )
+ self.session = create_ort_session(config)
self.feed_dict = config.random_inputs()
if ENABLE_DEBUG and not config.is_packed_qkv:
@@ -753,19 +776,196 @@ def __init__(self, config: SparseAttentionConfig):
print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"])
def infer(self):
- return self.gpu_binding.infer(self.feed_dict)
+ return self.session.infer(self.feed_dict)
+
+
+def get_provider_support_info(provider: str, use_kv_cache: bool):
+ if provider == "CUDAExecutionProvider":
+ formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H]
+ device_id = torch.cuda.current_device()
+ device = torch.device("cuda", device_id)
+ dtype = torch.float16
+ else:
+ assert provider == "CPUExecutionProvider"
+ formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.QKV_BSN3H]
+ device = torch.device("cpu")
+ dtype = torch.float
+ return device, dtype, formats
+
+
+def has_cuda_support():
+ if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers():
+ major, minor = torch.cuda.get_device_capability()
+ sm = major * 10 + minor
+ return sm in [75, 80, 86, 89, 90]
+
+ return False
+
+
+def get_simple_test_case(provider: str, has_past_kv: bool):
+ """A simple test case for debugging purpose."""
+ device, dtype, _formats = get_provider_support_info(provider, False)
+ if provider == "CPUExecutionProvider":
+ # A simple case for debugging purpose.
+ max_sequence_length = 16
+ sequence_length = 15
+ packed_qkv = False
+ config = SparseAttentionConfig(
+ batch_size=1,
+ sequence_length=1 if has_past_kv else sequence_length,
+ max_sequence_length=max_sequence_length,
+ past_sequence_length=sequence_length if has_past_kv else 0,
+ num_heads=4,
+ kv_num_heads=2,
+ head_size=8,
+ sparse_block_size=4,
+ num_layout=2,
+ local_blocks=2,
+ vert_stride=2,
+ softmax_scale=0.0,
+ provider=provider,
+ device=device,
+ dtype=dtype,
+ is_packed_qkv=packed_qkv,
+ max_cache_sequence_length=max_sequence_length,
+ )
+ yield config
+
+
+def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rotary=False):
+ if provider == "CUDAExecutionProvider" and not has_cuda_support():
+ return
+ yield
+
+ device, dtype, formats = get_provider_support_info(provider, False)
+ batch_sizes = [1, 2, 3]
+ sequence_lengths = [1, 64, 127, 128, 192, 256]
+ heads = [4, 8, 16]
+
+ # SparseAttention CUDA kernel only supports head size 128
+ head_sizes = [128] if provider == "CUDAExecutionProvider" else [128, 256]
+
+ if comprehensive:
+ for batch_size in batch_sizes:
+ for sequence_length in sequence_lengths:
+ for num_heads in heads:
+ for head_size in head_sizes:
+ for format in formats:
+ packed_qkv = format == InputFormats.QKV_BSN3H
+
+ non_prompt_len = 1
+ if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary:
+ # Generate case of sequence_length > 1 when it is not prompt for CPU provider.
+ non_prompt_len = batch_size
+
+ query_sequence_length = non_prompt_len if has_past_kv else sequence_length
+ config = SparseAttentionConfig(
+ batch_size=batch_size,
+ sequence_length=query_sequence_length,
+ max_sequence_length=256,
+ past_sequence_length=(
+ min(256 - query_sequence_length, sequence_length) if has_past_kv else 0
+ ),
+ num_heads=num_heads,
+ kv_num_heads=num_heads // 2,
+ head_size=head_size,
+ sparse_block_size=64,
+ num_layout=2,
+ local_blocks=2,
+ vert_stride=2,
+ softmax_scale=1.8 / (128**0.5),
+ provider=provider,
+ device=device,
+ dtype=dtype,
+ is_packed_qkv=packed_qkv,
+ do_rotary=do_rotary,
+ rotary_interleaved=sequence_length <= 128,
+ max_cache_sequence_length=None if sequence_length >= 128 else 128,
+ )
+ yield config
+ else:
+ test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes))
+ for i in range(test_cases):
+ batch_size = batch_sizes[i % len(batch_sizes)]
+ sequence_length = sequence_lengths[i % len(sequence_lengths)]
+ num_heads = heads[i % len(heads)]
+ head_size = head_sizes[i % len(head_sizes)]
+ format = formats[i % len(formats)]
+ packed_qkv = format == InputFormats.QKV_BSN3H
+
+ non_prompt_len = 1
+ if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary:
+ # Generate case of sequence_length > 1 when it is not prompt for CPU provider.
+ non_prompt_len = batch_size
+
+ query_sequence_length = non_prompt_len if has_past_kv else sequence_length
+
+ config = SparseAttentionConfig(
+ batch_size=batch_size,
+ sequence_length=query_sequence_length,
+ max_sequence_length=256,
+ past_sequence_length=min(256 - query_sequence_length, sequence_length) if has_past_kv else 0,
+ num_heads=num_heads,
+ kv_num_heads=num_heads // 2,
+ head_size=head_size,
+ sparse_block_size=64,
+ num_layout=2,
+ local_blocks=2,
+ vert_stride=2,
+ softmax_scale=1.8 / (128**0.5),
+ provider=provider,
+ device=device,
+ dtype=dtype,
+ is_packed_qkv=packed_qkv,
+ do_rotary=do_rotary,
+ rotary_interleaved=sequence_length <= 128,
+ max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer.
+ )
+ yield config
+
+
+# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine.
+comprehensive_mode = False
class TestSparseAttention(unittest.TestCase):
- @unittest.skipUnless(torch.cuda.is_available(), "cuda not available")
+ @unittest.skipUnless(has_cuda_support(), "cuda not available")
def test_sparse_attention(self):
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
+ self.run_relevance_test(sm)
- if sm not in [75, 80, 86, 89, 90]:
- self.skipTest("SparseAttention is not supported on this GPU")
+ @parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True)
+ def test_simple_token_cpu(self, config: SparseAttentionConfig):
+ self.run_one_relevance_test(config)
- self.run_relevance_test(sm)
+ @parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True)
+ def test_simple_prompt_cpu(self, config: SparseAttentionConfig):
+ self.run_one_relevance_test(config)
+
+ @parameterized.expand(
+ get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True
+ )
+ def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig):
+ # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense.
+ if config.sparse_block_size * config.local_blocks > config.total_sequence_length:
+ self.run_one_relevance_test(config)
+
+ @parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True)
+ def test_sparse_att_token_gpu(self, config):
+ self.run_one_relevance_test(config)
+
+ @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True)
+ def test_sparse_att_token_cpu(self, config):
+ self.run_one_relevance_test(config)
+
+ @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True)
+ def test_sparse_att_prompt_cpu(self, config):
+ self.run_one_relevance_test(config)
+
+ @parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True)
+ def test_sparse_att_prompt_gpu(self, config):
+ self.run_one_relevance_test(config)
def run_one_relevance_test(self, config: SparseAttentionConfig):
if (not config.do_rotary) and config.total_sequence_length <= 2048:
@@ -774,6 +974,10 @@ def run_one_relevance_test(self, config: SparseAttentionConfig):
obj = TorchGroupQueryAttention(gqa_config)
expected_out = obj.infer()
else:
+ if config.dtype == torch.bfloat16:
+ # Skip test since create_group_query_attention_onnx_model does not support bfloat16 right now.
+ return
+
# Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only).
gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False)
obj = OrtGroupQueryAttention(gqa_config)
@@ -881,6 +1085,8 @@ def run_relevance_no_past_128k(self, sm: int, device):
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
+ provider="CUDAExecutionProvider",
+ dtype=torch.float16,
device=device,
is_packed_qkv=packed_qkv,
)
@@ -907,6 +1113,8 @@ def run_relevance_past_128k(self, sm: int, device):
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
+ provider="CUDAExecutionProvider",
+ dtype=torch.float16,
device=device,
is_packed_qkv=packed_qkv,
)
@@ -921,7 +1129,8 @@ def run_relevance_test(self, sm: int):
device = torch.device("cuda", device_id)
with torch.no_grad():
# Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length)
- if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024:
+ # The 128k tests fails randomly in T4 GPU, increase memory threshold for now.
+ if torch.cuda.get_device_properties(device_id).total_memory > 20 * 1024 * 1024 * 1024:
self.run_relevance_no_past_128k(sm, device)
self.run_relevance_past_128k(sm, device)
self.run_relevance_no_past(sm, device)