diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 88127387d08ea..9fd713ea4637c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -46,6 +46,7 @@ enum AttentionKernelType { AttentionKernel_TrtFusedCrossAttention, AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_FlashAttention, + AttentionKernel_CudnnFlashAttention, AttentionKernel_Default }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 5c0989bced70c..cbb095b2008ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -249,6 +249,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; + constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -261,6 +262,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_flash_attention, false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); @@ -297,7 +299,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); + cudnnHandle_t cudnn = GetCudnnHandle(context); + return QkvToContext(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index f9eabe27d97e4..9b511d96d80c8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize( bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, + bool use_cudnn_flash_attention, bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_size = element_size * batch_size * num_heads * @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize( return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast(batch_size), true); } + if (use_cudnn_flash_attention) { + return qkv_bytes; + } + return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); } @@ -320,6 +326,62 @@ Status FlashAttention( } #endif + +template +Status CudnnFlashAttention( + cudnnHandle_t cudnn_handle, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + + constexpr bool is_bf16 = false; + + cudnn_sdpa::run( + data.q, + data.k, + data.v, + data.output, + parameters.batch_size, + parameters.num_heads, // num_heads_q, + parameters.num_heads, // num_heads_kv, + parameters.head_size, // head_size_qk + parameters.v_head_size, // head_size_v + parameters.sequence_length, // sequence_length_q + parameters.total_sequence_length, // sequence_length_kv + scale, // scale prior softmax + parameters.is_unidirectional, // causal + is_bf16, // True if bfloat16, otherwise float16 + data.qkv_format, + cudnn_handle, + ort_stream, + data.allocator); + + return Status::OK(); +} + +template <> +Status CudnnFlashAttention( + cudnnHandle_t cudnn_handle, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + ORT_UNUSED_PARAMETER(cudnn_handle); + ORT_UNUSED_PARAMETER(ort_stream); + ORT_UNUSED_PARAMETER(parameters); + ORT_UNUSED_PARAMETER(data); + ORT_UNUSED_PARAMETER(scale); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} + + + #if USE_MEMORY_EFFICIENT_ATTENTION template Status EfficientAttention( @@ -485,6 +547,7 @@ template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data) { @@ -502,7 +565,8 @@ Status QkvToContext( assert((int(data.use_flash_attention) + int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); + int(data.fused_cross_attention_kernel != nullptr) + + int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1); ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); @@ -564,6 +628,10 @@ Status QkvToContext( } #endif + if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale); + } + #if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return EfficientAttention(device_prop, stream, parameters, data, scale); @@ -581,6 +649,7 @@ template struct AttentionData; template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data); @@ -588,6 +657,7 @@ template Status QkvToContext( template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index fad353dcfeb07..7e9ad663cfde0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -9,6 +9,7 @@ #include #include #include "core/framework/allocator.h" +#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize( bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention, + bool use_cudnn_flash_attention, bool no_qkv_workspace); template @@ -104,9 +106,11 @@ struct AttentionData { size_t workspace_bytes = 0; bool allow_debug_info = false; + // For MultiHeadAttention only. + AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + AllocatorPtr allocator = nullptr; bool IsUnfused() const { - return !use_flash_attention && !use_memory_efficient_attention && - (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + return kernel_type == AttentionKernelType::AttentionKernel_Unfused; } void PrintDebugInfo() const { @@ -139,6 +143,7 @@ template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, + cudnnHandle_t& cudnn, Stream* stream, contrib::AttentionParameters& parameters, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc index 28a095e68131e..5990dc66d9e93 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc @@ -27,7 +27,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) { use_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFlashAttention, false); use_efficient_attention_ = !ParseEnvironmentVariableWithDefault(kDisableMemoryEfficientAttention, false); use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedSelfAttention, false); - use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, false); + use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault(kEnableCudnnFlashAttention, true); use_unfused_ = true; use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault(kDisableTrtFlashAttention, false); use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault(kDisableFusedCrossAttention, false); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index aba1e01bfd91b..28ac13955c676 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -46,7 +46,7 @@ class AttentionKernelOptions { bool use_flash_attention_{true}; bool use_efficient_attention_{true}; bool use_trt_fused_attention_{true}; - bool use_cudnn_flash_attention_{false}; + bool use_cudnn_flash_attention_{true}; bool use_unfused_{true}; bool use_trt_flash_attention_{true}; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 05c592ec61059..a54e9739966c5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -165,7 +165,10 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { // query, key and value are passed as Q, K and V for the following conditions. - return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); + return (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) && + data.bias == nullptr; } // For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) @@ -186,8 +189,9 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Add bias for Q if (data.bias != nullptr) { LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, @@ -200,9 +204,7 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, data.k = const_cast(data.key); data.v = const_cast(data.value); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); if (data.bias == nullptr) { // Transpose query from BSNH to BNSH @@ -229,7 +231,10 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { // query, key and value are passed as Q, K and V for the following conditions. - return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; + return (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) && + data.bias == nullptr; } // For MultiHeadAttention without past state, with Q, K and V inputs @@ -271,9 +276,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); data.v = nullptr; data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { + } else if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { if (data.bias != nullptr) { LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, kv_sequence_length, @@ -286,9 +291,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, } data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else if (data.fused_runner != nullptr) { + } else if (data.fused_runner != nullptr) { assert(qk_head_size == v_head_size); assert(data.relative_position_bias == nullptr); @@ -334,7 +337,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, template bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. return data.past_key == nullptr && data.present_key != nullptr; } @@ -373,8 +378,9 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, data.v = data.present_value; } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Use oiginal Query (BSNH) since there is no bias. data.q = const_cast(data.query); @@ -385,9 +391,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, data.value, data.v)); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.query, data.q)); @@ -436,8 +440,9 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, data.v = data.present_value; } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || + data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.bias, data.query, data.q); @@ -456,9 +461,7 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; - } else -#endif - { // unfused kernel + } else { // unfused kernel assert(data.IsUnfused()); constexpr int format = 0; @@ -514,7 +517,8 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; @@ -586,7 +590,8 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.use_memory_efficient_attention || data.use_flash_attention || + data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { // Note that there is no bias so we need not output query to q. data.q = const_cast(data.query); // Unpack kv to BSNH. diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu new file mode 100644 index 0000000000000..ec137b4da12db --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu @@ -0,0 +1,359 @@ +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" +#include +#include + +#ifndef NDEBUG +#include +#endif + +// FP16/BF16 Flash Attention support in CUDNN backend: + +// version 8903 (8.9.3): +// Padding mask and causal mask +// Additive bias +// Multi-query attention (h_kv=1) +// Both self attention and cross attention +// (padded) variable sequence length +// Head dimensions 64 or 128 +// version 8903 (8.9.4): +// Alibi mask; +// version 8907 (8.9.7): +// Grouped Query Attention +// version 90100 (9.1.0): +// Head dimensions 256 +// version 90101 (9.1.1) +// Sliding window attention +// version 90300 (9.3.0) +// Bug fixes; Variable sequence length supports zero-sequence-length values +// For more information, please refer to cuDNN release notes, and the following link: +// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop +// TODO: For cuDNN version < 9.3, we will disable it by default, unless user explicitly enables it. + +#if CUDART_VERSION < 12000 || CUDNN_MAJOR < 9 +namespace onnxruntime::cudnn_sdpa { + +bool is_supported(const cudaDeviceProp& /*dprops*/, + int /*num_heads_q*/, + int /*num_heads_kv*/, + int /*head_size_qk*/, + int /*head_size_v*/, + int /*sequence_length_q*/, + int /*sequence_length_kv*/, + bool /*is_causal*/) { + return false; +} + +void run( + void* /*q*/, + void* /*k*/, + void* /*v*/, + void* /*output*/, + int /*batch_size*/, + int /*num_heads_q*/, + int /*num_heads_kv*/, + int /*head_size_qk*/, + int /*head_size_v*/, + int /*sequence_length_q*/, + int /*sequence_length_kv*/, + float /*scale*/, + bool /*is_causal*/, + bool /*is_bf16*/, + AttentionQkvFormat /*qkv_format*/, + cudnnHandle_t /*handle*/, + Stream* /*stream*/, + AllocatorPtr /*allocator*/) { + ORT_THROW("OnnxRuntime was not compiled with cuDNN Flash Attention."); +} + +} // namespace onnxruntime::cudnn_sdpa + +#else // CUDART_VERSION >= 12000 && CUDNN_MAJOR >= 9 + +#include +#include "core/providers/cuda/shared_inc/cudnn_fe_call.h" +#include "core/providers/cuda/cuda_stream_handle.h" + +namespace onnxruntime::cudnn_sdpa { + +namespace fe = cudnn_frontend; + +int get_max_head_size() { + static int max_head_size = 0; + static std::once_flag flag; + + if (max_head_size == 0) { + std::call_once(flag, []() { + auto version = cudnnGetVersion(); + if (version < 90100) { + max_head_size = 128; + } else { + max_head_size = 256; + } + }); + } + + return max_head_size; + +} +bool is_supported(const cudaDeviceProp& dprops, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + bool is_causal) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + // See https://github.com/NVIDIA/cudnn-frontend/blob/1.0/release/docs/operations/Attention.md + int max_head_size = get_max_head_size(); + return (is_sm8x || is_sm90) && + (head_size_qk % 8 == 0) && (head_size_qk <= max_head_size) && + (head_size_v % 8 == 0) && (head_size_v <= max_head_size) && + (num_heads_q % num_heads_kv == 0) && + // Bottom right causal mask is only supported with s_q multiple of 64 and s_kv multiple of 64 + (!is_causal || (sequence_length_q % 64 == 0 && sequence_length_kv % 64 == 0)); +} + +// A helper function to set stride for q, k, v or output tensor. +// Strides are calculated based on logical tensor layout BNSH (batch_size, num_heads, sequence_length, head_size). +// The physical tensor layout could be either BSNH (is_bsnh=True) or BNSH (is_bsnh=False). +inline void set_stride(std::vector& stride, + int64_t num_heads, + int64_t sequence_length, + int64_t head_size, + bool is_bsnh) { + stride = {num_heads * sequence_length * head_size, // stride for batch. + is_bsnh ? head_size : (head_size * sequence_length), // stride for head. + is_bsnh ? (num_heads * head_size) : head_size, // stride for sequence. + 1}; // stride for hidden dim of head, shall always be 1. +} + +// It is used as a key for hash table to store cached graphs. +// It contains all parameters used in builing graph. Do not include data pointers that only needed in graph execution. +struct GraphParams { + int batch_size; + int num_heads_q; + int num_heads_kv; + int head_size_qk; + int head_size_v; + int sequence_length_q; + int sequence_length_kv; + float scale; + bool is_causal; + bool is_bf16; // True if bfloat16, otherwise float16 + AttentionQkvFormat qkv_format; + cudnnHandle_t handle; + + bool operator == (const GraphParams &rhs) const { + return batch_size == rhs.batch_size && + num_heads_q == rhs.num_heads_q && + num_heads_kv == rhs.num_heads_kv && + head_size_qk == rhs.head_size_qk && + head_size_v == rhs.head_size_v && + sequence_length_q == rhs.sequence_length_q && + sequence_length_kv == rhs.sequence_length_kv && + scale == rhs.scale && + is_causal == rhs.is_causal && + is_bf16 == rhs.is_bf16 && + qkv_format == rhs.qkv_format && + handle == rhs.handle; + } +}; + + +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +// #define BIAS_UID 5 +// #define SEQ_LEN_Q_UID 6 +// #define SEQ_LEN_KV_UID 7 + +std::shared_ptr build_graph(GraphParams& params) { + int batch_size = params.batch_size; + int num_heads_q = params.num_heads_q; + int num_heads_kv = params.num_heads_kv; + int head_size_qk= params.head_size_qk; + int head_size_v= params.head_size_v; + int sequence_length_q= params.sequence_length_q; + int sequence_length_kv= params.sequence_length_kv; + float scale= params.scale; + bool is_causal= params.is_causal; + bool is_bf16= params.is_bf16; + AttentionQkvFormat qkv_format = params.qkv_format; + cudnnHandle_t handle = params.handle; + + assert(qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BNSH); + + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(is_bf16 ? fe::DataType_t::BFLOAT16 : fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + bool is_q_bsnh = (qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH || + qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + bool is_kv_bsnh = qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH; + + std::vector stride; + set_stride(stride, num_heads_q, sequence_length_q, head_size_qk, is_q_bsnh); + + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({batch_size, num_heads_q, sequence_length_q, head_size_qk}) // logical layout + .set_stride(stride)); + + set_stride(stride, num_heads_kv, sequence_length_kv, head_size_qk, is_kv_bsnh); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_qk}) + .set_stride(stride)); + + set_stride(stride, num_heads_kv, sequence_length_kv, head_size_v, is_kv_bsnh); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_v}) + .set_stride(stride)); + + auto attributes = fe::graph::SDPA_attributes() + .set_name("SDPA") + .set_is_inference(true) + .set_causal_mask(is_causal) + .set_causal_mask_bottom_right(is_causal) + .set_attn_scale(scale); + //.set_sliding_window_length(sliding_window_value); + + // auto bias = mha_graph.tensor(fe::graph::Tensor_attributes() + // .set_name("bias") + // .set_uid(BIAS_UID) + // .set_dim({b, 1, s_q, s_kv}) + // .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + // attributes.set_bias(bias); + + // if (padding_mask) { + // auto seq_q = graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_q") + // .set_uid(SEQ_LEN_Q_UID) + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // auto seq_kv = graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_kv") + // .set_uid(SEQ_LEN_KV_UID) + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // attributes.set_padding_mask(padding_mask).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + // } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, attributes); + + constexpr bool is_output_bsnh = true; + set_stride(stride, num_heads_q, sequence_length_q, head_size_v, is_output_bsnh); + + O->set_output(true) + .set_dim({batch_size, num_heads_q, sequence_length_q, head_size_v}) + .set_stride(stride) + .set_uid(O_UID); + +#ifndef NDEBUG + std::cout << "cudnn graph:" << *mha_graph; +#endif + + CUDNN_FE_CALL_THROW(mha_graph->build(handle, {fe::HeurMode_t::A})); + + return mha_graph; +} + + +// Compute hash based on content in memory byte by byte. This can be moved to a common header file if needed. +template +struct BytesHash { + // Verify that Params is good to hash byte by byte. + static_assert(std::is_standard_layout_v, "Params is not standard layout"); + + size_t operator()(const T& params) const { + auto ptr = reinterpret_cast(¶ms); + // Fowler–Noll–Vo hash function + uint32_t value = 0x811C9DC5; + constexpr size_t bytes = sizeof(T); + for (size_t i = 0; i < bytes; ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe. +// TODO: since we the key includes sequence lengths, we may want to limit the cache size. +thread_local std::unordered_map, BytesHash> mha_graph_cache; + +void run( + void* q, + void* k, + void* v, + void* output, + int batch_size, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + float scale, + bool is_causal, + bool is_bf16, // True if bfloat16, otherwise float16 + AttentionQkvFormat qkv_format, + cudnnHandle_t handle, + Stream* stream, + AllocatorPtr allocator) { + + GraphParams params; + params.batch_size = batch_size; + params.num_heads_q = num_heads_q; + params.num_heads_kv = num_heads_kv; + params.head_size_qk = head_size_qk; + params.head_size_v = head_size_v; + params.sequence_length_q = sequence_length_q; + params.sequence_length_kv = sequence_length_kv; + params.scale = scale; + params.is_causal = is_causal; + params.is_bf16 = is_bf16; + params.qkv_format = qkv_format; + params.handle = handle; + + std::shared_ptr mha_graph; + auto it = mha_graph_cache.find(params); + if (it != mha_graph_cache.end()) { + mha_graph = it->second; + } else { + mha_graph = build_graph(params); + mha_graph_cache[params] = mha_graph; + } + + std::unordered_map variant_pack = { + {Q_UID, q}, + {K_UID, k}, + {V_UID, v}, + {O_UID, output}, + //{bias, bTensor.devPtr}, + //{SCALE_UID, &scale} + }; + + // Allocate workspace. + auto bytes = mha_graph->get_workspace_size(); + + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr( + allocator, bytes, false, stream, WaitCudaNotificationOnDevice); + + CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get())); +} + +} // namespace onnxruntime::cudnn_sdpa +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h new file mode 100644 index 0000000000000..a8e766aff861d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h @@ -0,0 +1,37 @@ +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +using onnxruntime::Stream; +using onnxruntime::contrib::AttentionQkvFormat; + +namespace onnxruntime::cudnn_sdpa { +bool is_supported(const cudaDeviceProp& dprops, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + bool is_causal); + +void run( + void* q, + void* k, + void* v, + void* output, + int batch_size, + int num_heads_q, + int num_heads_kv, + int head_size_qk, + int head_size_v, + int sequence_length_q, + int sequence_length_kv, + float scale, + bool is_causal, + bool is_bf16, // True if bfloat16, otherwise float16 + AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported + cudnnHandle_t handle, + Stream* stream, + AllocatorPtr allocator); + +} // namespace onnxruntime::cudnn_sdpa diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c36abc8e1d624..99a3c6a49e5b3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -6,6 +6,7 @@ #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" @@ -59,6 +60,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention(); + enable_cudnn_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseCudnnFlashAttention(); + // Allocate cache buffers constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer(cache_bytes); @@ -148,6 +151,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; + AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default; + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && nullptr == relative_position_bias && @@ -173,6 +178,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; + kernel_type = AttentionKernelType::AttentionKernel_FlashAttention; } auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); @@ -182,8 +188,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif + bool use_cudnn_sdpa = kernel_type == AttentionKernelType::AttentionKernel_Default && + enable_cudnn_flash_attention_ && + nullptr == relative_position_bias && + nullptr == key_padding_mask && + onnxruntime::cudnn_sdpa::is_supported(device_prop, + parameters.num_heads, // num_heads_q + parameters.num_heads, // num_heads_kv + parameters.head_size, // head_size_qk + parameters.v_head_size, // head_size_v + parameters.sequence_length, // seq_len_q + parameters.total_sequence_length, // seq_len_kv + is_unidirectional_); + if (use_cudnn_sdpa) { + kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention; + } + bool use_fused_cross_attention = - !use_flash_attention && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_cross_attention_ && nullptr == key_padding_mask && nullptr == relative_position_bias && @@ -203,13 +225,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // The kernel has no limit on sequence length, and this checks whether the kernel has been loaded. if (fused_fp16_cross_attention_kernel_->isValid(sequence_length)) { fused_cross_attention_kernel = fused_fp16_cross_attention_kernel_; + kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention; } } bool use_fused_runner = - !use_flash_attention && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && nullptr == relative_position_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && @@ -232,6 +254,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length); if (fused_fp16_runner_->IsValid(normalized_seq_len)) { fused_runner = fused_fp16_runner_.get(); + // could also be AttentionKernel_TrtFlashAttention, but we don't classify it here. + kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; } } @@ -245,19 +269,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = - !use_flash_attention && - fused_runner == nullptr && - fused_cross_attention_kernel == nullptr && + kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_memory_efficient_attention_ && is_long_sequence && (relative_position_bias == nullptr || is_good_for_rpb) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, std::is_same::value, parameters.head_size, parameters.v_head_size); + if (use_memory_efficient_attention) { + kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; + } #else constexpr bool use_memory_efficient_attention = false; #endif + if (kernel_type == AttentionKernelType::AttentionKernel_Default) { + kernel_type = AttentionKernelType::AttentionKernel_Unfused; + } + typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); @@ -276,6 +305,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.kernel_type = kernel_type; + data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). // The cache will be initialized only once, and become readonly after that. @@ -303,6 +334,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_sdpa, no_qkv_workspace); auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); @@ -321,6 +353,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (data.allow_debug_info) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_cudnn_flash_attention = use_cudnn_sdpa; debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; debug_info.use_efficient_attention = use_memory_efficient_attention; if (fused_fp16_runner_ != nullptr) { @@ -335,8 +368,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } cublasHandle_t cublas = GetCublasHandle(context); + cudnnHandle_t cudnn = GetCudnnHandle(context); return QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data); + device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 68fd0c9943fca..8edc1d0e6ac06 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -33,6 +33,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + bool enable_cudnn_flash_attention_; // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b62e566d43f89..60eee240c8d8b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -179,6 +179,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_fused_cross_attention = false; constexpr bool use_memory_efficient_attention = false; constexpr bool use_flash_attention = false; + constexpr bool use_cudnn_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, parameters.num_heads, @@ -191,6 +192,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention, + use_cudnn_flash_attention, true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -215,7 +217,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.present = reinterpret_cast(present->MutableData()); } - return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); + cudnnHandle_t cudnn = GetCudnnHandle(context); + return QkvToContext(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index f0255d7ece84e..427ec7c773065 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -301,6 +301,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -317,6 +318,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -333,6 +335,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -350,6 +353,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -367,6 +371,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -377,6 +382,22 @@ static void RunMultiHeadAttentionKernel( mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); } + + if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; + RunMultiHeadAttentionTest( + query_data, key_data, value_data, kv_data, qkv_data, bias_data, rel_pos_bias_data, + past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, + mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + } } static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { @@ -439,7 +460,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); } -#if USE_MEMORY_EFFICIENT_ATTENTION +#if USE_FLASH_ATTENTION kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( @@ -451,6 +472,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu } #endif + kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention; + if (!SkipAttentionKernel(data, kernel_type)) { + RunMultiHeadAttentionKernel( + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, + data.rel_pos_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, + data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, + data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + } + kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data,