Skip to content

Commit

Permalink
undo mha
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 1, 2024
1 parent 1de2033 commit 7550a31
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 520 deletions.
2 changes: 0 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand All @@ -48,7 +47,6 @@ target_sources(onnxruntime_mlas PRIVATE
${MLAS_INC_DIR}/mlas_q4.h
${MLAS_INC_DIR}/mlas_qnbit.h
${MLAS_INC_DIR}/mlas.h
${MLAS_INC_DIR}/mlas_flashattn.h
)

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Environment variable for tuning attention algorithm
constexpr const char* kAttentionAlgo = "ORT_ATTENTION_ALGO";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;

Expand Down
86 changes: 10 additions & 76 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/bert/multihead_attention.h"
#include <type_traits>
#include <vector>
#include <algorithm>

#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "attention_cpu_base.h"
#include "multihead_attention.h"
#include "multihead_attention_helper.h"
#include "attention_utils.h"

#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/env_var_utils.h"
#include "core/platform/threadpool.h"
#include "core/mlas/inc/mlas_flashattn.h"

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>

using onnxruntime::concurrency::ThreadPool;

Expand All @@ -41,12 +39,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;

const auto& env = Env::Default();
l2_cache_size_ = env.GetL2CacheSize();

disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
algo_ = ParseEnvironmentVariableWithDefault<int>(attention::kAttentionAlgo, 0);
}

template <typename T>
Expand All @@ -68,6 +60,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
}

AttentionParameters parameters = {};
constexpr float scale = 1.0f;
bool past_present_share_buffer = false;
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
Expand All @@ -81,7 +74,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
&parameters,
num_heads_,
mask_filter_value_,
scale_,
scale,
is_unidirectional_,
past_present_share_buffer,
false));
Expand All @@ -106,14 +99,8 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
const int v_bias_offset = 2 * qk_hidden_size;

// If optional outputs aren't needed, present_k and present_v will be null
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(v_head_size)});
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(v_head_size)});
Tensor* present_k = context->Output(1, present_k_shape);
Tensor* present_v = context->Output(2, present_v_shape);

Expand Down Expand Up @@ -151,59 +138,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));

if (std::is_same_v<T, float> &&
!disable_flash_ &&
!is_unidirectional_ &&
key_padding_mask == nullptr &&
extra_add_qk == nullptr &&
past_key == nullptr &&
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
l2_cache_size_ > 0) {
FlashAttentionThreadedArgs args;

if (algo_ == 1) {
args.q_block_size = q_sequence_length >= 768 ? 256 : (q_sequence_length >= 192 ? 64 : 32);
args.kv_block_size = 512;
} else {
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid row_size_kv = 0
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
}
args.q_block_size = std::min(args.q_block_size, q_sequence_length);
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length);

args.batch_size = batch_size;
args.num_heads = num_heads_;
args.q_sequence_length = q_sequence_length;
args.kv_sequence_length = kv_sequence_length;
args.qk_head_size = qk_head_size;
args.v_head_size = v_head_size;
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);

int columns = args.kv_block_size + 2 + args.v_head_size; // columns in qk + qk_max + qk_sum + out
args.buffer_size_per_thread = static_cast<size_t>(args.q_block_size) * static_cast<size_t>(columns);

size_t total_buffer_size = args.buffer_size_per_thread * static_cast<size_t>(args.thread_count);
IAllocatorUniquePtr<float> buffer = IAllocator::MakeUniquePtr<float>(allocator, total_buffer_size);
args.buffer = buffer.get();

args.query = Q.Get<Tensor>().Data<float>();
args.key = K.Get<Tensor>().Data<float>();
args.value = V.Get<Tensor>().Data<float>();
args.output = output->MutableData<float>();

concurrency::ThreadPool::TrySimpleParallelFor(tp, args.thread_count, [&](std::ptrdiff_t thread_id) {
FlashAttentionThreaded(thread_id, &args);
});

return Status::OK();
}

// Compute the attention score and apply the score to V
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
K.GetMutable<Tensor>()->MutableData<T>(),
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/bert/attention_cpu_base.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -20,9 +19,6 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
int num_heads_; // number of attention heads
float mask_filter_value_;
bool is_unidirectional_;
bool disable_flash_;
int l2_cache_size_;
int algo_;
};

} // namespace contrib
Expand Down
45 changes: 0 additions & 45 deletions onnxruntime/core/mlas/inc/mlas_flashattn.h

This file was deleted.

157 changes: 0 additions & 157 deletions onnxruntime/core/mlas/lib/flashattn.cpp

This file was deleted.

2 changes: 0 additions & 2 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ class Env {

virtual std::vector<LogicalProcessors> GetDefaultThreadAffinities() const = 0;

virtual int GetL2CacheSize() const = 0;

/// \brief Returns the number of micro-seconds since the Unix epoch.
virtual uint64_t NowMicros() const {
return env_time_->NowMicros();
Expand Down
Loading

0 comments on commit 7550a31

Please sign in to comment.