Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] SparseAttention op #21110

Merged
merged 26 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand All @@ -47,6 +48,7 @@ target_sources(onnxruntime_mlas PRIVATE
${MLAS_INC_DIR}/mlas_q4.h
${MLAS_INC_DIR}/mlas_qnbit.h
${MLAS_INC_DIR}/mlas.h
${MLAS_INC_DIR}/mlas_flashattn.h
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5646,7 +5646,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain integer type.</dd>
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;

int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Environment variable for tuning attention algorithm
constexpr const char* kAttentionAlgo = "ORT_ATTENTION_ALGO";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;

Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"

#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
Expand Down
32 changes: 25 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#pragma once

#include "attention_base.h"
#include "attention_helper.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cpu/bert/attention_helper.h"

#include "core/common/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
Expand All @@ -14,14 +14,32 @@
namespace onnxruntime {
namespace contrib {

class GQAAttentionBase : public AttentionBase {
class GQAAttentionBase {
protected:
GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
: AttentionBase(info, require_same_hidden_size) {}
GQAAttentionBase(const OpKernelInfo& info, bool has_local) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);

int local_window_size_;
bool do_rotary_;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
kv_num_heads_ = static_cast<int>(kv_num_heads);

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;

// local_window_size is used in GQA but not in SparseAttention.
local_window_size_ = has_local ? static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1)) : -1;
}

int num_heads_; // number of attention heads of Q
int kv_num_heads_; // number of attention heads of K or V
float scale_; // the scaling factor applied before softmax
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;

template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
Expand Down
42 changes: 16 additions & 26 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "group_query_attention.h"
#include "group_query_attention_helper.h"
#include "attention_utils.h"
#include "rotary_embedding.h"
#include "rotary_embedding_helper.h"
#include "contrib_ops/cpu/bert/group_query_attention.h"
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
#include "contrib_ops/cpu/bert/rotary_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
Expand Down Expand Up @@ -33,19 +34,8 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention<float>);

template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
}
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
: OpKernel(info), GQAAttentionBase(info, true) {}

template <typename T>
Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
Expand Down Expand Up @@ -174,14 +164,14 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
ORT_RETURN_IF_ERROR(rotary_helper::PackVIntoRotaryQKV<T>(tp,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
v_input,
v_rotary));
}
}

Expand Down
32 changes: 0 additions & 32 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,38 +263,6 @@ Status CheckInputs(const Tensor* query,

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
}

template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
const T* input,
T* output) {
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (num_heads + 2 * kv_num_heads) * head_stride;

const int loop_len = batch_size * sequence_length * kv_num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_num_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_num_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}

} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime
86 changes: 76 additions & 10 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/bert/multihead_attention.h"
#include <type_traits>
#include <vector>
#include <algorithm>

#include "attention_cpu_base.h"
#include "multihead_attention.h"
#include "multihead_attention_helper.h"
#include "attention_utils.h"

#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/env_var_utils.h"
#include "core/platform/threadpool.h"
#include "core/mlas/inc/mlas_flashattn.h"

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>

using onnxruntime::concurrency::ThreadPool;

Expand All @@ -39,6 +41,12 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;

const auto& env = Env::Default();
l2_cache_size_ = env.GetL2CacheSize();

disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
algo_ = ParseEnvironmentVariableWithDefault<int>(attention::kAttentionAlgo, 0);
}

template <typename T>
Expand All @@ -60,7 +68,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
}

AttentionParameters parameters = {};
constexpr float scale = 1.0f;
bool past_present_share_buffer = false;
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
Expand All @@ -74,7 +81,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
&parameters,
num_heads_,
mask_filter_value_,
scale,
scale_,
is_unidirectional_,
past_present_share_buffer,
false));
Expand All @@ -99,8 +106,14 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
const int v_bias_offset = 2 * qk_hidden_size;

// If optional outputs aren't needed, present_k and present_v will be null
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(v_head_size)});
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(v_head_size)});
Tensor* present_k = context->Output(1, present_k_shape);
Tensor* present_v = context->Output(2, present_v_shape);

Expand Down Expand Up @@ -138,6 +151,59 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));

if (std::is_same_v<T, float> &&
!disable_flash_ &&
!is_unidirectional_ &&
key_padding_mask == nullptr &&
extra_add_qk == nullptr &&
past_key == nullptr &&
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
l2_cache_size_ > 0) {
FlashAttentionThreadedArgs args;

if (algo_ == 1) {
args.q_block_size = q_sequence_length >= 768 ? 256 : (q_sequence_length >= 192 ? 64 : 32);
args.kv_block_size = 512;
} else {
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid row_size_kv = 0
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
}
args.q_block_size = std::min(args.q_block_size, q_sequence_length);
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length);

args.batch_size = batch_size;
args.num_heads = num_heads_;
args.q_sequence_length = q_sequence_length;
args.kv_sequence_length = kv_sequence_length;
args.qk_head_size = qk_head_size;
args.v_head_size = v_head_size;
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);

int columns = args.kv_block_size + 2 + args.v_head_size; // columns in qk + qk_max + qk_sum + out
args.buffer_size_per_thread = static_cast<size_t>(args.q_block_size) * static_cast<size_t>(columns);

size_t total_buffer_size = args.buffer_size_per_thread * static_cast<size_t>(args.thread_count);
IAllocatorUniquePtr<float> buffer = IAllocator::MakeUniquePtr<float>(allocator, total_buffer_size);
args.buffer = buffer.get();

args.query = Q.Get<Tensor>().Data<float>();
args.key = K.Get<Tensor>().Data<float>();
args.value = V.Get<Tensor>().Data<float>();
args.output = output->MutableData<float>();

concurrency::ThreadPool::TrySimpleParallelFor(tp, args.thread_count, [&](std::ptrdiff_t thread_id) {
FlashAttentionThreaded(thread_id, &args);
});

return Status::OK();
}

// Compute the attention score and apply the score to V
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
K.GetMutable<Tensor>()->MutableData<T>(),
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/bert/attention_cpu_base.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -19,6 +20,9 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
int num_heads_; // number of attention heads
float mask_filter_value_;
bool is_unidirectional_;
bool disable_flash_;
int l2_cache_size_;
int algo_;
};

} // namespace contrib
Expand Down
Loading
Loading