Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] cuDNN Flash Attention #21629

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cmake/external/cuDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
CUDNN::cudnn_heuristic
)
endif()

mark_as_advanced(CUDNN_INCLUDE_DIR)
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)

# GLOB pattern of file to be excluded
set(contrib_ops_excluded_files
"bert/cudnn_fmha/*"
"bert/cutlass_fmha/*"
"bert/fastertransformer_decoder_attention/*"
"bert/flash_attention/*"
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

Expand Down Expand Up @@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
Expand Down
84 changes: 80 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
Expand Down Expand Up @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}

if (use_cudnn_flash_attention) {
return qkv_bytes;
}

return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
Expand Down Expand Up @@ -320,6 +326,68 @@ Status FlashAttention(
}
#endif

template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
constexpr bool is_bf16 = false;

T* attention_bias = const_cast<T*>(data.attention_bias);
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);

cudnn_sdpa::run(
data.output,
data.q,
data.k,
data.v,
attention_bias,
nullptr, // (optional) mask_sequence_lengths_q
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scaling factor applied prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
0, // sliding window length. 0 means no sliding window.
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);

return Status::OK();
}

template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"cudnn flash attention does not support float tensor");
}

#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
Expand Down Expand Up @@ -498,6 +566,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
Expand All @@ -512,10 +581,11 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;

// At most one fused kernel is enabled.
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
assert((static_cast<int>(data.use_flash_attention) +
static_cast<int>(data.use_memory_efficient_attention) +
static_cast<int>(fused_runner != nullptr) +
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down Expand Up @@ -577,6 +647,10 @@ Status QkvToContext(
}
#endif

if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}

#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
Expand All @@ -594,13 +668,15 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);

template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);

template <typename T>
Expand Down Expand Up @@ -104,9 +106,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;

// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}

void PrintDebugInfo() const {
Expand Down Expand Up @@ -139,6 +143,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"

using namespace onnxruntime::contrib::attention;

namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
Expand All @@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);

use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
Expand All @@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
use_cudnn_flash_attention_ = false;
if (enable_kernel_debug_info_) {
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
}
}

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
Expand All @@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
}

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
if (this->enable_kernel_debug_info_) {
this->Print();
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {

class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
Expand All @@ -40,7 +40,7 @@ class AttentionKernelOptions {
protected:
void Print() const;

void Initialize(int value, bool use_build_flag);
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);

private:
bool use_flash_attention_{true};
Expand Down
Loading
Loading