From fbc3927231b81f19c9c3c5ee044418184a5381d0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 20 Aug 2024 08:50:22 -0700 Subject: [PATCH] [CUDA] cuDNN Flash Attention (#21629) ### Description - [x] Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator. - [x] Support attention mask. - [x] Support attention bias. - [x] Update tests and benchmark script. The cuDNN SDPA is disabled by default. To enable it, need the following: (1) Requires cuDNN 9.3 or newer version installed. (2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or set `sdpa_kernel=8` cuda provider option to enable it. (3) Only works on devices with compute capability >= 8.0. Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths. Future Works: (1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed. (2) Add API to support ragged batching (padding removed in inputs). (3) Support other input formats (like QKV_BS3NH). (4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case. ### Example Benchmark Results on H100 The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias. #### Test Setting 1 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 256 | 0 | 32 | 128 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math Q,KV | 0.000129 | 133.0 | ort:cudnn Q,KV | 0.000151 | 114.1 | ort:flash Q,KV | 0.000194 | 88.5 | ort:efficient QKV | 0.000154 | 111.8 | ort:cudnn QKV | 0.000175 | 98.0 | ort:flash QKV | 0.000217 | 79.0 | ort:efficient #### Test Setting 2 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 512 | 0 | 16 | 64 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math Q,KV | 0.000103 | 167.1 | ort:cudnn Q,KV | 0.000117 | 146.3 | ort:flash Q,KV | 0.000192 | 89.6 | ort:efficient QKV | 0.000113 | 151.5 | ort:cudnn QKV | 0.000128 | 134.7 | ort:flash QKV | 0.000201 | 85.3 | ort:efficient --- cmake/external/cuDNN.cmake | 2 - cmake/onnxruntime_rocm_hipify.cmake | 1 + .../contrib_ops/cpu/bert/attention_common.h | 1 + .../contrib_ops/cuda/bert/attention.cc | 5 +- .../contrib_ops/cuda/bert/attention_impl.cu | 84 +++- .../contrib_ops/cuda/bert/attention_impl.h | 9 +- .../cuda/bert/attention_kernel_options.cc | 16 +- .../cuda/bert/attention_kernel_options.h | 4 +- .../cuda/bert/attention_prepare_qkv.cu | 57 +-- .../bert/cudnn_fmha/cudnn_flash_attention.cu | 405 ++++++++++++++++++ .../bert/cudnn_fmha/cudnn_flash_attention.h | 50 +++ .../cuda/bert/multihead_attention.cc | 46 +- .../cuda/bert/multihead_attention.h | 1 + .../quantization/attention_quantization.cc | 5 +- .../providers/cuda/cuda_execution_provider.h | 2 +- .../contrib_ops/attention_op_test_helper.cc | 3 +- .../multihead_attention_op_test.cc | 27 ++ .../test/python/transformers/benchmark_mha.py | 8 +- .../test/python/transformers/test_mha.py | 5 + 19 files changed, 681 insertions(+), 50 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h diff --git a/cmake/external/cuDNN.cmake b/cmake/external/cuDNN.cmake index 3d05f6406a80e..f416b207676cf 100644 --- a/cmake/external/cuDNN.cmake +++ b/cmake/external/cuDNN.cmake @@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9) CUDNN::cudnn_heuristic ) endif() - -mark_as_advanced(CUDNN_INCLUDE_DIR) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 1740144cf6553..fcddd2a51e0d1 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED) # GLOB pattern of file to be excluded set(contrib_ops_excluded_files + "bert/cudnn_fmha/*" "bert/cutlass_fmha/*" "bert/fastertransformer_decoder_attention/*" "bert/flash_attention/*" diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5a5899166f5ba..1e01aa765ca6d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -47,6 +47,7 @@ enum AttentionKernelType { AttentionKernel_TrtFusedCrossAttention, AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_FlashAttention, + AttentionKernel_CudnnFlashAttention, AttentionKernel_Default }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 1d1416995a673..e5686b255425c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -246,6 +246,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; + constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -258,6 +259,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_flash_attention, false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); @@ -294,7 +296,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); + cudnnHandle_t cudnn = GetCudnnHandle(context); + return QkvToContext(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 28e2b7b28764b..a02f5c7329b9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize( bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, + bool use_cudnn_flash_attention, bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_size = element_size * batch_size * num_heads * @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize( return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast(batch_size), true); } + if (use_cudnn_flash_attention) { + return qkv_bytes; + } + return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); } @@ -320,6 +326,68 @@ Status FlashAttention( } #endif +template +Status CudnnFlashAttention( + cudnnHandle_t cudnn_handle, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(parameters.mask_type == AttentionMaskType::MASK_NONE || + parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN); + constexpr bool is_bf16 = false; + + T* attention_bias = const_cast(data.attention_bias); + int* mask_sequence_lengths_kv = const_cast(data.mask_index); + + cudnn_sdpa::run( + data.output, + data.q, + data.k, + data.v, + attention_bias, + nullptr, // (optional) mask_sequence_lengths_q + mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv + parameters.batch_size, + parameters.num_heads, // num_heads_q, + parameters.num_heads, // num_heads_kv, + parameters.head_size, // head_size_qk + parameters.v_head_size, // head_size_v + parameters.sequence_length, // sequence_length_q + parameters.total_sequence_length, // sequence_length_kv + scale, // scaling factor applied prior softmax + parameters.is_unidirectional, // causal + is_bf16, // True if bfloat16, otherwise float16 + parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not + parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not + 0, // sliding window length. 0 means no sliding window. + data.qkv_format, + cudnn_handle, + ort_stream, + data.allocator); + + return Status::OK(); +} + +template <> +Status CudnnFlashAttention( + cudnnHandle_t cudnn_handle, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + ORT_UNUSED_PARAMETER(cudnn_handle); + ORT_UNUSED_PARAMETER(ort_stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "cudnn flash attention does not support float tensor"); +} + #if USE_MEMORY_EFFICIENT_ATTENTION template Status EfficientAttention( @@ -498,6 +566,7 @@ template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data) { @@ -512,10 +581,11 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); + assert((static_cast(data.use_flash_attention) + + static_cast(data.use_memory_efficient_attention) + + static_cast(fused_runner != nullptr) + + static_cast(data.fused_cross_attention_kernel != nullptr) + + static_cast(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1); ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); @@ -577,6 +647,10 @@ Status QkvToContext( } #endif + if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale); + } + #if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return EfficientAttention(device_prop, stream, parameters, data, scale); @@ -594,6 +668,7 @@ template struct AttentionData; template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data); @@ -601,6 +676,7 @@ template Status QkvToContext( template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index a6760f84e69f3..fcc9af9681223 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -9,6 +9,7 @@ #include #include #include "core/framework/allocator.h" +#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize( bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, + bool use_cudnn_flash_attention, bool no_qkv_workspace); template @@ -104,9 +106,11 @@ struct AttentionData { size_t workspace_bytes = 0; bool allow_debug_info = false; + // For MultiHeadAttention only. + AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + AllocatorPtr allocator = nullptr; bool IsUnfused() const { - return !use_flash_attention && !use_memory_efficient_attention && - (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + return kernel_type == AttentionKernelType::AttentionKernel_Unfused; } void PrintDebugInfo() const { @@ -139,6 +143,7 @@ template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* stream, contrib::AttentionParameters& parameters, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 28a095e68131e..b2e80cb5035cb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -9,11 +9,12 @@ #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" using namespace onnxruntime::contrib::attention; namespace onnxruntime { -void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { +void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) { if (value > 0) { use_flash_attention_ = (value & static_cast(AttentionBackend::FLASH_ATTENTION)) > 0; use_efficient_attention_ = (value & static_cast(AttentionBackend::EFFICIENT_ATTENTION)) > 0; @@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); + use_unfused_ = true; use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableTrtFlashAttention, false); use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedCrossAttention, false); @@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { kMinSeqLenForEfficientAttentionFp32, value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32); + // Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0). + if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) { + use_cudnn_flash_attention_ = false; + if (enable_kernel_debug_info_) { + std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl; + } + } + if (use_build_flag) { // Some kernels can be disabled at build time. If they are disabled, we should not use them. #ifndef USE_FLASH_ATTENTION @@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { } void AttentionKernelOptions::InitializeOnce( - int sdpa_kernel, bool use_build_flag) { + int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) { std::call_once(this->initialize_once_flag_, [&]() { - this->Initialize(sdpa_kernel, use_build_flag); + this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version); if (this->enable_kernel_debug_info_) { this->Print(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index aba1e01bfd91b..a27fb199a6272 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo { class AttentionKernelOptions { public: - void InitializeOnce(int sdpa_kernel, bool use_build_flag); + void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false); bool UseFlashAttention() const { return use_flash_attention_; } bool UseEfficientAttention() const { return use_efficient_attention_; } @@ -40,7 +40,7 @@ class AttentionKernelOptions { protected: void Print() const; - void Initialize(int value, bool use_build_flag); + void Initialize(int value, bool use_build_flag, bool check_cudnn_version); private: bool use_flash_attention_{true}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 575e65ebef0e9..a079076f2881b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -169,7 +169,10 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { // query, key and value are passed as Q, K and V for the following conditions. - return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); + return (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) && + data.bias == nullptr; } // For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) @@ -190,8 +193,9 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Add bias for Q if (data.bias != nullptr) { LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, @@ -204,9 +208,7 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, data.k = const_cast(data.key); data.v = const_cast(data.value); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); if (data.bias == nullptr) { // Transpose query from BSNH to BNSH @@ -233,7 +235,10 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { // query, key and value are passed as Q, K and V for the following conditions. - return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; + return (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) && + data.bias == nullptr; } // For MultiHeadAttention without past state, with Q, K and V inputs @@ -275,9 +280,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); data.v = nullptr; data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { + } else if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { if (data.bias != nullptr) { LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, kv_sequence_length, @@ -290,9 +295,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, } data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else if (data.fused_runner != nullptr) { + } else if (data.fused_runner != nullptr) { assert(qk_head_size == v_head_size); assert(data.attention_bias == nullptr); @@ -338,7 +341,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. return data.past_key == nullptr && data.present_key != nullptr; } @@ -377,8 +382,9 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, data.v = data.present_value; } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Use oiginal Query (BSNH) since there is no bias. data.q = const_cast(data.query); @@ -389,9 +395,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, data.value, data.v)); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.query, data.q)); @@ -440,8 +444,9 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, data.v = data.present_value; } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.bias, data.query, data.q); @@ -460,9 +465,7 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); constexpr int format = 0; @@ -518,7 +521,8 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; @@ -590,7 +594,8 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Note that there is no bias so we need not output query to q. data.q = const_cast(data.query); // Unpack kv to BSNH. diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu new file mode 100644 index 0000000000000..426b105dff8db --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" +#include +#include +#include +#include + +#if CUDNN_MAJOR < 9 +namespace onnxruntime::cudnn_sdpa { + +bool is_stable() { + return false; +} + +bool is_supported(const cudaDeviceProp& /*dprops*/, + int /*num_heads_q*/, + int /*num_heads_kv*/, + int /*head_size_qk*/, + int /*head_size_v*/, + int /*sequence_length_q*/, + int /*sequence_length_kv*/, + bool /*is_causal*/) { + return false; +} + +void run( + void* /*output*/, + void* /*q*/, + void* /*k*/, + void* /*v*/, + void* /*bias*/, + int* /*mask_sequence_lengths_q*/, + int* /*mask_sequence_lengths_kv*/, + int /*batch_size*/, + int /*num_heads_q*/, + int /*num_heads_kv*/, + int /*head_size_qk*/, + int /*head_size_v*/, + int /*sequence_length_q*/, + int /*sequence_length_kv*/, + float /*scale*/, + bool /*is_causal*/, + bool /*is_bf16*/, + bool /*broadcast_attn_bias_dim_0*/, + bool /*broadcast_attn_bias_dim_1*/, + int /*sliding_window*/, + AttentionQkvFormat /*qkv_format*/, + cudnnHandle_t /*handle*/, + Stream* /*stream*/, + AllocatorPtr /*allocator*/) { + ORT_THROW("OnnxRuntime was not compiled with cuDNN Flash Attention."); +} + +} // namespace onnxruntime::cudnn_sdpa + +#else // CUDNN_MAJOR >= 9 + +#include +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/cuda/cuda_stream_handle.h" + +namespace onnxruntime::cudnn_sdpa { + +bool is_stable() { + // FP16/BF16 Flash Attention support in CUDNN backend: + // version 8903 (8.9.3): + // Padding mask and causal mask + // Additive bias + // Multi-query attention (h_kv=1) + // Both self attention and cross attention + // (padded) variable sequence length + // Head dimensions 64 or 128 + // version 8903 (8.9.4): + // Alibi mask; + // version 8907 (8.9.7): + // Grouped Query Attention + // version 90100 (9.1.0): + // Head dimensions 256 + // version 90101 (9.1.1) + // Sliding window attention + // version 90300 (9.3.0) + // Bug fixes; Variable sequence length supports zero-sequence-length values + // For more information, please refer to cuDNN release notes, and the following links: + // https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop + // https://github.com/NVIDIA/cudnn-frontend/blob/v1.5.2/docs/operations/Attention.md + + // For cuDNN version < 9.3, we will disable it by default. + return cudnnGetVersion() >= 90300; +} + +namespace fe = cudnn_frontend; + +bool is_supported(const cudaDeviceProp& dprops, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + bool is_causal) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && + (head_size_qk % 8 == 0) && (head_size_qk <= 256) && + (head_size_v % 8 == 0) && (head_size_v <= 256) && + (num_heads_q % num_heads_kv == 0) && + // Bottom right causal mask is only supported with s_q multiple of 64 and s_kv multiple of 64. + (!is_causal || (sequence_length_q != sequence_length_kv && + sequence_length_q % 64 == 0 && + sequence_length_kv % 64 == 0)); +} + +// A helper function to set stride for q, k, v or output tensor. +// Strides are calculated based on logical tensor layout BNSH (batch_size, num_heads, sequence_length, head_size). +// The physical tensor layout could be either BSNH (is_bsnh=True) or BNSH (is_bsnh=False). +inline void set_stride(std::vector& stride, + int64_t num_heads, + int64_t sequence_length, + int64_t head_size, + bool is_bsnh) { + stride = {num_heads * sequence_length * head_size, // stride for batch. + is_bsnh ? head_size : (head_size * sequence_length), // stride for head. + is_bsnh ? (num_heads * head_size) : head_size, // stride for sequence. + 1}; // stride for hidden dim of head, shall always be 1. +} + +// It is used as a key for hash table to store cached graphs. +// It contains all parameters used in builing graph. Do not include data pointers that only needed in graph execution. +struct GraphParams { + int batch_size; + int num_heads_q; + int num_heads_kv; + int head_size_qk; + int head_size_v; + int sequence_length_q; + int sequence_length_kv; + float scale; + bool is_causal; + bool is_bf16; // True if bfloat16, otherwise float16 + AttentionQkvFormat qkv_format; + cudnnHandle_t handle; + bool has_bias; + bool broadcast_bias_dim_0; + bool broadcast_bias_dim_1; + bool has_padding_mask_q; + bool has_padding_mask_kv; + int sliding_window; + + bool operator==(const GraphParams& rhs) const { + return batch_size == rhs.batch_size && + num_heads_q == rhs.num_heads_q && + num_heads_kv == rhs.num_heads_kv && + head_size_qk == rhs.head_size_qk && + head_size_v == rhs.head_size_v && + sequence_length_q == rhs.sequence_length_q && + sequence_length_kv == rhs.sequence_length_kv && + scale == rhs.scale && + is_causal == rhs.is_causal && + is_bf16 == rhs.is_bf16 && + qkv_format == rhs.qkv_format && + handle == rhs.handle && + has_bias == rhs.has_bias && + broadcast_bias_dim_0 == rhs.broadcast_bias_dim_0 && + broadcast_bias_dim_1 == rhs.broadcast_bias_dim_1 && + has_padding_mask_q == rhs.has_padding_mask_q && + has_padding_mask_kv == rhs.has_padding_mask_kv && + sliding_window == rhs.sliding_window; + } +}; + +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define BIAS_UID 5 +#define SEQ_LEN_Q_UID 6 +#define SEQ_LEN_KV_UID 7 + +std::shared_ptr build_graph(GraphParams& params) { + int batch_size = params.batch_size; + int num_heads_q = params.num_heads_q; + int num_heads_kv = params.num_heads_kv; + int head_size_qk = params.head_size_qk; + int head_size_v = params.head_size_v; + int sequence_length_q = params.sequence_length_q; + int sequence_length_kv = params.sequence_length_kv; + float scale = params.scale; + bool is_causal = params.is_causal; + bool is_bf16 = params.is_bf16; + AttentionQkvFormat qkv_format = params.qkv_format; + cudnnHandle_t handle = params.handle; + + assert(qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BNSH); + + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(is_bf16 ? fe::DataType_t::BFLOAT16 : fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + bool is_q_bsnh = (qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + bool is_kv_bsnh = qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH; + + std::vector stride; + set_stride(stride, num_heads_q, sequence_length_q, head_size_qk, is_q_bsnh); + + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({batch_size, num_heads_q, sequence_length_q, head_size_qk}) // logical layout + .set_stride(stride)); + + set_stride(stride, num_heads_kv, sequence_length_kv, head_size_qk, is_kv_bsnh); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_qk}) + .set_stride(stride)); + + set_stride(stride, num_heads_kv, sequence_length_kv, head_size_v, is_kv_bsnh); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_v}) + .set_stride(stride)); + + auto attributes = fe::graph::SDPA_attributes() + .set_name("SDPA") + .set_is_inference(true) + .set_causal_mask(is_causal) + .set_causal_mask_bottom_right(is_causal && sequence_length_q != sequence_length_kv) + .set_attn_scale(scale); + + if (params.sliding_window > 0) { + attributes.set_sliding_window_length(params.sliding_window); + } + + if (params.has_bias) { + std::vector bias_shape = {params.broadcast_bias_dim_0 ? 1 : batch_size, + params.broadcast_bias_dim_1 ? 1 : num_heads_q, + sequence_length_q, + sequence_length_kv}; + stride = {bias_shape[1] * bias_shape[2] * bias_shape[3], bias_shape[2] * bias_shape[3], bias_shape[3], 1}; + auto bias = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_uid(BIAS_UID) + .set_dim(bias_shape) + .set_stride(stride)); + attributes.set_bias(bias); + } + + if (params.has_padding_mask_q || params.has_padding_mask_kv) { + attributes.set_padding_mask(true); + + if (params.has_padding_mask_q) { + auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_uid(SEQ_LEN_Q_UID) + .set_dim({batch_size, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + attributes.set_seq_len_q(seq_q); + } + + if (params.has_padding_mask_kv) { + auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_uid(SEQ_LEN_KV_UID) + .set_dim({batch_size, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + attributes.set_seq_len_kv(seq_kv); + } + } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, attributes); + + constexpr bool is_output_bsnh = true; + set_stride(stride, num_heads_q, sequence_length_q, head_size_v, is_output_bsnh); + + O->set_output(true) + .set_dim({batch_size, num_heads_q, sequence_length_q, head_size_v}) + .set_stride(stride) + .set_uid(O_UID); + + if (!mha_graph->build(handle, {fe::HeurMode_t::A}).is_good()) { + ORT_THROW("Failed to build cuDNN graph for Flash Attention:", *mha_graph, "cudnn version:", cudnnGetVersion()); + } + + return mha_graph; +} + +// Compute hash based on content in memory byte by byte. This can be moved to a common header file if needed. +template +struct BytesHash { + // Verify that Params is good to hash byte by byte. + static_assert(std::is_standard_layout_v, "Params is not standard layout"); + + size_t operator()(const T& params) const { + auto ptr = reinterpret_cast(¶ms); + // Fowler–Noll–Vo hash function + uint32_t value = 0x811C9DC5; + constexpr size_t bytes = sizeof(T); + for (size_t i = 0; i < bytes; ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); + } +}; + +// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe. +// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size. +thread_local +std::unordered_map, BytesHash > mha_graph_cache; + +void run( + void* output, + void* q, + void* k, + void* v, + void* attn_bias, + int* mask_sequence_lengths_q, + int* mask_sequence_lengths_kv, + int batch_size, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + float scale, + bool is_causal, + bool is_bf16, + bool broadcast_attn_bias_dim_0, + bool broadcast_attn_bias_dim_1, + int sliding_window, + AttentionQkvFormat qkv_format, + cudnnHandle_t handle, + Stream* stream, + AllocatorPtr allocator) { + + GraphParams params; + params.batch_size = batch_size; + params.num_heads_q = num_heads_q; + params.num_heads_kv = num_heads_kv; + params.head_size_qk = head_size_qk; + params.head_size_v = head_size_v; + params.sequence_length_q = sequence_length_q; + params.sequence_length_kv = sequence_length_kv; + params.scale = scale; + params.is_causal = is_causal; + params.is_bf16 = is_bf16; + params.qkv_format = qkv_format; + params.handle = handle; + params.has_bias = attn_bias != nullptr; + params.broadcast_bias_dim_0 = broadcast_attn_bias_dim_0; + params.broadcast_bias_dim_1 = broadcast_attn_bias_dim_1; + params.has_padding_mask_q = (mask_sequence_lengths_q != nullptr); + params.has_padding_mask_kv = (mask_sequence_lengths_kv != nullptr); + params.sliding_window = sliding_window; + + std::shared_ptr mha_graph; + auto it = mha_graph_cache.find(params); + if (it != mha_graph_cache.end()) { + mha_graph = it->second; + } else { + mha_graph = build_graph(params); + mha_graph_cache[params] = mha_graph; + } + + std::unordered_map variant_pack = { + {Q_UID, q}, + {K_UID, k}, + {V_UID, v}, + {O_UID, output}, + }; + + if (attn_bias != nullptr) { + variant_pack[BIAS_UID] = attn_bias; + } + + if (mask_sequence_lengths_q != nullptr) { + variant_pack[SEQ_LEN_Q_UID] = mask_sequence_lengths_q; + } + + if (mask_sequence_lengths_kv != nullptr) { + variant_pack[SEQ_LEN_KV_UID] = mask_sequence_lengths_kv; + } + + // Allocate workspace. + auto bytes = mha_graph->get_workspace_size(); + + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr( + allocator, bytes, false, stream, WaitCudaNotificationOnDevice); + + CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get())); +} + +} // namespace onnxruntime::cudnn_sdpa +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h new file mode 100644 index 0000000000000..858a22a6b9187 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +using onnxruntime::Stream; +using onnxruntime::contrib::AttentionQkvFormat; + +namespace onnxruntime::cudnn_sdpa { + +bool is_stable(); + +bool is_supported(const cudaDeviceProp& dprops, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + bool is_causal); + +void run( + void* output, + void* q, + void* k, + void* v, + void* bias, // (optional) attention bias with shape [b or 1, h_q or 1, s_q, s_kv]. + int* mask_sequence_lengths_q, // (optional) sequence lengths of q for padding mask. Shape: [batch_size] + int* mask_sequence_lengths_kv, // (optional) sequence lengths of k or v for padding mask. Shape: [batch_size] + int batch_size, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + float scale, + bool is_causal, + bool is_bf16, // True if bfloat16, otherwise float16 + bool broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 + bool broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 + int sliding_window, // sliding window length. 0 means no sliding window. + AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported + cudnnHandle_t handle, + Stream* stream, + AllocatorPtr allocator); + +} // namespace onnxruntime::cudnn_sdpa diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index b2fd9b5e89de1..2ad8bc4015a47 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -6,6 +6,7 @@ #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" @@ -59,6 +60,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); + enable_cudnn_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseCudnnFlashAttention(); + // Allocate cache buffers constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer(cache_bytes); @@ -148,6 +151,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; + AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && nullptr == attention_bias && @@ -173,6 +178,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; + kernel_type = AttentionKernelType::AttentionKernel_FlashAttention; } auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); @@ -184,8 +190,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + bool use_cudnn_sdpa = kernel_type == AttentionKernelType::AttentionKernel_Default && + enable_cudnn_flash_attention_ && + is_mask_none_or_1d_k_len && + onnxruntime::cudnn_sdpa::is_supported(device_prop, + parameters.num_heads, // num_heads_q + parameters.num_heads, // num_heads_kv + parameters.head_size, // head_size_qk + parameters.v_head_size, // head_size_v + parameters.sequence_length, // seq_len_q + parameters.total_sequence_length, // seq_len_kv + is_unidirectional_); + if (use_cudnn_sdpa) { + kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention; + } + bool use_fused_cross_attention = - !use_flash_attention && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_cross_attention_ && nullptr == key_padding_mask && nullptr == attention_bias && @@ -205,11 +226,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // The kernel has no limit on sequence length, and this checks whether the kernel has been loaded. if (fused_fp16_cross_attention_kernel_->isValid(sequence_length)) { fused_cross_attention_kernel = fused_fp16_cross_attention_kernel_; + kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention; } } bool use_fused_runner = - !use_flash_attention && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && nullptr == attention_bias && @@ -234,6 +256,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length); if (fused_fp16_runner_->IsValid(normalized_seq_len)) { fused_runner = fused_fp16_runner_.get(); + // could also be AttentionKernel_TrtFlashAttention, but we don't classify it here. + kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; } } @@ -244,9 +268,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length >= length_threshold; bool use_memory_efficient_attention = - !use_flash_attention && - fused_runner == nullptr && - fused_cross_attention_kernel == nullptr && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_memory_efficient_attention_ && is_long_sequence && // Check whether the attention bias alignment is good for memory efficient attention. @@ -254,10 +276,17 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, std::is_same::value, parameters.head_size, parameters.v_head_size); + if (use_memory_efficient_attention) { + kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; + } #else constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_type == AttentionKernelType::AttentionKernel_Default) { + kernel_type = AttentionKernelType::AttentionKernel_Unfused; + } + typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); @@ -278,6 +307,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.kernel_type = kernel_type; + data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). // The cache will be initialized only once, and become readonly after that. @@ -305,6 +336,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_sdpa, no_qkv_workspace); auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); @@ -323,6 +355,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (data.allow_debug_info) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_cudnn_flash_attention = use_cudnn_sdpa; debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; debug_info.use_efficient_attention = use_memory_efficient_attention; if (fused_fp16_runner_ != nullptr) { @@ -337,8 +370,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); return QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data); + device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 68fd0c9943fca..8edc1d0e6ac06 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -33,6 +33,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + bool enable_cudnn_flash_attention_; // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 3a5fc401c53af..1b774b163888f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -179,6 +179,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_fused_cross_attention = false; constexpr bool use_memory_efficient_attention = false; constexpr bool use_flash_attention = false; + constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, parameters.num_heads, @@ -191,6 +192,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_flash_attention, true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -215,7 +217,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.present = reinterpret_cast(present->MutableData()); } - return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); + cudnnHandle_t cudnn = GetCudnnHandle(context); + return QkvToContext(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 0871f7e4d0a74..c5736733beb1d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -88,7 +88,7 @@ class CUDAExecutionProvider : public IExecutionProvider { #ifndef DISABLE_CONTRIB_OPS // Attention kernel options parsed from sdpa_kernel cuda provider option. const AttentionKernelOptions* GetAttentionKernelOptions() const { - attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true); + attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true, true); return &attention_kernel_options_; } #endif diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 1ea67314f62d6..5df521bd6381d 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -104,7 +104,8 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, AttentionKernelType::AttentionKernel_TrtFusedAttention, - AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention}; + AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention, + AttentionKernelType::AttentionKernel_CudnnFlashAttention}; LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.query_data", data.query_data); LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.key_data", data.key_data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 3aaf710c33db4..1d167b5dffdb5 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -367,6 +367,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -377,6 +378,22 @@ static void RunMultiHeadAttentionKernel( mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); } + + if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; + RunMultiHeadAttentionTest( + query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, + past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + } } static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { @@ -451,6 +468,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu } #endif + kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention; + if (!SkipAttentionKernel(data, kernel_type)) { + RunMultiHeadAttentionKernel( + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, + data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + } + kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 50b94e7af285e..4cc5ce4201ea1 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -791,7 +791,13 @@ def run_tflops_test( # flash attention is available for sm >= 80 sm = get_compute_capability() if sm >= 80: - backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] + backends = [ + SdpaKernel.DEFAULT, + SdpaKernel.FLASH_ATTENTION, + SdpaKernel.EFFICIENT_ATTENTION, + SdpaKernel.CUDNN_FLASH_ATTENTION, + SdpaKernel.MATH, + ] else: backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH] else: diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 5ebc02c84acb2..92653ffb053ce 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -804,6 +804,10 @@ def run_mha_cuda_multi_threading_default(self): if get_compute_capability() >= 60: self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + def run_mha_cuda_multi_threading_cudnn(self): + if get_compute_capability() in [80, 86, 89, 90]: + self.run_mha_cuda_multi_threading(SdpaKernel.CUDNN_FLASH_ATTENTION) + def run_mha_cuda_multi_threading_efficient(self): if comprehensive_mode and get_compute_capability() >= 60: self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) @@ -826,6 +830,7 @@ def test_all(self): self.run_mha_cpu() self.run_mha_cuda() self.run_mha_cuda_multi_threading_default() + self.run_mha_cuda_multi_threading_cudnn() self.run_mha_cuda_multi_threading_efficient() self.run_mha_cuda_multi_threading_math() self.run_mha_cuda_multi_threading_trt()