Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FlashAttention for CPU #20805

Merged
merged 36 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
37a3175
Register new contrib op FlashAttention
duanqn May 7, 2024
4ebe454
Move getenv to constructor
duanqn Jun 14, 2024
0d65ce2
Get Env
duanqn Jun 14, 2024
42b2acb
Renaming
duanqn Jun 14, 2024
88a2600
Check for T==float
duanqn Jun 17, 2024
53e2e85
Lintrunner
duanqn Jun 17, 2024
945f656
Remove mlas function
duanqn Jun 19, 2024
ee323fb
Handle scale; Require present_k and present_v to be empty
duanqn Jun 20, 2024
63e76ad
Check is_unidirectional_
duanqn Jun 20, 2024
3d6368b
fix build
tianleiwu Jun 21, 2024
1fba73a
Merge with mlas.h
duanqn Jun 21, 2024
9479623
Add comment and MLASCALL
duanqn Jun 21, 2024
1e63e82
Remove unnecessary change
duanqn Jun 21, 2024
1fd0813
Fix onnxruntime_mlas.cmake
duanqn Jun 21, 2024
afb7466
Pick onnxruntime/test/python/transformers/benchmark_mha.py from lates…
duanqn Jun 21, 2024
327b4c2
Disable FlashAttention by default
duanqn Jun 21, 2024
ab0da5b
Fix value choice of row_size_q and row_size_kv; Add comments
duanqn Jun 21, 2024
8b19094
Fix order
duanqn Jun 21, 2024
8b2270a
causal=False
duanqn Jun 21, 2024
b449524
Add MLASCALL on implementation
duanqn Jun 24, 2024
06251b1
Improve comment
duanqn Jun 24, 2024
27b18d4
Enable FlashAttention by default
duanqn Jun 24, 2024
3059b44
lintrunner -a
duanqn Jun 24, 2024
412f219
Remove memset
duanqn Jun 24, 2024
44ff8f0
Fix l2_cache_size_
duanqn Jun 26, 2024
5421335
Fix PREfast
duanqn Jun 26, 2024
03d8f36
#include <algorithm>
duanqn Jun 26, 2024
d63e528
Fix bug
duanqn Jun 28, 2024
7a3d4a6
lintrunner
duanqn Jun 28, 2024
bf014d0
Renaming
duanqn Jul 5, 2024
baff456
Renaming
duanqn Jul 5, 2024
72f3c67
Use MlasSgemmOperation
duanqn Jul 5, 2024
e8a4373
Move threading inside MLAS kernel
duanqn Jul 8, 2024
46a8ce9
Remove MLASCALL
Jul 9, 2024
e1cf289
Remove 1 TODO
duanqn Jul 10, 2024
852fd98
Renaming
duanqn Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down
86 changes: 82 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/env_var_utils.h"
#include "core/platform/threadpool.h"
#include "core/mlas/inc/mlas.h"

#include <algorithm>
#include <type_traits>
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>

Expand Down Expand Up @@ -39,6 +43,11 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i

mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;

const auto& env = Env::Default();
l2_cache_size_ = env.GetL2CacheSize();

disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
}

template <typename T>
Expand All @@ -60,7 +69,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
}

AttentionParameters parameters = {};
constexpr float scale = 1.0f;
bool past_present_share_buffer = false;
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
Expand All @@ -74,7 +82,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
&parameters,
num_heads_,
mask_filter_value_,
scale,
scale_,
is_unidirectional_,
past_present_share_buffer,
false));
Expand All @@ -99,8 +107,14 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
const int v_bias_offset = 2 * qk_hidden_size;

// If optional outputs aren't needed, present_k and present_v will be null
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(v_head_size)});
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(qk_head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(num_heads_),
static_cast<int64_t>(total_kv_sequence_length),
static_cast<int64_t>(v_head_size)});
Tensor* present_k = context->Output(1, present_k_shape);
Tensor* present_v = context->Output(2, present_v_shape);

Expand Down Expand Up @@ -138,6 +152,70 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));

