Skip to content

Commit

Permalink
cudnn flash attention draft
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 5, 2024
1 parent 2653226 commit ab89d9e
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 41 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -261,6 +262,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

Expand Down Expand Up @@ -297,7 +299,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}

return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}

} // namespace cuda
Expand Down
72 changes: 71 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

Expand Down Expand Up @@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
Expand Down Expand Up @@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}

if (use_cudnn_flash_attention) {
return qkv_bytes;
}

return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
Expand Down Expand Up @@ -320,6 +326,62 @@ Status FlashAttention(
}
#endif


template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(nullptr == data.mask_index);
assert(nullptr == data.relative_position_bias);

constexpr bool is_bf16 = false;

cudnn_sdpa::run(
data.q,
data.k,
data.v,
data.output,
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scale prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);

return Status::OK();
}

template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor");
}



#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
Expand Down Expand Up @@ -485,6 +547,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
Expand All @@ -502,7 +565,8 @@ Status QkvToContext(
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
int(data.fused_cross_attention_kernel != nullptr) +
int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down Expand Up @@ -564,6 +628,10 @@ Status QkvToContext(
}
#endif

if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}

#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
Expand All @@ -581,13 +649,15 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);

template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);

template <typename T>
Expand Down Expand Up @@ -104,9 +106,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;

// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}

void PrintDebugInfo() const {
Expand Down Expand Up @@ -139,6 +143,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, true);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AttentionKernelOptions {
bool use_flash_attention_{true};
bool use_efficient_attention_{true};
bool use_trt_fused_attention_{true};
bool use_cudnn_flash_attention_{false};
bool use_cudnn_flash_attention_{true};
bool use_unfused_{true};

bool use_trt_flash_attention_{true};
Expand Down
57 changes: 31 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
template <typename T>
bool NoQkvWorkspace_MHA_Cross(AttentionData<T>& data) {
// query, key and value are passed as Q, K and V for the following conditions.
return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr);
return (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
data.bias == nullptr;
}

// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format)
Expand All @@ -186,8 +189,9 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;

#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Add bias for Q
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
Expand All @@ -200,9 +204,7 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());
if (data.bias == nullptr) {
// Transpose query from BSNH to BNSH
Expand All @@ -229,7 +231,10 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
template <typename T>
bool NoQkvWorkspace_MHA_NoPast(AttentionData<T>& data) {
// query, key and value are passed as Q, K and V for the following conditions.
return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr;
return (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
data.bias == nullptr;
}

// For MultiHeadAttention without past state, with Q, K and V inputs
Expand Down Expand Up @@ -271,9 +276,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
Expand All @@ -286,9 +291,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
}

data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else if (data.fused_runner != nullptr) {
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.relative_position_bias == nullptr);

Expand Down Expand Up @@ -334,7 +337,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,

template <typename T>
bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData<T>& data) {
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV.
return data.past_key == nullptr && data.present_key != nullptr;
}
Expand Down Expand Up @@ -373,8 +378,9 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.v = data.present_value;
}

#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);

Expand All @@ -385,9 +391,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
Expand Down Expand Up @@ -436,8 +440,9 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
data.v = data.present_value;
}

#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Query(BxSxNxH) + Bias_Q => Q (BxSxNxH)
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
data.bias, data.query, data.q);
Expand All @@ -456,9 +461,7 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1);

data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());

constexpr int format = 0;
Expand Down Expand Up @@ -514,7 +517,8 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention || data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// unpack qkv to BSNH.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
Expand Down Expand Up @@ -586,7 +590,8 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;

if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention || data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Note that there is no bias so we need not output query to q.
data.q = const_cast<T*>(data.query);
// Unpack kv to BSNH.
Expand Down
Loading

0 comments on commit ab89d9e

Please sign in to comment.