Skip to content

Commit

Permalink
[CUDA] cuDNN Flash Attention (#21629)
Browse files Browse the repository at this point in the history
### Description
- [x] Add cuDNN flash attention using cudnn frontend, and enable it in
MultiHeadAttention operator.
- [x] Support attention mask.
- [x] Support attention bias.
- [x] Update tests and benchmark script.

The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or
set `sdpa_kernel=8` cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.

Note that some combinations of parameters might be rejected due to
limited support of head dimension or sequence lengths.

Future Works:
(1) FP8 and BF16 APIs.  Currently, only API for FP16 are exposed.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH
or BNSH format. May do some experiment to see whether converting q to
BNSH could be better in some case.

### Example Benchmark Results on H100

The following tests are on FP16 MultiHeadAttention operator without
attention mask and attention bias.

#### Test Setting 1
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 256 | 0 | 32 | 128

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash
Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient
Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math
Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn
Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash
Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient
Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math
Q,KV | 0.000129 | 133.0 | ort:cudnn
Q,KV | 0.000151 | 114.1 | ort:flash
Q,KV | 0.000194 | 88.5 | ort:efficient
QKV | 0.000154 | 111.8 | ort:cudnn
QKV | 0.000175 | 98.0 | ort:flash
QKV | 0.000217 | 79.0 | ort:efficient

#### Test Setting 2

batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 512 | 0 | 16 | 64

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash
Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient
Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math
Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn
Q,K,V (BSNH)  | 0.000087 | 196.6 | ort:flash
Q,K,V (BSNH)  | 0.000163 | 105.6 | ort:efficient
Q,K,V (BSNH)  | 0.000651 | 26.4 | ort:math
Q,KV | 0.000103 | 167.1 | ort:cudnn
Q,KV | 0.000117 | 146.3 | ort:flash
Q,KV | 0.000192 | 89.6 | ort:efficient
QKV | 0.000113 | 151.5 | ort:cudnn
QKV | 0.000128 | 134.7 | ort:flash
QKV | 0.000201 | 85.3 | ort:efficient
  • Loading branch information
tianleiwu authored Aug 20, 2024
1 parent 9f7e19c commit fbc3927
Show file tree
Hide file tree
Showing 19 changed files with 681 additions and 50 deletions.
2 changes: 0 additions & 2 deletions cmake/external/cuDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
CUDNN::cudnn_heuristic
)
endif()

mark_as_advanced(CUDNN_INCLUDE_DIR)
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)

# GLOB pattern of file to be excluded
set(contrib_ops_excluded_files
"bert/cudnn_fmha/*"
"bert/cutlass_fmha/*"
"bert/fastertransformer_decoder_attention/*"
"bert/flash_attention/*"
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

Expand Down Expand Up @@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
Expand Down
84 changes: 80 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
Expand Down Expand Up @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}

if (use_cudnn_flash_attention) {
return qkv_bytes;
}

return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
Expand Down Expand Up @@ -320,6 +326,68 @@ Status FlashAttention(
}
#endif

template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
constexpr bool is_bf16 = false;

T* attention_bias = const_cast<T*>(data.attention_bias);
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);

cudnn_sdpa::run(
data.output,
data.q,
data.k,
data.v,
attention_bias,
nullptr, // (optional) mask_sequence_lengths_q
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scaling factor applied prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
0, // sliding window length. 0 means no sliding window.
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);

return Status::OK();
}

template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"cudnn flash attention does not support float tensor");
}

#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
Expand Down Expand Up @@ -498,6 +566,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
Expand All @@ -512,10 +581,11 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;

// At most one fused kernel is enabled.
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
assert((static_cast<int>(data.use_flash_attention) +
static_cast<int>(data.use_memory_efficient_attention) +
static_cast<int>(fused_runner != nullptr) +
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down Expand Up @@ -577,6 +647,10 @@ Status QkvToContext(
}
#endif

if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}

#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
Expand All @@ -594,13 +668,15 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);

template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);

template <typename T>
Expand Down Expand Up @@ -104,9 +106,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;

// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}

void PrintDebugInfo() const {
Expand Down Expand Up @@ -139,6 +143,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"

using namespace onnxruntime::contrib::attention;

namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
Expand All @@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);

use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
Expand All @@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
use_cudnn_flash_attention_ = false;
if (enable_kernel_debug_info_) {
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
}
}

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
Expand All @@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
}

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
if (this->enable_kernel_debug_info_) {
this->Print();
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {

class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
Expand All @@ -40,7 +40,7 @@ class AttentionKernelOptions {
protected:
void Print() const;

void Initialize(int value, bool use_build_flag);
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);

private:
bool use_flash_attention_{true};
Expand Down
Loading

0 comments on commit fbc3927

Please sign in to comment.