Skip to content

Commit

Permalink
add cudnn flash attention to MultiHeadAttention op
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 6, 2024
1 parent 2653226 commit 1ac4cf8
Show file tree
Hide file tree
Showing 14 changed files with 683 additions and 42 deletions.
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#pragma once
#include <gsl/gsl>

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -46,6 +47,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Expand All @@ -69,6 +71,7 @@ struct AttentionParameters {
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_res_pos_bias;
gsl::span<const int64_t> relative_position_bias_dims;
float mask_filter_value;
float scale;
bool use_tf32;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ Status CheckInputs(const T* query,
output_parameters->mask_type = mask_type;
output_parameters->scale = scale;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
if (relative_position_bias != nullptr) {
output_parameters->relative_position_bias_dims = relative_position_bias->Shape().GetDims();
}
output_parameters->qkv_format = qkv_format;
}

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -261,6 +262,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

Expand Down Expand Up @@ -297,7 +299,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
Expand Down
80 changes: 79 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
Expand Down Expand Up @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}

if (use_cudnn_flash_attention) {
return qkv_bytes;
}

return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
Expand Down Expand Up @@ -320,6 +326,70 @@ Status FlashAttention(
}
#endif


template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
assert(data.relative_position_bias == nullptr || parameters.relative_position_bias_dims.size() == 4);
constexpr bool is_bf16 = false;

T* attention_bias = const_cast<T*>(data.relative_position_bias);
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);

cudnn_sdpa::run(
data.q,
data.k,
data.v,
data.output,
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scaling factor applied prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
attention_bias, // (optional) relative position bias.
parameters.relative_position_bias_dims, // Shape of attention_bias like [b or 1, h_q or 1, s_q, s_kv] or empty.
nullptr, // (optional) mask_sequence_lengths_q
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
0, // sliding window length. 0 means no sliding window.
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);

return Status::OK();
}

template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor");
}



#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
Expand Down Expand Up @@ -485,6 +555,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
Expand All @@ -502,7 +573,8 @@ Status QkvToContext(
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
int(data.fused_cross_attention_kernel != nullptr) +

Check warning on line 576 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:576: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

Check warning on line 577 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:577: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down Expand Up @@ -564,6 +636,10 @@ Status QkvToContext(
}
#endif

if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}

#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
Expand All @@ -581,13 +657,15 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);

template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);

template <typename T>
Expand All @@ -69,6 +71,7 @@ struct AttentionData {
const T* past = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;

const T* relative_position_bias = nullptr;

bool has_qkv_workspace = false;
Expand Down Expand Up @@ -104,9 +107,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;

// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}

void PrintDebugInfo() const {
Expand Down Expand Up @@ -139,6 +144,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"

using namespace onnxruntime::contrib::attention;

Expand All @@ -27,7 +28,11 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);

// Enable cuDNN flash attention by default only when it is stable (requires cuDNN version >= 9.3.0).
bool is_cudnn_stable = ::onnxruntime::cudnn_sdpa::is_stable();
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, is_cudnn_stable);

use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
Expand Down
Loading

0 comments on commit 1ac4cf8

Please sign in to comment.