Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] FusedMHARunnerFP16v2 thread-safe #21420

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
Expand All @@ -171,8 +171,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == relative_position_bias &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);

if (use_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
Expand All @@ -184,8 +184,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,10 @@ Status FusedTrtSelfAttention(

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(data.fused_runner);

const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length);

// B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);

fused_fp16_runner->setup(S, B);
const int b = (nullptr == data.mask_index ? batch_size : 2 * batch_size);

if (!causal) {
assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
Expand All @@ -261,12 +259,12 @@ Status FusedTrtSelfAttention(
packed_qkv = data.query;
}

fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream);
DUMP_TENSOR("fused output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
} else {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream);
DUMP_TENSOR("fused causal output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
}
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
Expand All @@ -206,8 +206,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro

// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
bool is_fMHA_supported = FusedMHARunnerFP16v2::is_supported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);
bool is_fMHA_supported = FusedMHARunnerFP16v2::IsSupported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);

if (!is_fMHA_supported) {
return fused_runner;
Expand All @@ -72,8 +72,8 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(parameters.sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(parameters.sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}

Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,9 @@ Status FusedScaledDotProductAttention(
parameters.token_count, stream);

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);

fused_fp16_runner->run(data.workspace, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len,
data.workspace, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,8 @@ Status FusedAttentionTrt(
}

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);

fused_fp16_runner->run(qkv, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len, qkv, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

Expand Down
Loading
Loading