Skip to content

Commit

Permalink
draft 2
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jan 26, 2024
1 parent 60c8bc4 commit 5d08a7f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ Status QkvToContext(
#if USE_TENSORRT_LLM_FMHA
// Run TRT-LLM fused attention.
if (nullptr != llm_fmha_runner) {
printf("*Trt_LLM_Attention\n");
return FusedTrtLlmAttention(stream, parameters, data);
}
#endif
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,

T* qkv = data.workspace;

// For LLM FMHA
if (data.llm_fmha_runner != nullptr) {

}

bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);

Expand Down Expand Up @@ -240,7 +245,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,

T* qkv = data.workspace;

bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_kernel = data.llm_fmha_runner != nullptr || (nullptr != fused_runner && !parameters.is_unidirectional);

assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);
Expand Down
79 changes: 45 additions & 34 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)

enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);

#if USE_TENSORRT_LLM_FMHA
enable_trt_llm_fmha_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtLlmAttention, false);
#else
enable_trt_llm_fmha_ = false;
#endif

#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 ||
Expand Down Expand Up @@ -141,8 +144,42 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
#endif


void* llm_fmha_runner = nullptr;
#if USE_TENSORRT_LLM_FMHA
bool use_llm_attention = enable_trt_llm_fmha_ &&
nullptr == relative_position_bias &&
(value != nullptr || key == nullptr) &&
(nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) &&
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length;
// && tensorrt_llm::kernels::MHARunner::fmha_supported(parameters.head_size, sm);
if (use_llm_attention) {
if (!tensorrt_llm::kernels::MHARunner::fmha_supported(parameters.head_size, sm)){
printf("llm not supported for head_size=%d, sm=%d\n", parameters.head_size, sm);
} else {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == llm_fmha_runner_.get()) {
llm_fmha_runner_ = tensorrt_llm::kernels::FusedMHARunnerV2::Create(tensorrt_llm::kernels::DATA_TYPE_FP16, num_heads_, parameters.head_size, parameters.scale);
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads
constexpr bool force_fp32_accuracy = true;
constexpr bool is_s_padded = true;
llm_fmha_runner_->setup_flags(force_fp32_accuracy, is_s_padded, parameters.is_unidirectional, num_heads_);
}

llm_fmha_runner_->setup(parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.batch_size * parameters.sequence_length);

// int S = llm_fmha_runner_->getSFromMaxSeqLen(parameters.sequence_length);
// if (/*llm_fmha_runner_->fmha_supported() &&*/ llm_fmha_runner_->isValid(S)) {
llm_fmha_runner = reinterpret_cast<void*>(llm_fmha_runner_.get());
}
}
#endif

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
bool use_flash_attention = llm_fmha_runner == nullptr &&
!disable_flash_attention_ &&
!past_no_bias &&
nullptr == relative_position_bias &&
nullptr == key_padding_mask &&
Expand Down Expand Up @@ -176,36 +213,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

void* llm_fmha_runner = nullptr;
#if USE_TENSORRT_LLM_FMHA
bool use_llm_attention = !use_flash_attention &&
enable_trt_flash_attention_ &&
nullptr == relative_position_bias &&
(value != nullptr || key == nullptr) &&
(nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) &&
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length;
if (use_llm_attention) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == llm_fmha_runner_.get()) {
llm_fmha_runner_.reset(new tensorrt_llm::kernels::FusedMHARunnerV2(tensorrt_llm::kernels::DATA_TYPE_FP16, num_heads_, parameters.head_size, parameters.scale));
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads = num_heads
constexpr bool force_fp32_accuracy = true;
constexpr bool is_s_padded = true;
llm_fmha_runner_->setup_flags(force_fp32_accuracy, is_s_padded, parameters.is_unidirectional, num_heads_);
}

// llm_fmha_runner_->setup(parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.batch_size * parameters.sequence_length);

if (llm_fmha_runner_->fmha_supported() && llm_fmha_runner_->isValid(parameters.sequence_length)) {
llm_fmha_runner = reinterpret_cast<void*>(llm_fmha_runner_.get());
}
}
#endif

bool use_fused_cross_attention = !use_flash_attention &&
// llm_fmha_runner_ == nullptr &&
bool use_fused_cross_attention = llm_fmha_runner == nullptr &&
!use_flash_attention &&
!disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
nullptr == relative_position_bias &&
Expand All @@ -227,7 +236,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

bool use_fused_runner = !use_flash_attention &&
bool use_fused_runner = llm_fmha_runner == nullptr &&
!use_flash_attention &&
!disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == relative_position_bias &&
Expand Down Expand Up @@ -262,7 +272,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {

bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;

bool use_memory_efficient_attention = !use_flash_attention &&
bool use_memory_efficient_attention = llm_fmha_runner == nullptr &&
!use_flash_attention &&
fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class MultiHeadAttention final : public CudaKernel {
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
mutable tensorrt_llm::kernels::UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> llm_fmha_runner_;
//mutable tensorrt_llm::kernels::UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> llm_fmha_runner_;
mutable std::unique_ptr<tensorrt_llm::kernels::MHARunner> llm_fmha_runner_;
};

} // namespace cuda
Expand Down

0 comments on commit 5d08a7f

Please sign in to comment.