if (std::is_same_v<T, float> &&
!disable_flash_ &&
!is_unidirectional_ &&
key_padding_mask == nullptr &&
extra_add_qk == nullptr &&
past_key == nullptr &&
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
l2_cache_size_ > 0) {
MlasFlashAttentionThreadedArgs args;
args.batch_size = batch_size;
args.num_heads = num_heads_;
args.q_sequence_length = q_sequence_length;
args.kv_sequence_length = kv_sequence_length;
args.qk_head_size = qk_head_size;
args.v_head_size = v_head_size;
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;
/*
q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper.
Let M = l2_cache_size / sizeof(float)
In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache:
slice of Q -- [Br, qk_head_size]
slice of K -- [Bc, qk_head_size]
slice of V -- [Bc, v_head_size]
result of QK -- [Br, Bc]
temporary output (same shape as QKV) -- [Br, v_head_size]
The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have
(Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
<= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc
<= 2 * Bc * (qk_head_size + v_head_size) + M/4
<= 2 * M/4 + M/4 = M * (3/4)

We leave 1/4 of the L2 cache for
1. storing small tensors l and m
2. instruction (code)
*/
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length
args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);
args.buffer_size_per_thread = (static_cast<size_t>(args.q_block_size) * 2 +
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.kv_block_size) +
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.v_head_size)) *
sizeof(float);
size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count;
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(allocator, buffer_bytes);

args.buffer = reinterpret_cast<float*>(buffer.get());

args.query = Q.Get<Tensor>().Data<float>();
args.key = K.Get<Tensor>().Data<float>();
args.value = V.Get<Tensor>().Data<float>();
args.output = output->MutableData<float>();

MlasFlashAttention(&args, tp);
return Status::OK();
}

// Compute the attention score and apply the score to V
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
K.GetMutable<Tensor>()->MutableData<T>(),
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
int num_heads_; // number of attention heads
float mask_filter_value_;
bool is_unidirectional_;
bool disable_flash_;
int l2_cache_size_;
};

} // namespace contrib
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1825,3 +1825,35 @@ MlasNhwcAvgPool(
);

#endif

struct MlasFlashAttentionThreadedArgs {
int batch_size;
int num_heads;
int q_sequence_length;
int kv_sequence_length;
int qk_head_size;
int v_head_size;
int q_block_size;
int kv_block_size;
float scale;
int thread_count;
float* buffer;
size_t buffer_size_per_thread;
const float* query;
const float* key;
const float* value;
float* output;
};

/**
* @brief Per-thread worker function for fp32 Flash Attention
duanqn marked this conversation as resolved.
Show resolved Hide resolved
* @param thread_id Thread index
* @param args Arguments
* @return
*/
void
MLASCALL
MlasFlashAttention(
MlasFlashAttentionThreadedArgs* args,
MLAS_THREADPOOL* ThreadPool
);
167 changes: 167 additions & 0 deletions onnxruntime/core/mlas/lib/flashattn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#include <numeric>

#include "mlasi.h"

void
duanqn marked this conversation as resolved.
Show resolved Hide resolved
MlasFlashAttentionThreaded(
void* argptr,
std::ptrdiff_t thread_id
)
{
const MlasFlashAttentionThreadedArgs* args = reinterpret_cast<MlasFlashAttentionThreadedArgs*>(argptr);
ptrdiff_t q_block_size = static_cast<ptrdiff_t>(args->q_block_size);
ptrdiff_t kv_block_size = static_cast<ptrdiff_t>(args->kv_block_size);
ptrdiff_t batch_size = static_cast<ptrdiff_t>(args->batch_size);
ptrdiff_t num_heads = static_cast<ptrdiff_t>(args->num_heads);
ptrdiff_t q_sequence_length = static_cast<ptrdiff_t>(args->q_sequence_length);
ptrdiff_t kv_sequence_length = static_cast<ptrdiff_t>(args->kv_sequence_length);
ptrdiff_t qk_head_size = static_cast<ptrdiff_t>(args->qk_head_size);
ptrdiff_t v_head_size = static_cast<ptrdiff_t>(args->v_head_size);
float* buffer = args->buffer;
ptrdiff_t buffer_size_per_thread = static_cast<ptrdiff_t>(args->buffer_size_per_thread);
ptrdiff_t thread_count = static_cast<ptrdiff_t>(args->thread_count);
const float* query = args->query;
const float* key = args->key;
const float* value = args->value;
float* output = args->output;

#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
auto&& mlas_platform = GetMlasPlatform();
#endif

ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size;

ptrdiff_t task_start = 0;
ptrdiff_t task_end = 0;
ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count;
ptrdiff_t quotient = total_task_count / thread_count;
ptrdiff_t remainder = total_task_count % thread_count;
if (thread_id < remainder) {
task_start = (quotient + 1) * thread_id;
task_end = task_start + quotient + 1;
} else {
task_start = quotient * thread_id + remainder;
task_end = task_start + quotient;
}

for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) {
ptrdiff_t batch_idx = task_index;
ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size;
batch_idx /= q_chunk_count;
ptrdiff_t head_idx = batch_idx % num_heads;
batch_idx /= num_heads;

char* buffer_current_thread = reinterpret_cast<char*>(buffer) + thread_id * buffer_size_per_thread;
float* l = reinterpret_cast<float*>(buffer_current_thread);
float* m = l + q_block_size;
for (ptrdiff_t t = 0; t < q_block_size; ++t) {
m[t] = std::numeric_limits<float>::lowest();
}
float* intermediate = m + q_block_size;
float* temp_output = intermediate + q_block_size * kv_block_size;
float negmax = 0;

for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) {
/*
S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T
old_m = m
m = max(m, rowmax(S))
diff = old_m - m
S = exp(S - m)
l = exp(diff) * l + rowsum(S)
O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :]
*/
ptrdiff_t h = batch_idx * num_heads + head_idx;
const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size;
const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size;
const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size;

size_t row_size_q_capped = static_cast<size_t>(std::min(q_block_size, q_sequence_length - q_idx));
size_t row_size_kv_capped = static_cast<size_t>(std::min(kv_block_size, kv_sequence_length - ir));

MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasTrans,
row_size_q_capped,
row_size_kv_capped,
static_cast<size_t>(qk_head_size),
args->scale,
inputQ,
static_cast<size_t>(qk_head_size),
inputK,
static_cast<size_t>(qk_head_size),
0.0f,
intermediate,
row_size_kv_capped);

