Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] SparseAttention op #21110

Merged
merged 26 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain integer type.</dd>
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;

int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"

#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
Expand Down
31 changes: 24 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"

#include "core/common/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
Expand All @@ -14,14 +14,31 @@
namespace onnxruntime {
namespace contrib {

class GQAAttentionBase : public AttentionBase {
class GQAAttentionBase {
protected:
GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
: AttentionBase(info, require_same_hidden_size) {}
GQAAttentionBase(const OpKernelInfo& info, bool has_local) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);

int local_window_size_;
bool do_rotary_;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
kv_num_heads_ = static_cast<int>(kv_num_heads);

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;

local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;
}

int num_heads_; // number of attention heads of Q
int kv_num_heads_; // number of attention heads of K or V
float scale_; // the scaling factor applied before softmax
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
Expand Down
42 changes: 16 additions & 26 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "group_query_attention.h"
#include "group_query_attention_helper.h"
#include "attention_utils.h"
#include "rotary_embedding.h"
#include "rotary_embedding_helper.h"
#include "contrib_ops/cpu/bert/group_query_attention.h"
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
#include "contrib_ops/cpu/bert/rotary_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
Expand Down Expand Up @@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention<float>);

template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
}
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
: OpKernel(info), GQAAttentionBase(info, true) {}

template <typename T>
Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
Expand Down Expand Up @@ -174,14 +164,14 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
}
}

Expand Down
32 changes: 0 additions & 32 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query,

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
}

template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
const T* input,
T* output) {
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;

const int loop_len = batch_size * sequence_length * kv_num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_num_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_num_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}

} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime
47 changes: 47 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace rotary_helper {

template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
const T* input,
T* output) {
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;

const int loop_len = batch_size * sequence_length * kv_num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_num_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_num_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}

} // namespace rotary_helper
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
Expand Down Expand Up @@ -281,6 +282,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention)>,

Check warning on line 285 in onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc:285: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
Expand Down
Loading
Loading