diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a5b9c84c63eb9..d81437954e3ad 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -166,6 +166,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; +// Environment variable for tuning attention algorithm +constexpr const char* kAttentionAlgo = "ORT_ATTENTION_ALGO"; + // Minimum sequence length to enable memory efficient attention in FP32. constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index e019a2b5affd0..cd3f70b90be83 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -5,7 +5,6 @@ #include #include -#include "contrib_ops/cpu/bert/attention_cpu_base.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cpu/bert/attention_utils.h" #include "core/common/common.h" @@ -47,6 +46,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i l2_cache_size_ = env.GetL2CacheSize(); disable_flash_ = ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + algo_ = ParseEnvironmentVariableWithDefault(attention::kAttentionAlgo, 0); } template @@ -161,9 +161,18 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { present_k == nullptr && present_v == nullptr && l2_cache_size_ > 0) { - int row_size_kv = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); - if (row_size_kv > 0) { - FlashAttentionThreadedArgs args; + FlashAttentionThreadedArgs args; + if (algo_ == 1) { + int q_block_size = q_sequence_length >= 768 ? 256 : (q_sequence_length >= 192 ? 64 : 32); + int kv_block_size = 512; + args.q_block_size = q_block_size > q_sequence_length ? q_sequence_length : q_block_size; + args.kv_block_size = kv_block_size > kv_sequence_length ? kv_sequence_length : kv_block_size; + } else { + args.kv_block_size = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); + args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size); + } + + if (args.kv_block_size > 0) { args.batch_size = batch_size; args.num_heads = num_heads_; args.q_sequence_length = q_sequence_length; @@ -171,17 +180,16 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { args.qk_head_size = qk_head_size; args.v_head_size = v_head_size; args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast(qk_head_size)) : scale_; - args.row_size_kv = row_size_kv; - args.row_size_q = std::min(row_size_kv, qk_head_size + v_head_size); auto* tp = context->GetOperatorThreadPool(); args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); - args.buffer_size_per_thread = static_cast(args.row_size_q) * - static_cast(2 + args.row_size_kv + args.v_head_size) * sizeof(float); - size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count; - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_bytes); - args.buffer = reinterpret_cast(buffer.get()); + int columns = args.kv_block_size + 2 + args.v_head_size; // qk + qk_max + qk_sum + dst + args.buffer_size_per_thread = static_cast(args.q_block_size) * static_cast(columns); + + size_t total_buffer_size = args.buffer_size_per_thread * static_cast(args.thread_count); + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, total_buffer_size); + args.buffer = buffer.get(); args.query = Q.Get().Data(); args.key = K.Get().Data(); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 8a9bef1b2bf0d..17625cb61acc6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -5,6 +5,7 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/bert/attention_cpu_base.h" namespace onnxruntime { namespace contrib { @@ -21,6 +22,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { bool is_unidirectional_; bool disable_flash_; int l2_cache_size_; + int algo_; }; } // namespace contrib diff --git a/onnxruntime/core/mlas/inc/mlas_flashattn.h b/onnxruntime/core/mlas/inc/mlas_flashattn.h index 280e5b31cb267..016a728547b80 100644 --- a/onnxruntime/core/mlas/inc/mlas_flashattn.h +++ b/onnxruntime/core/mlas/inc/mlas_flashattn.h @@ -26,11 +26,11 @@ struct FlashAttentionThreadedArgs { int kv_sequence_length; int qk_head_size; int v_head_size; - int row_size_q; - int row_size_kv; + int q_block_size; + int kv_block_size; float scale; float* buffer; - size_t buffer_size_per_thread; + size_t buffer_size_per_thread; // Number of float elements in buffer for each thread int thread_count; const float* query; const float* key; diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp index ed7f0379961a0..e104824336c8b 100644 --- a/onnxruntime/core/mlas/lib/flashattn.cpp +++ b/onnxruntime/core/mlas/lib/flashattn.cpp @@ -8,8 +8,8 @@ FlashAttentionThreaded( const FlashAttentionThreadedArgs* args ) { - ptrdiff_t row_size_q = static_cast(args->row_size_q); - ptrdiff_t row_size_kv = static_cast(args->row_size_kv); + ptrdiff_t q_block_size = static_cast(args->q_block_size); + ptrdiff_t kv_block_size = static_cast(args->kv_block_size); ptrdiff_t batch_size = static_cast(args->batch_size); ptrdiff_t num_heads = static_cast(args->num_heads); ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); @@ -28,11 +28,11 @@ FlashAttentionThreaded( auto&& mlas_platform = GetMlasPlatform(); #endif - ptrdiff_t q_chunk_count = (q_sequence_length + (row_size_q - 1)) / row_size_q; + ptrdiff_t q_block_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; ptrdiff_t task_start = 0; ptrdiff_t task_end = 0; - ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + ptrdiff_t total_task_count = batch_size * num_heads * q_block_count; ptrdiff_t quotient = total_task_count / thread_count; ptrdiff_t remainder = total_task_count % thread_count; if (thread_id < remainder) { @@ -45,32 +45,32 @@ FlashAttentionThreaded( for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { ptrdiff_t ib = task_index; - ptrdiff_t il = (ib % q_chunk_count) * row_size_q; - ib /= q_chunk_count; + ptrdiff_t il = (ib % q_block_count) * q_block_size; + ib /= q_block_count; ptrdiff_t ih = ib % num_heads; ib /= num_heads; - char* buffer_current_thread = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; - float* l = reinterpret_cast(buffer_current_thread); + float* buffer_current_thread = buffer + thread_id * buffer_size_per_thread; + float* l = buffer_current_thread; - memset(l, 0, row_size_q * sizeof(float)); - float* m = l + row_size_q; - for (ptrdiff_t t = 0; t < row_size_q; ++t) { + memset(l, 0, q_block_size * sizeof(float)); + float* m = l + q_block_size; + for (ptrdiff_t t = 0; t < q_block_size; ++t) { m[t] = std::numeric_limits::lowest(); } - float* intermediate = m + row_size_q; - float* temp_output = intermediate + row_size_q * row_size_kv; + float* intermediate = m + q_block_size; + float* temp_output = intermediate + q_block_size * kv_block_size; float negmax = 0; - for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += row_size_kv) { + for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { /* - S = Q[ib, ih, il:il+row_size_q, :] * (K[ib, ih, ir:ir+row_size_kv, :]).T + S = Q[ib, ih, il:il+q_block_size, :] * (K[ib, ih, ir:ir+kv_block_size, :]).T old_m = m m = max(m, rowmax(S)) diff = old_m - m S = exp(S - m) l = exp(diff) * l + rowsum(S) - O = diag(exp(diff)) * O + S * V[ib, ih, ir:ir+row_size_kv, :] + O = diag(exp(diff)) * O + S * V[ib, ih, ir:ir+kv_block_size, :] */ // TODO: Need to concat if past_k is present ptrdiff_t h = ib * num_heads + ih; @@ -78,13 +78,13 @@ FlashAttentionThreaded( const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; - size_t row_size_q_capped = static_cast(std::min(row_size_q, q_sequence_length - il)); - size_t row_size_kv_capped = static_cast(std::min(row_size_kv, kv_sequence_length - ir)); + size_t q_block_size_capped = static_cast(std::min(q_block_size, q_sequence_length - il)); + size_t kv_block_size_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans, - row_size_q_capped, - row_size_kv_capped, + q_block_size_capped, + kv_block_size_capped, static_cast(qk_head_size), args->scale, inputQ, @@ -93,16 +93,16 @@ FlashAttentionThreaded( static_cast(qk_head_size), 0.0f, intermediate, - row_size_kv_capped, + kv_block_size_capped, nullptr); - for (ptrdiff_t irow = 0; irow < static_cast(row_size_q_capped); ++irow) { - float* p = intermediate + irow * row_size_kv_capped; + for (ptrdiff_t irow = 0; irow < static_cast(q_block_size_capped); ++irow) { + float* p = intermediate + irow * kv_block_size_capped; #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped); + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, kv_block_size_capped); #else - float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped); + float rowmax = MlasReduceMaximumF32Kernel(p, kv_block_size_capped); #endif float m_diff = m[irow]; m[irow] = std::max(m[irow], rowmax); // new m @@ -110,9 +110,9 @@ FlashAttentionThreaded( m_diff -= m[irow]; // old - new (less than 0) #if defined(MLAS_TARGET_AMD64) - float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax); #else - float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); + float rowsum = MlasComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax); #endif // Note: for ir == 0, there is actually no need to calculate exp_diff @@ -130,12 +130,12 @@ FlashAttentionThreaded( } MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasNoTrans, - row_size_q_capped, + q_block_size_capped, static_cast(v_head_size), - row_size_kv_capped, + kv_block_size_capped, 1.0f, intermediate, - row_size_kv_capped, + kv_block_size_capped, inputV, static_cast(v_head_size), ir == 0 ? 0.0f : 1.0f, @@ -145,9 +145,9 @@ FlashAttentionThreaded( } float* output_row = output + ((ib * q_sequence_length + il) * num_heads + ih) * v_head_size; - ptrdiff_t row_size_q_valid = std::min(row_size_q, q_sequence_length - il); + ptrdiff_t q_block_size_valid = std::min(q_block_size, q_sequence_length - il); // TODO: leverage advanced instruction sets - for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) { + for (ptrdiff_t irow = 0; irow < q_block_size_valid; ++irow) { for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow]; } diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index c9e3a11ff5a7b..aaa12b3cc012d 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -429,6 +429,7 @@ def run_tflops_test( # List of environment variables to enable/disable attention kernels print("Environment Variables:") env_names = [ + "ORT_ATTENTION_ALGO", "ORT_DISABLE_FLASH_ATTENTION", "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", "ORT_DISABLE_FUSED_ATTENTION",