Skip to content

Commit

Permalink
move option object to cuda provider
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 17, 2024
1 parent a1c1eec commit 8a758a0
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 180 deletions.
13 changes: 13 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_
} // namespace sparse_attention

namespace attention {

enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
MATH = 8, // unfused

// The following kernels might be deprected in the future.
TRT_FLASH_ATTENTION = 16,
TRT_CROSS_ATTENTION = 32,
TRT_CAUSAL_ATTENTION = 64,
};

// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

Expand Down
10 changes: 1 addition & 9 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
kernel_options_ = this->GetAttentionKernelOptions();

disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();

enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif

#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
#endif
}

template <typename T>
Expand Down
54 changes: 27 additions & 27 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

// Initialize the singleton instance
AttentionKernelOptions AttentionKernelOptions::instance;
using namespace onnxruntime::contrib::attention;

Check warning on line 9 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc:9: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

void AttentionKernelOptions::Initialize(int value) {
namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
Expand All @@ -23,35 +19,39 @@ void AttentionKernelOptions::Initialize(int value) {
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
}

// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
value > 0 ? 0 : attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
kMinSeqLenForFlashAttentionPackedQKV,
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV);

min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : attention::kDefaultMinSeqLenForEfficientAttentionFp32);

initialized_ = true;
}

const AttentionKernelOptions* AttentionKernelOptions::GetInstance(int sdpa_kernel, bool force_init) {
if (force_init || !instance.initialized_) {
instance.Initialize(sdpa_kernel);
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);

if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
use_flash_attention_ = false;
#endif

#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
}
}

return &instance;
void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
std::call_once(this->initialize_once_flag_, [&]() { this->Initialize(sdpa_kernel, use_build_flag); });
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
25 changes: 4 additions & 21 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,12 @@
// Licensed under the MIT License.

#pragma once
#include <mutex>

Check warning on line 5 in onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 <mutex> is an unapproved C++11 header. [build/c++11] [5] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h:5: <mutex> is an unapproved C++11 header. [build/c++11] [5]

namespace onnxruntime {
namespace contrib {
namespace cuda {

enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
MATH = 8, // unfused

// The following kernels might be deprected in the future.
TRT_FLASH_ATTENTION = 16,
TRT_CROSS_ATTENTION = 32,
TRT_CAUSAL_ATTENTION = 64,
};

class AttentionKernelOptions {
public:
static const AttentionKernelOptions* GetInstance(int sdpa_kernel, bool force_init);
void InitializeOnce(int sdpa_kernel, bool use_build_flag);

bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
Expand All @@ -35,7 +21,7 @@ class AttentionKernelOptions {
int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; }

protected:
void Initialize(int value);
void Initialize(int value, bool use_build_flag);

private:
bool use_flash_attention_{true};
Expand All @@ -52,10 +38,7 @@ class AttentionKernelOptions {

int min_seq_len_for_efficient_attention_fp32_{0};

bool initialized_{false};
static AttentionKernelOptions instance;
std::once_flag initialize_once_flag_;
};

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
11 changes: 2 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,12 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
#if USE_FLASH_ATTENTION
kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
// Memory efficient attention only supports float and float16, not bfloat16.
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value || !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif

if (!disable_flash_attention_) {
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
}
Expand Down
10 changes: 1 addition & 9 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,13 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");

kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
kernel_options_ = this->GetAttentionKernelOptions();
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();

#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif

disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
TrtFusedAttention<T>::TrtFusedAttention(const OpKernelInfo& info)
: CudaKernel(info) {
kernel_options_ = AttentionKernelOptions::GetInstance(this->SdpaKernel(), false);
kernel_options_ = this->GetAttentionKernelOptions();
disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,9 @@ PackedMultiHeadAttention<T>::PackedMultiHeadAttention(const OpKernelInfo& info)

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention();
#else
disable_flash_attention_ = true;
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention();
#else
disable_memory_efficient_attention_ = true;
#endif
}

template <typename T>
Expand Down
18 changes: 17 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/tunable/cuda_tuning_context.h"

#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#endif

namespace onnxruntime {

void RunOnUnload(std::function<void()> function);
Expand Down Expand Up @@ -79,7 +83,14 @@ class CUDAExecutionProvider : public IExecutionProvider {
bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; }
bool IsNHWCPreferred() const { return info_.prefer_nhwc; }
bool UseTF32() const { return info_.use_tf32; }
int GetSdpaKernel() const { return info_.sdpa_kernel; }

#ifndef DISABLE_CONTRIB_OPS
// Attention kernel options parsed from sdpa_kernel cuda provider option.
const AttentionKernelOptions* GetAttentionKernelOptions() const {
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true);
return &attention_kernel_options_;
}
#endif

ProviderOptions GetProviderOptions() const override {
return CUDAExecutionProviderInfo::ToProviderOptions(info_);
Expand Down Expand Up @@ -111,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider {
// the tuning context might be altered when calling into a TunableOp
mutable cuda::tunable::CudaTuningContext tuning_context_;

#ifndef DISABLE_CONTRIB_OPS
// Attention kernel options parsed from sdpa_kernel cuda provider option.
mutable AttentionKernelOptions attention_kernel_options_;
#endif

class PerThreadContext final {
public:
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ class CudaKernel : public OpKernel {
return provider_->UseTF32();
}

int SdpaKernel() const {
return provider_->GetSdpaKernel();
#ifndef DISABLE_CONTRIB_OPS
const AttentionKernelOptions* GetAttentionKernelOptions() const {
return provider_->GetAttentionKernelOptions();
}
#endif

tunable::CudaTuningContext* GetTuningContext() const {
return static_cast<tunable::CudaTuningContext*>(provider_->GetTuningContext());
Expand Down
Loading

0 comments on commit 8a758a0

Please sign in to comment.