Skip to content

Commit

Permalink
rename row to block, and tune block size
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jun 21, 2024
1 parent e68f60c commit afc4325
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 46 deletions.
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Environment variable for tuning attention algorithm
constexpr const char* kAttentionAlgo = "ORT_ATTENTION_ALGO";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;

Expand Down
30 changes: 19 additions & 11 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <vector>
#include <algorithm>

#include "contrib_ops/cpu/bert/attention_cpu_base.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cpu/bert/attention_utils.h"
#include "core/common/common.h"
Expand Down Expand Up @@ -47,6 +46,7 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
l2_cache_size_ = env.GetL2CacheSize();

disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
algo_ = ParseEnvironmentVariableWithDefault<int>(attention::kAttentionAlgo, 0);
}

template <typename T>
Expand Down Expand Up @@ -161,27 +161,35 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
present_k == nullptr &&
present_v == nullptr &&
l2_cache_size_ > 0) {
int row_size_kv = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
if (row_size_kv > 0) {
FlashAttentionThreadedArgs args;
FlashAttentionThreadedArgs args;
if (algo_ == 1) {
int q_block_size = q_sequence_length >= 768 ? 256 : (q_sequence_length >= 192 ? 64 : 32);
int kv_block_size = 512;
args.q_block_size = q_block_size > q_sequence_length ? q_sequence_length : q_block_size;
args.kv_block_size = kv_block_size > kv_sequence_length ? kv_sequence_length : kv_block_size;
} else {
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
}

if (args.kv_block_size > 0) {
args.batch_size = batch_size;
args.num_heads = num_heads_;
args.q_sequence_length = q_sequence_length;
args.kv_sequence_length = kv_sequence_length;
args.qk_head_size = qk_head_size;
args.v_head_size = v_head_size;
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;
args.row_size_kv = row_size_kv;
args.row_size_q = std::min(row_size_kv, qk_head_size + v_head_size);

auto* tp = context->GetOperatorThreadPool();
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);
args.buffer_size_per_thread = static_cast<size_t>(args.row_size_q) *
static_cast<size_t>(2 + args.row_size_kv + args.v_head_size) * sizeof(float);
size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count;
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(allocator, buffer_bytes);

args.buffer = reinterpret_cast<float*>(buffer.get());
int columns = args.kv_block_size + 2 + args.v_head_size; // qk + qk_max + qk_sum + dst
args.buffer_size_per_thread = static_cast<size_t>(args.q_block_size) * static_cast<size_t>(columns);

size_t total_buffer_size = args.buffer_size_per_thread * static_cast<size_t>(args.thread_count);
IAllocatorUniquePtr<float> buffer = IAllocator::MakeUniquePtr<float>(allocator, total_buffer_size);
args.buffer = buffer.get();

args.query = Q.Get<Tensor>().Data<float>();
args.key = K.Get<Tensor>().Data<float>();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/bert/attention_cpu_base.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -21,6 +22,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
bool is_unidirectional_;
bool disable_flash_;
int l2_cache_size_;
int algo_;
};

} // namespace contrib
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/inc/mlas_flashattn.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ struct FlashAttentionThreadedArgs {
int kv_sequence_length;
int qk_head_size;
int v_head_size;
int row_size_q;
int row_size_kv;
int q_block_size;
int kv_block_size;
float scale;
float* buffer;
size_t buffer_size_per_thread;
size_t buffer_size_per_thread; // Number of float elements in buffer for each thread
int thread_count;
const float* query;
const float* key;
Expand Down
64 changes: 32 additions & 32 deletions onnxruntime/core/mlas/lib/flashattn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ FlashAttentionThreaded(
const FlashAttentionThreadedArgs* args
)
{
ptrdiff_t row_size_q = static_cast<ptrdiff_t>(args->row_size_q);
ptrdiff_t row_size_kv = static_cast<ptrdiff_t>(args->row_size_kv);
ptrdiff_t q_block_size = static_cast<ptrdiff_t>(args->q_block_size);
ptrdiff_t kv_block_size = static_cast<ptrdiff_t>(args->kv_block_size);
ptrdiff_t batch_size = static_cast<ptrdiff_t>(args->batch_size);
ptrdiff_t num_heads = static_cast<ptrdiff_t>(args->num_heads);
ptrdiff_t q_sequence_length = static_cast<ptrdiff_t>(args->q_sequence_length);
Expand All @@ -28,11 +28,11 @@ FlashAttentionThreaded(
auto&& mlas_platform = GetMlasPlatform();
#endif

ptrdiff_t q_chunk_count = (q_sequence_length + (row_size_q - 1)) / row_size_q;
ptrdiff_t q_block_count = (q_sequence_length + (q_block_size - 1)) / q_block_size;

ptrdiff_t task_start = 0;
ptrdiff_t task_end = 0;
ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count;
ptrdiff_t total_task_count = batch_size * num_heads * q_block_count;
ptrdiff_t quotient = total_task_count / thread_count;
ptrdiff_t remainder = total_task_count % thread_count;
if (thread_id < remainder) {
Expand All @@ -45,46 +45,46 @@ FlashAttentionThreaded(

for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) {
ptrdiff_t ib = task_index;
ptrdiff_t il = (ib % q_chunk_count) * row_size_q;
ib /= q_chunk_count;
ptrdiff_t il = (ib % q_block_count) * q_block_size;
ib /= q_block_count;
ptrdiff_t ih = ib % num_heads;
ib /= num_heads;

char* buffer_current_thread = reinterpret_cast<char*>(buffer) + thread_id * buffer_size_per_thread;
float* l = reinterpret_cast<float*>(buffer_current_thread);
float* buffer_current_thread = buffer + thread_id * buffer_size_per_thread;
float* l = buffer_current_thread;

memset(l, 0, row_size_q * sizeof(float));
float* m = l + row_size_q;
for (ptrdiff_t t = 0; t < row_size_q; ++t) {
memset(l, 0, q_block_size * sizeof(float));
float* m = l + q_block_size;
for (ptrdiff_t t = 0; t < q_block_size; ++t) {
m[t] = std::numeric_limits<float>::lowest();
}
float* intermediate = m + row_size_q;
float* temp_output = intermediate + row_size_q * row_size_kv;
float* intermediate = m + q_block_size;
float* temp_output = intermediate + q_block_size * kv_block_size;
float negmax = 0;

for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += row_size_kv) {
for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) {
/*
S = Q[ib, ih, il:il+row_size_q, :] * (K[ib, ih, ir:ir+row_size_kv, :]).T
S = Q[ib, ih, il:il+q_block_size, :] * (K[ib, ih, ir:ir+kv_block_size, :]).T
old_m = m
m = max(m, rowmax(S))
diff = old_m - m
S = exp(S - m)
l = exp(diff) * l + rowsum(S)
O = diag(exp(diff)) * O + S * V[ib, ih, ir:ir+row_size_kv, :]
O = diag(exp(diff)) * O + S * V[ib, ih, ir:ir+kv_block_size, :]
*/
// TODO: Need to concat if past_k is present
ptrdiff_t h = ib * num_heads + ih;
const float* inputQ = query + (h * q_sequence_length + il) * qk_head_size;
const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size;
const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size;

size_t row_size_q_capped = static_cast<size_t>(std::min(row_size_q, q_sequence_length - il));
size_t row_size_kv_capped = static_cast<size_t>(std::min(row_size_kv, kv_sequence_length - ir));
size_t q_block_size_capped = static_cast<size_t>(std::min(q_block_size, q_sequence_length - il));
size_t kv_block_size_capped = static_cast<size_t>(std::min(kv_block_size, kv_sequence_length - ir));

MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasTrans,
row_size_q_capped,
row_size_kv_capped,
q_block_size_capped,
kv_block_size_capped,
static_cast<size_t>(qk_head_size),
args->scale,
inputQ,
Expand All @@ -93,26 +93,26 @@ FlashAttentionThreaded(
static_cast<size_t>(qk_head_size),
0.0f,
intermediate,
row_size_kv_capped,
kv_block_size_capped,
nullptr);

for (ptrdiff_t irow = 0; irow < static_cast<ptrdiff_t>(row_size_q_capped); ++irow) {
float* p = intermediate + irow * row_size_kv_capped;
for (ptrdiff_t irow = 0; irow < static_cast<ptrdiff_t>(q_block_size_capped); ++irow) {
float* p = intermediate + irow * kv_block_size_capped;

#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped);
float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, kv_block_size_capped);
#else
float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped);
float rowmax = MlasReduceMaximumF32Kernel(p, kv_block_size_capped);
#endif
float m_diff = m[irow];
m[irow] = std::max(m[irow], rowmax); // new m
negmax = -m[irow];
m_diff -= m[irow]; // old - new (less than 0)

#if defined(MLAS_TARGET_AMD64)
float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax);
#else
float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
float rowsum = MlasComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax);
#endif

// Note: for ir == 0, there is actually no need to calculate exp_diff
Expand All @@ -130,12 +130,12 @@ FlashAttentionThreaded(
}
MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasNoTrans,
row_size_q_capped,
q_block_size_capped,
static_cast<size_t>(v_head_size),
row_size_kv_capped,
kv_block_size_capped,
1.0f,
intermediate,
row_size_kv_capped,
kv_block_size_capped,
inputV,
static_cast<size_t>(v_head_size),
ir == 0 ? 0.0f : 1.0f,
Expand All @@ -145,9 +145,9 @@ FlashAttentionThreaded(
}

float* output_row = output + ((ib * q_sequence_length + il) * num_heads + ih) * v_head_size;
ptrdiff_t row_size_q_valid = std::min(row_size_q, q_sequence_length - il);
ptrdiff_t q_block_size_valid = std::min(q_block_size, q_sequence_length - il);
// TODO: leverage advanced instruction sets
for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) {
for (ptrdiff_t irow = 0; irow < q_block_size_valid; ++irow) {
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow];
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def run_tflops_test(
# List of environment variables to enable/disable attention kernels
print("Environment Variables:")
env_names = [
"ORT_ATTENTION_ALGO",
"ORT_DISABLE_FLASH_ATTENTION",
"ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV",
"ORT_DISABLE_FUSED_ATTENTION",
Expand Down

0 comments on commit afc4325

Please sign in to comment.