Skip to content

Commit

Permalink
Add cudnn sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 19, 2024
1 parent d79e3c5 commit 9c78a6d
Show file tree
Hide file tree
Showing 22 changed files with 681 additions and 51 deletions.
2 changes: 0 additions & 2 deletions cmake/external/cuDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
CUDNN::cudnn_heuristic
)
endif()

mark_as_advanced(CUDNN_INCLUDE_DIR)
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)

# GLOB pattern of file to be excluded
set(contrib_ops_excluded_files
"bert/cudnn_fmha/*"
"bert/cutlass_fmha/*"
"bert/fastertransformer_decoder_attention/*"
"bert/flash_attention/*"
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cpu/utils/console_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
const char* name,
const T* tensor,
gsl::span<const int64_t>& dims) {
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
if (!dumper->IsEnabled()) {
return;
}

if ((tensor == nullptr || dims.size() == 0)) {
std::cout << std::string(name) << " is None" << std::endl;
return;
}
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

Expand Down Expand Up @@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
Expand Down
77 changes: 76 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
Expand Down Expand Up @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}

if (use_cudnn_flash_attention) {
return qkv_bytes;
}

return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
Expand Down Expand Up @@ -320,6 +326,67 @@ Status FlashAttention(
}
#endif

template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
constexpr bool is_bf16 = false;

T* attention_bias = const_cast<T*>(data.attention_bias);
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);

cudnn_sdpa::run(
data.output,
data.q,
data.k,
data.v,
attention_bias,
nullptr, // (optional) mask_sequence_lengths_q
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scaling factor applied prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
0, // sliding window length. 0 means no sliding window.
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);

return Status::OK();
}

template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "cudnn flash attention does not support float tensor");

Check warning on line 387 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:387: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
Expand Down Expand Up @@ -498,6 +565,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
Expand All @@ -515,7 +583,8 @@ Status QkvToContext(
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
int(data.fused_cross_attention_kernel != nullptr) +

Check warning on line 586 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:586: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

Check warning on line 587 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:587: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down Expand Up @@ -577,6 +646,10 @@ Status QkvToContext(
}
#endif

if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}

#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
Expand All @@ -594,13 +667,15 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);

template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);

template <typename T>
Expand Down Expand Up @@ -104,9 +106,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;

// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}

void PrintDebugInfo() const {
Expand Down Expand Up @@ -139,6 +143,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"

using namespace onnxruntime::contrib::attention;

namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
Expand All @@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);

use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
Expand All @@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
use_cudnn_flash_attention_ = false;
if (enable_kernel_debug_info_) {
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
}
}

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
Expand All @@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
}

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
if (this->enable_kernel_debug_info_) {
this->Print();
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {

class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
Expand All @@ -40,7 +40,7 @@ class AttentionKernelOptions {
protected:
void Print() const;

void Initialize(int value, bool use_build_flag);
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);

private:
bool use_flash_attention_{true};
Expand Down
Loading

0 comments on commit 9c78a6d

Please sign in to comment.