diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 38be417767f8b..304aa77f5473c 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -39,7 +39,6 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -48,7 +47,6 @@ target_sources(onnxruntime_mlas PRIVATE ${MLAS_INC_DIR}/mlas_q4.h ${MLAS_INC_DIR}/mlas_qnbit.h ${MLAS_INC_DIR}/mlas.h - ${MLAS_INC_DIR}/mlas_flashattn.h ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index d81437954e3ad..a5b9c84c63eb9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -166,9 +166,6 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF // Environment variable to enable or disable flash attention. Default is 0 (enabled). constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; -// Environment variable for tuning attention algorithm -constexpr const char* kAttentionAlgo = "ORT_ATTENTION_ALGO"; - // Minimum sequence length to enable memory efficient attention in FP32. constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 02ee9bf0e85bd..b39167f4498e0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -1,21 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cpu/bert/multihead_attention.h" -#include -#include -#include -#include "contrib_ops/cpu/bert/multihead_attention_helper.h" -#include "contrib_ops/cpu/bert/attention_utils.h" +#include "attention_cpu_base.h" +#include "multihead_attention.h" +#include "multihead_attention_helper.h" +#include "attention_utils.h" + #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/common/safeint.h" -#include "core/platform/env_var_utils.h" #include "core/platform/threadpool.h" -#include "core/mlas/inc/mlas_flashattn.h" #include +#include using onnxruntime::concurrency::ThreadPool; @@ -41,12 +39,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - - const auto& env = Env::Default(); - l2_cache_size_ = env.GetL2CacheSize(); - - disable_flash_ = ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); - algo_ = ParseEnvironmentVariableWithDefault(attention::kAttentionAlgo, 0); } template @@ -68,6 +60,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { } AttentionParameters parameters = {}; + constexpr float scale = 1.0f; bool past_present_share_buffer = false; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -81,7 +74,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ¶meters, num_heads_, mask_filter_value_, - scale_, + scale, is_unidirectional_, past_present_share_buffer, false)); @@ -106,14 +99,8 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { const int v_bias_offset = 2 * qk_hidden_size; // If optional outputs aren't needed, present_k and present_v will be null - std::vector present_k_shape({static_cast(batch_size), - static_cast(num_heads_), - static_cast(total_kv_sequence_length), - static_cast(qk_head_size)}); - std::vector present_v_shape({static_cast(batch_size), - static_cast(num_heads_), - static_cast(total_kv_sequence_length), - static_cast(v_head_size)}); + std::vector present_k_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(qk_head_size)}); + std::vector present_v_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(v_head_size)}); Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); @@ -151,59 +138,6 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V)); - if (std::is_same_v && - !disable_flash_ && - !is_unidirectional_ && - key_padding_mask == nullptr && - extra_add_qk == nullptr && - past_key == nullptr && - past_value == nullptr && - present_k == nullptr && - present_v == nullptr && - l2_cache_size_ > 0) { - FlashAttentionThreadedArgs args; - - if (algo_ == 1) { - args.q_block_size = q_sequence_length >= 768 ? 256 : (q_sequence_length >= 192 ? 64 : 32); - args.kv_block_size = 512; - } else { - args.kv_block_size = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); - args.kv_block_size = std::max(args.kv_block_size, 1); // avoid row_size_kv = 0 - args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size); - } - args.q_block_size = std::min(args.q_block_size, q_sequence_length); - args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); - - args.batch_size = batch_size; - args.num_heads = num_heads_; - args.q_sequence_length = q_sequence_length; - args.kv_sequence_length = kv_sequence_length; - args.qk_head_size = qk_head_size; - args.v_head_size = v_head_size; - args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast(qk_head_size)) : scale_; - - auto* tp = context->GetOperatorThreadPool(); - args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); - - int columns = args.kv_block_size + 2 + args.v_head_size; // columns in qk + qk_max + qk_sum + out - args.buffer_size_per_thread = static_cast(args.q_block_size) * static_cast(columns); - - size_t total_buffer_size = args.buffer_size_per_thread * static_cast(args.thread_count); - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, total_buffer_size); - args.buffer = buffer.get(); - - args.query = Q.Get().Data(); - args.key = K.Get().Data(); - args.value = V.Get().Data(); - args.output = output->MutableData(); - - concurrency::ThreadPool::TrySimpleParallelFor(tp, args.thread_count, [&](std::ptrdiff_t thread_id) { - FlashAttentionThreaded(thread_id, &args); - }); - - return Status::OK(); - } - // Compute the attention score and apply the score to V return ApplyAttention(Q.GetMutable()->MutableData(), K.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 17625cb61acc6..fb7da78a5c0a5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -5,7 +5,6 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/bert/attention_cpu_base.h" namespace onnxruntime { namespace contrib { @@ -20,9 +19,6 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { int num_heads_; // number of attention heads float mask_filter_value_; bool is_unidirectional_; - bool disable_flash_; - int l2_cache_size_; - int algo_; }; } // namespace contrib diff --git a/onnxruntime/core/mlas/inc/mlas_flashattn.h b/onnxruntime/core/mlas/inc/mlas_flashattn.h deleted file mode 100644 index 016a728547b80..0000000000000 --- a/onnxruntime/core/mlas/inc/mlas_flashattn.h +++ /dev/null @@ -1,45 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - mlas_flashattn.h - -Abstract: - - Utilities for FlashAttention on CPU. Used internally - by MLAS on platforms without half precision support. Provided here as - convenience for tests or other client libraries/apps. - ---*/ - -#pragma once -#include - -struct FlashAttentionThreadedArgs { - int batch_size; - int num_heads; - int q_sequence_length; - int kv_sequence_length; - int qk_head_size; - int v_head_size; - int q_block_size; - int kv_block_size; - float scale; - float* buffer; - size_t buffer_size_per_thread; // Number of float elements in buffer for each thread - int thread_count; - const float* query; - const float* key; - const float* value; - float* output; -}; - -void -FlashAttentionThreaded( - std::ptrdiff_t thread_id, - const FlashAttentionThreadedArgs* args -); diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp deleted file mode 100644 index e104824336c8b..0000000000000 --- a/onnxruntime/core/mlas/lib/flashattn.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include "mlas_flashattn.h" -#include -#include "mlasi.h" - -void -FlashAttentionThreaded( - std::ptrdiff_t thread_id, - const FlashAttentionThreadedArgs* args -) -{ - ptrdiff_t q_block_size = static_cast(args->q_block_size); - ptrdiff_t kv_block_size = static_cast(args->kv_block_size); - ptrdiff_t batch_size = static_cast(args->batch_size); - ptrdiff_t num_heads = static_cast(args->num_heads); - ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); - ptrdiff_t kv_sequence_length = static_cast(args->kv_sequence_length); - ptrdiff_t qk_head_size = static_cast(args->qk_head_size); - ptrdiff_t v_head_size = static_cast(args->v_head_size); - float* buffer = args->buffer; - ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); - ptrdiff_t thread_count = static_cast(args->thread_count); - const float* query = args->query; - const float* key = args->key; - const float* value = args->value; - float* output = args->output; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - auto&& mlas_platform = GetMlasPlatform(); -#endif - - ptrdiff_t q_block_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; - - ptrdiff_t task_start = 0; - ptrdiff_t task_end = 0; - ptrdiff_t total_task_count = batch_size * num_heads * q_block_count; - ptrdiff_t quotient = total_task_count / thread_count; - ptrdiff_t remainder = total_task_count % thread_count; - if (thread_id < remainder) { - task_start = (quotient + 1) * thread_id; - task_end = task_start + quotient + 1; - } else { - task_start = quotient * thread_id + remainder; - task_end = task_start + quotient; - } - - for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { - ptrdiff_t ib = task_index; - ptrdiff_t il = (ib % q_block_count) * q_block_size; - ib /= q_block_count; - ptrdiff_t ih = ib % num_heads; - ib /= num_heads; - - float* buffer_current_thread = buffer + thread_id * buffer_size_per_thread; - float* l = buffer_current_thread; - - memset(l, 0, q_block_size * sizeof(float)); - float* m = l + q_block_size; - for (ptrdiff_t t = 0; t < q_block_size; ++t) { - m[t] = std::numeric_limits::lowest(); - } - float* intermediate = m + q_block_size; - float* temp_output = intermediate + q_block_size * kv_block_size; - float negmax = 0; - - for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { - /* - S = Q[ib, ih, il:il+q_block_size, :] * (K[ib, ih, ir:ir+kv_block_size, :]).T - old_m = m - m = max(m, rowmax(S)) - diff = old_m - m - S = exp(S - m) - l = exp(diff) * l + rowsum(S) - O = diag(exp(diff)) * O + S * V[ib, ih, ir:ir+kv_block_size, :] - */ - // TODO: Need to concat if past_k is present - ptrdiff_t h = ib * num_heads + ih; - const float* inputQ = query + (h * q_sequence_length + il) * qk_head_size; - const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; - const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; - - size_t q_block_size_capped = static_cast(std::min(q_block_size, q_sequence_length - il)); - size_t kv_block_size_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); - - MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans, - CBLAS_TRANSPOSE::CblasTrans, - q_block_size_capped, - kv_block_size_capped, - static_cast(qk_head_size), - args->scale, - inputQ, - static_cast(qk_head_size), - inputK, - static_cast(qk_head_size), - 0.0f, - intermediate, - kv_block_size_capped, - nullptr); - - for (ptrdiff_t irow = 0; irow < static_cast(q_block_size_capped); ++irow) { - float* p = intermediate + irow * kv_block_size_capped; - -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, kv_block_size_capped); -#else - float rowmax = MlasReduceMaximumF32Kernel(p, kv_block_size_capped); -#endif - float m_diff = m[irow]; - m[irow] = std::max(m[irow], rowmax); // new m - negmax = -m[irow]; - m_diff -= m[irow]; // old - new (less than 0) - -#if defined(MLAS_TARGET_AMD64) - float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax); -#else - float rowsum = MlasComputeSumExpF32Kernel(p, p, kv_block_size_capped, &negmax); -#endif - - // Note: for ir == 0, there is actually no need to calculate exp_diff - if (ir != 0) { - float exp_diff = std::exp(m_diff); - l[irow] = exp_diff * l[irow] + rowsum; - - for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { - temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol]; - } - } else { - l[irow] = rowsum; - // When ir == 0, there is no need to scale the old result because it is zero. - } - } - MlasGemm(CBLAS_TRANSPOSE::CblasNoTrans, - CBLAS_TRANSPOSE::CblasNoTrans, - q_block_size_capped, - static_cast(v_head_size), - kv_block_size_capped, - 1.0f, - intermediate, - kv_block_size_capped, - inputV, - static_cast(v_head_size), - ir == 0 ? 0.0f : 1.0f, - temp_output, - static_cast(v_head_size), - nullptr); - } - - float* output_row = output + ((ib * q_sequence_length + il) * num_heads + ih) * v_head_size; - ptrdiff_t q_block_size_valid = std::min(q_block_size, q_sequence_length - il); - // TODO: leverage advanced instruction sets - for (ptrdiff_t irow = 0; irow < q_block_size_valid; ++irow) { - for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { - output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow]; - } - output_row += num_heads * v_head_size; - } - } -} diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index fd79abd4c908d..6917f42091bf3 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -147,8 +147,6 @@ class Env { virtual std::vector GetDefaultThreadAffinities() const = 0; - virtual int GetL2CacheSize() const = 0; - /// \brief Returns the number of micro-seconds since the Unix epoch. virtual uint64_t NowMicros() const { return env_time_->NowMicros(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 2fbe0ae9a91e1..9999550c241c8 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -43,10 +43,6 @@ limitations under the License. #define ORT_USE_CPUINFO #endif -#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) -#include -#endif - #include "core/common/common.h" #include "core/common/gsl.h" #include "core/common/logging/logging.h" @@ -306,22 +302,6 @@ class PosixEnv : public Env { return ret; } - int GetL2CacheSize() const override { -#ifdef _SC_LEVEL2_CACHE_SIZE - return static_cast(sysconf(_SC_LEVEL2_CACHE_SIZE)); -#else - int value = 0; // unknown -#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE) - int mib[2] = {CTL_HW, HW_L2CACHESIZE}; - size_t len = sizeof(value); - if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) { - return -1; // error - } -#endif - return value; -#endif - } - void SleepForMicroseconds(int64_t micros) const override { while (micros > 0) { timespec sleep_time; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 368688f617e79..dc090e446e60f 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -16,7 +16,6 @@ limitations under the License. #include "core/platform/windows/env.h" -#include #include #include #include @@ -304,10 +303,6 @@ std::vector WindowsEnv::GetDefaultThreadAffinities() const { return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; } -int WindowsEnv::GetL2CacheSize() const { - return l2_cache_size; -} - WindowsEnv& WindowsEnv::Instance() { static WindowsEnv default_env; return default_env; @@ -929,57 +924,9 @@ void WindowsEnv::InitializeCpuInfo() { } iter += size; } - - DWORD newLength = 0; - GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); - last_error = GetLastError(); - if (last_error != ERROR_INSUFFICIENT_BUFFER) { - const auto error_code = GetLastError(); - if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" - << ", error code: " << error_code - << ", error msg: " << std::system_category().message(error_code); - } - return; - } - - if (newLength > returnLength) { - // Re-allocate - allocation = std::make_unique(newLength); - processorInfos = reinterpret_cast(allocation.get()); - } - - if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { - const auto error_code = GetLastError(); - if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" - << ", error code: " << error_code - << ", error msg: " << std::system_category().message(error_code); - } - return; - } - - iter = reinterpret_cast(processorInfos); - end = iter + newLength; - - while (iter < end) { - auto processor_info = reinterpret_cast(iter); - auto size = processor_info->Size; - - if (processor_info->Relationship == RelationCache && - processor_info->Cache.Level == 2) { - // L2 cache - l2_cache_size = static_cast(processor_info->Cache.CacheSize); - break; - } - - iter += size; - } - if (logging::LoggingManager::HasDefaultLogger()) { LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:"; LOGS_DEFAULT(VERBOSE) << log_stream.str(); - LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size << " bytes"; } } } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/env.h b/onnxruntime/core/platform/windows/env.h index 84d57b889235c..79739db9e5640 100644 --- a/onnxruntime/core/platform/windows/env.h +++ b/onnxruntime/core/platform/windows/env.h @@ -55,7 +55,6 @@ class WindowsEnv : public Env { static int DefaultNumCores(); int GetNumPhysicalCpuCores() const override; std::vector GetDefaultThreadAffinities() const override; - int GetL2CacheSize() const override; static WindowsEnv& Instance(); PIDType GetSelfPid() const override; Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override; @@ -114,8 +113,6 @@ class WindowsEnv : public Env { * } */ std::vector cores_; - - int l2_cache_size; /* * "global_processor_info_map_" is a map of: * global_processor_id <--> (group_id, local_processor_id) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 33a17be38adbf..22578175846f7 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -8,19 +8,17 @@ sh benchmark_mha.sh """ -import csv import math import os import platform import statistics import time -from datetime import datetime from typing import List, Optional import torch from onnx import TensorProto, helper -from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnxruntime import InferenceSession, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession @@ -277,7 +275,9 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig): return model.SerializeToString() -def create_session(config: MultiHeadAttentionConfig, session_options=None) -> CudaSession: +def create_session( + config: MultiHeadAttentionConfig, +) -> CudaSession: onnx_model_str = create_multi_head_attention_onnx_model(config) if config.provider == "CUDAExecutionProvider": @@ -287,7 +287,7 @@ def create_session(config: MultiHeadAttentionConfig, session_options=None) -> Cu else: providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + ort_session = InferenceSession(onnx_model_str, providers=providers) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -297,8 +297,11 @@ def create_session(config: MultiHeadAttentionConfig, session_options=None) -> Cu class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__(self, config: MultiHeadAttentionConfig, session_options=None): - self.ort_session = create_session(config, session_options) + def __init__( + self, + config: MultiHeadAttentionConfig, + ): + self.ort_session = create_session(config) self.feed_dict = config.random_inputs() def infer(self): @@ -344,24 +347,13 @@ def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: return "Unfused" -def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: - # CPU Flash Attention does not support causal and kv cache etc. - if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "CPU:Flash" - +def get_cpu_kernel_name() -> str: + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "CPU:Flash" return "CPU:Unfused" -def run_tflops_test( - csv_writer: csv.DictWriter, - use_gpu: bool = True, - enable_cuda_graph: bool = False, - causal: bool = False, - use_kv_cache: bool = False, - intra_op_num_threads: int = 0, - repeats: int = 100, -): +def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): if use_gpu: device_id = torch.cuda.current_device() device = torch.device("cuda", device_id) @@ -415,32 +407,16 @@ def run_tflops_test( ] else: configs = [ - # TNLGv4 (1, 128, 0, 32, 128, True), (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), (1, 2048, 0, 32, 128, True), - # bert-base - (1, 128, 0, 12, 64, True), - (1, 384, 0, 12, 64, True), - (1, 512, 0, 12, 64, True), - (4, 128, 0, 12, 64, True), - (4, 384, 0, 12, 64, True), - (4, 512, 0, 12, 64, True), - # bert-large - (1, 128, 0, 16, 64, True), - (1, 384, 0, 16, 64, True), - (1, 512, 0, 16, 64, True), - (4, 128, 0, 16, 64, True), - (4, 384, 0, 16, 64, True), - (4, 512, 0, 16, 64, True), ] # List of environment variables to enable/disable attention kernels print("Environment Variables:") env_names = [ - "ORT_ATTENTION_ALGO", "ORT_DISABLE_FLASH_ATTENTION", "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", "ORT_DISABLE_FUSED_ATTENTION", @@ -449,127 +425,73 @@ def run_tflops_test( "ORT_DISABLE_FUSED_CROSS_ATTENTION", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", ] - - env_list = "" for name in env_names: value = os.getenv(name) if value is not None: print(f"{name}={value}") - if env_list: - env_list += "," - env_list += f"{name}={value}" - print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") + print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") + causal = False for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - use_kv_cache=use_kv_cache, - past_sequence_length=past_sequence_length, - max_cache_sequence_length=None, - kv_sequence_length=None, - provider=provider, - enable_cuda_graph=enable_cuda_graph, - device=device, - dtype=torch.float16 if use_gpu else torch.float, - share_past_present_buffer=False, - input_format=input_format, - ) - - sess_options = SessionOptions() - sess_options.intra_op_num_threads = intra_op_num_threads - session = create_session(config, sess_options) - - if use_gpu: - kernel = get_gpu_kernel_name(config) - else: - kernel = get_cpu_kernel_name(config) - - if kernel == "Unfused": - # Skip large sequence length for Unfused kernel to avoid OOM. - if not enable_unfused: - continue - - # Unfused kernel does not support packed QKV or packed KV formats. - if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: - continue - - input_dict = config.random_inputs() - - # warm up session - _ = measure_latency(session, input_dict) - - latency_list = [] - for _ in range(repeats): - latency = measure_latency(session, input_dict) - latency_list.append(latency) - average_latency = statistics.mean(latency_list) - - del session - - # compute TFLOPS per second - speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency) - - format = InputFormats.input_format_str(input_format) - print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" - f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" - ) - - row = { - "use_gpu": use_gpu, - "enable_cuda_graph": enable_cuda_graph, - "format": format, - "causal": causal, - "batch_size": batch_size, - "sequence_length": sequence_length, - "past_sequence_length": past_sequence_length, - "num_heads": num_heads, - "head_size": head_size, - "intra_op_num_threads": intra_op_num_threads, - "average_latency": average_latency, - "tflops": speed, - "kernel": kernel, - "environment_variables": env_list, - } - csv_writer.writerow(row) - - -def run_tflops_tests( - use_gpu: bool = True, - enable_cuda_graph: bool = False, -): - csv_filename = "benchmark_mha_{}_{}.csv".format( - "gpu" if use_gpu else "cpu", datetime.now().strftime("%Y%m%d-%H%M%S") - ) - with open(csv_filename, mode="a", newline="") as csv_file: - column_names = [ - "use_gpu", - "enable_cuda_graph", - "format", - "causal", - "batch_size", - "sequence_length", - "past_sequence_length", - "num_heads", - "head_size", - "intra_op_num_threads", - "average_latency", - "tflops", - "kernel", - "environment_variables", - ] - csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) - csv_writer.writeheader() - - for causal, use_kv_cache in [(False, False)]: - for intra_op_num_threads in [1, 2, 4, 8, 16, 0]: # 0 means using all CPU cores by default. - run_tflops_test(csv_writer, use_gpu, enable_cuda_graph, causal, use_kv_cache, intra_op_num_threads) + for use_kv_cache in [False]: + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=True, + use_kv_cache=use_kv_cache, + past_sequence_length=past_sequence_length, + max_cache_sequence_length=None, + kv_sequence_length=None, + provider=provider, + enable_cuda_graph=enable_cuda_graph, + device=device, + dtype=torch.float16 if use_gpu else torch.float, + share_past_present_buffer=False, + input_format=input_format, + ) + + session = create_session(config) + + if use_gpu: + kernel = get_gpu_kernel_name(config) + else: + kernel = get_cpu_kernel_name() + + if kernel == "Unfused": + # Skip large sequence length for Unfused kernel to avoid OOM. + if not enable_unfused: + continue + + # Unfused kernel does not support packed QKV or packed KV formats. + if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + continue + + input_dict = config.random_inputs() + + # warm up session + _ = measure_latency(session, input_dict) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + # compute TFLOPS per second + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + format = InputFormats.input_format_str(input_format) + print( + f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + ) def plot_prompt_performance( @@ -644,7 +566,7 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_causal_performance_test(sm: int): +def run_performance_test(sm: int): """ Run performance tests for prompt and token generation. @@ -678,9 +600,9 @@ def run_causal_performance_test(sm: int): if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_causal_performance_test(sm) + run_performance_test(sm) - run_tflops_tests(use_gpu=True, enable_cuda_graph=True) + run_tflops_test(use_gpu=True, enable_cuda_graph=True) # Test CPU provider - run_tflops_tests(use_gpu=False, enable_cuda_graph=False) + run_tflops_test(use_gpu=False, enable_cuda_graph=False)