Skip to content

Commit

Permalink
reserve a flag for cudnn flash attention; print debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 18, 2024
1 parent 8a758a0 commit 81ee535
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 9 deletions.
15 changes: 11 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,18 @@ enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
MATH = 8, // unfused
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention
MATH = 16, // unfused

// The following kernels might be deprected in the future.
TRT_FLASH_ATTENTION = 16,
TRT_CROSS_ATTENTION = 32,
TRT_CAUSAL_ATTENTION = 64,
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO";

// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

Expand All @@ -170,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable cuDNN flash attention.
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";

// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

Expand Down
23 changes: 22 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
kernel_options_ = this->GetAttentionKernelOptions();
if (kernel_options_->AllowDebugInfo()) {
node_name_ = info.node().Name();
}

disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();

Expand Down Expand Up @@ -211,6 +214,25 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
if (is_unidirectional_) {
debug_info.use_trt_causal_attention = true;
} else if (enable_trt_flash_attention_ && sequence_length >= kMinSequenceLengthFlashAttention) {
debug_info.use_trt_flash_attention = true;
} else {
debug_info.use_trt_fused_attention = true;
}
}
debug_info.is_float16 = sizeof(T) == 2;
debug_info.operator_name = "Attention";
debug_info.node_name = &(node_name_);
debug_info.Print();
}

cublasHandle_t cublas = GetCublasHandle(context);

typedef typename ToCudaType<T>::MappedType CudaT;
Expand Down Expand Up @@ -248,7 +270,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_fused_cross_attention,
use_memory_efficient_attention);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
;

typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class Attention final : public CudaKernel, public AttentionBase {
bool disable_memory_efficient_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;

const AttentionKernelOptions* kernel_options_;
std::string node_name_;

Check warning on line 35 in onnxruntime/contrib_ops/cuda/bert/attention.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention.h:35: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace cuda
Expand Down
94 changes: 93 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
// Licensed under the MIT License.

#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include <iomanip>
#include <iostream>
#include <sstream>
//#include "core/common/common.h"

Check warning on line 8 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc:8: Should have a space between // and comment [whitespace/comments] [4]
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
Expand All @@ -14,6 +18,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
Expand All @@ -22,12 +27,15 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
}

enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false);

// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForFlashAttentionPackedQKV,
Expand All @@ -51,7 +59,91 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {

void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
std::call_once(this->initialize_once_flag_, [&]() { this->Initialize(sdpa_kernel, use_build_flag); });
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
if (this->enable_kernel_debug_info_) {
this->Print();
}
});
}

void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_);
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_);
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_);
sstream << " MATH=" << int(use_unfused_);

// Output text in Cyan color to make it easier to spot
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

void AttentionKernelDebugInfo::Print() const {
std::stringstream sstream;
if (operator_name != nullptr) {
sstream << "Operator=" << operator_name;
}

if (node_name != nullptr && node_name->length() > 0) {
sstream << " Node=" << *(node_name);
}

if (is_bfloat16) {
sstream << " DataType=bf16";
} else if (is_float16) {
sstream << " DataType=fp16";
} else {
sstream << " DataType=fp32";
}

if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}

if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}

if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}

if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}

if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}

if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_cross_attention.value());
}

if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_causal_attention.value());
}

bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());

// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
}

// Output text in Cyan color to make it easier to spot.
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}

} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,63 @@

#pragma once
#include <mutex>

Check warning on line 5 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 <mutex> is an unapproved C++11 header. [build/c++11] [5] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h:5: <mutex> is an unapproved C++11 header. [build/c++11] [5]
#include <optional>
#include <string>

namespace onnxruntime {
struct AttentionKernelDebugInfo {
std::optional<bool> use_flash_attention = std::nullopt;
std::optional<bool> use_efficient_attention = std::nullopt;
std::optional<bool> use_trt_fused_attention = std::nullopt;
std::optional<bool> use_cudnn_flash_attention = std::nullopt;
std::optional<bool> use_trt_flash_attention = std::nullopt;
std::optional<bool> use_trt_cross_attention = std::nullopt;
std::optional<bool> use_trt_causal_attention = std::nullopt;
const char* operator_name = nullptr;
const std::string* node_name = nullptr;
bool is_float16 = false;
bool is_bfloat16 = false;
void Print() const;
};

class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
bool UseTrtFusedAttention() const { return use_trt_fused_attention_; }
bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; }
bool UseUnfusedAttention() const { return use_unfused_; }
bool UseTrtFlashAttention() const { return use_trt_flash_attention_; }
bool UseTrtCrossAttention() const { return use_trt_cross_attention_; }
bool UseTrtCausalAttention() const { return use_trt_causal_attention_; }