for (ptrdiff_t irow = 0; irow < static_cast<ptrdiff_t>(row_size_q_capped); ++irow) {
float* p = intermediate + irow * row_size_kv_capped;

#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped);
#else
float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped);
#endif
float m_diff = m[irow];
m[irow] = std::max(m[irow], rowmax); // new m
negmax = -m[irow];
m_diff -= m[irow]; // old - new (less than 0)

#if defined(MLAS_TARGET_AMD64)
float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
#else
float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
#endif

// Note: for ir == 0, there is actually no need to calculate exp_diff
if (ir != 0) {
float exp_diff = std::exp(m_diff);
l[irow] = exp_diff * l[irow] + rowsum;

for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol];
}
} else {
l[irow] = rowsum;
// When ir == 0, there is no need to scale the old result because it is zero.
}
}
MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasNoTrans,
row_size_q_capped,
static_cast<size_t>(v_head_size),
row_size_kv_capped,
1.0f,
intermediate,
row_size_kv_capped,
inputV,
static_cast<size_t>(v_head_size),
ir == 0 ? 0.0f : 1.0f,
temp_output,
static_cast<size_t>(v_head_size));
}

float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size;
ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx);
// TODO: leverage advanced instruction sets
for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) {
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow];
}
output_row += num_heads * v_head_size;
}
}
}

void
MLASCALL
MlasFlashAttention(
MlasFlashAttentionThreadedArgs* args,
MLAS_THREADPOOL* ThreadPool
)
{
MlasExecuteThreaded(
MlasFlashAttentionThreaded,
static_cast<void *>(args),
static_cast<std::ptrdiff_t>(args->thread_count),
ThreadPool);
}
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class Env {

virtual std::vector<LogicalProcessors> GetDefaultThreadAffinities() const = 0;

virtual int GetL2CacheSize() const = 0;

/// \brief Returns the number of micro-seconds since the Unix epoch.
virtual uint64_t NowMicros() const {
return env_time_->NowMicros();
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ limitations under the License.
#define ORT_USE_CPUINFO
#endif

#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
#include <sys/sysctl.h>
#endif

#include "core/common/common.h"
#include <gsl/gsl>
#include "core/common/logging/logging.h"
Expand Down Expand Up @@ -302,6 +306,22 @@ class PosixEnv : public Env {
return ret;
}

int GetL2CacheSize() const override {
#ifdef _SC_LEVEL2_CACHE_SIZE
return static_cast<int>(sysconf(_SC_LEVEL2_CACHE_SIZE));
#else
int value = 0; // unknown
#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE)
int mib[2] = {CTL_HW, HW_L2CACHESIZE};
size_t len = sizeof(value);
if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) {
return -1; // error
}
#endif
return value;
#endif
}

void SleepForMicroseconds(int64_t micros) const override {
while (micros > 0) {
timespec sleep_time;
Expand Down
Loading
Loading