bool AllowDebugInfo() const { return enable_kernel_debug_info_; }

int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; }
int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; }

protected:
void Print() const;

Check warning on line 44 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Weird number of spaces at line-start. Are you using a 2-space indent? [whitespace/indent] [3] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h:44: Weird number of spaces at line-start. Are you using a 2-space indent? [whitespace/indent] [3]

void Initialize(int value, bool use_build_flag);

private:
bool use_flash_attention_{true};
bool use_efficient_attention_{true};
bool use_trt_fused_attention_{true};
bool use_cudnn_flash_attention_{false};
bool use_unfused_{true};

bool use_trt_flash_attention_{true};
bool use_trt_cross_attention_{true};

// Causal attention is disabled by default in #14732.
bool use_trt_causal_attention_{false};

bool enable_kernel_debug_info_{false};

int min_seq_len_for_flash_attention_packed_qkv_{0};

int min_seq_len_for_efficient_attention_fp32_{0};
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

kernel_options_ = this->GetAttentionKernelOptions();
if (kernel_options_->AllowDebugInfo()) {
node_name_ = info.node().Name();
}

disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();

// Memory efficient attention only supports float and float16, not bfloat16.
Expand Down Expand Up @@ -193,6 +197,18 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif


if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
debug_info.is_float16 = std::is_same<T, MLFloat16>::value;
debug_info.is_bfloat16 = std::is_same<T, BFloat16>::value;
debug_info.operator_name = "GroupQueryAttention";
debug_info.node_name = &(node_name_);
debug_info.Print();
}

// seqlens_k buffer
size_t seqlens_k_bytes = 0;
seqlens_k_bytes = sizeof(int) * parameters.batch_size;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class GroupQueryAttention final : public CudaKernel {
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr<int> zeros_;
const AttentionKernelOptions* kernel_options_;
std::string node_name_;

Check warning on line 37 in onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/group_query_attention.h:37: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace cuda
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");

kernel_options_ = this->GetAttentionKernelOptions();
if (kernel_options_->AllowDebugInfo()) {
node_name_ = info.node().Name();
}

disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

Expand Down Expand Up @@ -233,6 +237,24 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_fp16_runner_ != nullptr) {
if (enable_trt_flash_attention_ && sequence_length >= kMinSequenceLengthFlashAttention) {
debug_info.use_trt_flash_attention = true;
} else {
debug_info.use_trt_fused_attention = true;
}
}
debug_info.is_float16 = sizeof(T) == 2;
debug_info.operator_name = "MultiHeadAttention";
debug_info.node_name = &(node_name_);
debug_info.Print();
}

// When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace.
// TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime.
bool no_qkv_workspace = nullptr == value &&
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MultiHeadAttention final : public CudaKernel {
mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
const AttentionKernelOptions* kernel_options_;
std::string node_name_;

Check warning on line 41 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.h:41: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace cuda
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ template <typename T>
TrtFusedAttention<T>::TrtFusedAttention(const OpKernelInfo& info)
: CudaKernel(info) {
kernel_options_ = this->GetAttentionKernelOptions();
if (kernel_options_->AllowDebugInfo()) {
node_name_ = info.node().Name();
}
disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
}
Expand Down Expand Up @@ -295,6 +298,22 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
#endif

if (this->kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
if (this->enable_trt_flash_attention_ && parameters.sequence_length >= kMinSequenceLengthFlashAttention) {
debug_info.use_trt_flash_attention = true;
} else {
debug_info.use_trt_fused_attention = true;
}
}
debug_info.is_float16 = std::is_same<T, MLFloat16>::value;
debug_info.operator_name = "PackedAttention";
debug_info.node_name = &(this->node_name_);
debug_info.Print();
}

typedef typename ToCudaType<T>::MappedType CudaT;
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
Expand Down
Loading

0 comments on commit 81ee535

Please sign in to comment.