Skip to content

Commit

Permalink
[CUDA] Improve performance of DecoderMaskedMultiheadAttention on A100 (
Browse files Browse the repository at this point in the history
…#18695)

### Description

Currently there are 2 memory latency bound hotspots in the
DecoderMaskedMultiheadAttention kernel in terms of reading from global
memory - one reading K values and the other reading V values.

The current logic to read them both is something like this - 

for(int i=0; i<all_time_steps; ++i) {
  auto data_in_register = load_chunk_from_global_memory(i);
  do_compute(data_in_register);
}

This incurs a data read stall as data needs to be fetched into the
registers before compute can begin and the compute instruction incurs a
data read stall and this also does not fully utilize the memory
bandwidth of A100. The above logic can be re-written by doing some
manual loop unrolling so that more data read is triggered "in flight".

Unroll factor: 4
for(int i=0; i<all_time_steps; i+=4) {
  auto data_in_register_0 = load_chunk_from_global_memory(i);

  // Do bounds check for the following
  auto data_in_registers_1 = load_chunk_from_global_memory(i+1);
  auto data_in_register_2 = load_chunk_from_global_memory(i+2);
  auto data_in_register_3 = load_chunk_from_global_memory(i+3);

  do_compute(data_in_register_0);

 // Do bounds check for the following
 do_compute(data_in_register_1);
 do_compute(data_in_register_2);
 do_compute(data_in_register_3);
}

The idea is that the memory read latency is hidden by instructions being
issued for subsequent data reads. See here for more details -
https://forums.developer.nvidia.com/t/global-memory-access-synchronous-or-asynchronous-read-write/3256/4

Kernel clock cycles, latency, and memory bandwidth usage before:

<img width="1210" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/9969784/7a1f41f9-fdaa-47b3-b629-996d7b5eef17">

Kernel clock cycles, latency, and memory bandwidth usage after:

<img width="1205" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/9969784/c76b2d2f-43e3-43c9-a710-b5fae76f69b6">


As can be seen, the kernel latency is better by >30% and memory
throughput is better by >14%.

We have a 1P customer using the Whisper model (sampling using
BeamSearch) and the E2E perf for a representative production input is >
6.5%

Whisper E2E Latency for sample input before (on A100):

<img width="194" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/9969784/84ef59f5-84f2-4277-b9f8-b04c27336642">

Whisper E2E Latency for sample input after (on A100):

<img width="191" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/9969784/ca9fe5d3-f726-403e-b27c-be4ee07e0625">


This feature of loading more data in flight may not always yield gains
and it will be workload dependent. For now, keeping the feature turned
OFF by default. It can be turned ON by the user when needed.

### Motivation and Context
Improve BeamSearch performance on CUDA EP
  • Loading branch information
hariharans29 authored Jan 11, 2024
1 parent 2eb3db6 commit f68dfcd
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 52 deletions.
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FL
// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

// Environment variable to enable loading more KV data in flight in
// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels
constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT";

} // namespace attention

} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*

auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;

parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);

bool is_dmmha_packing = (key == nullptr && value == nullptr);
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont

auto& device_prop = GetDeviceProp();
DecoderMaskedMultiHeadAttentionParams parameters;

parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault<bool>(
attention::kDecoderMaskedAttentionLoadKVDataInFlight, false);

ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,52 +344,148 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
bool has_beams = params.cache_indir != nullptr && !params.is_cross_attention;
const int* beam_indices = has_beams ? &params.cache_indir[bi_max_seq_length] : nullptr;

for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
if (!params.kv_data_in_flight) {
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);

// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
if (ti < tlength) {
if (has_beams) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
if (ti < tlength) {
if (has_beams) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;

#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
} else {
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
} else {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
}
}
}

// Perform the dot product and normalize qk.
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * inv_sqrt_dh;
// Perform the dot product and normalize qk.
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * inv_sqrt_dh;

// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
if (is_masked) {
qk += params.mask_filter_value;
// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
if (is_masked) {
qk += params.mask_filter_value;
}

// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
if (params.relative_attention_bias != nullptr) {
qk = add_vec(qk,
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]);
}
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
} else {
// TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model
// Also tune it for different architectures. This works best for Whisper on 80GB A100.
constexpr int K_CACHE_DATA_LOAD_UNROLL = 4;

// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
if (params.relative_attention_bias != nullptr) {
qk = add_vec(qk,
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]);
for (int ti = ko; ti < ti_end; ti += (K_CACHE_DATA_LOAD_UNROLL * K_PER_ITER)) {
int is_masked[K_CACHE_DATA_LOAD_UNROLL];
int beam_offset[K_CACHE_DATA_LOAD_UNROLL];
int time_step[K_CACHE_DATA_LOAD_UNROLL];
bool time_bounds_cond[K_CACHE_DATA_LOAD_UNROLL];

#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
is_masked[k_unroll] = 1;
beam_offset[k_unroll] = 0;
time_step[k_unroll] = ti + k_unroll * K_PER_ITER;
time_bounds_cond[k_unroll] = (time_step[k_unroll] < tlength);
}

#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll] && params.mask != nullptr) {
is_masked[k_unroll] = params.mask[bi_total_seq_length + time_step[k_unroll]];
}
}

if (has_beams) {
int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size;

#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll]) {
beam_offset[k_unroll] = beam_indices[time_step[k_unroll]] * head_maxlength_headsize_prod;
}
}
}

// The keys loaded from the key cache.
K_vec_k k_vec[K_CACHE_DATA_LOAD_UNROLL][K_VECS_PER_THREAD];

#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll]) {
if (has_beams) {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + time_step[k_unroll];

k_vec[k_unroll][ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset[k_unroll] + jj * QK_ELTS_IN_16B])));
}
} else {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + time_step[k_unroll];

k_vec[k_unroll][ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
}
}
}

// Perform the dot product and normalize qk.
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk[K_CACHE_DATA_LOAD_UNROLL];
#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
qk[k_unroll] = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec[k_unroll]) * inv_sqrt_dh;
}

// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll] && is_masked[k_unroll] == 0) {
qk[k_unroll] += params.mask_filter_value;
}
}

// Store the product to shared memory. There's one qk value per timestep. Update the max.
#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) {
if (params.relative_attention_bias != nullptr) {
qk[k_unroll] = add_vec(qk[k_unroll],
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]);
}
qk_max = fmaxf(qk_max, qk[k_unroll]);
qk_smem[time_step[k_unroll]] = qk[k_unroll];
}
}
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}

Expand Down Expand Up @@ -504,18 +600,80 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
V_vec_acum out;
zero(out);

// Loop over the timesteps to compute the partial outputs.
for (int ti = vo; ti < tlength; ti += V_PER_ITER) {
// Fetch offset based on cache_indir when beam sampling
const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0;
const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0;
if (!params.kv_data_in_flight) {
// Loop over the timesteps to compute the partial outputs.
for (int ti = vo; ti < tlength; ti += V_PER_ITER) {
// Fetch offset based on cache_indir when beam sampling
const int beam_src = has_beams ? params.cache_indir[bi_max_seq_length + ti] : 0;
const int beam_offset = has_beams ? beam_src * params.num_heads * params.max_sequence_length * head_size : 0;

// Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * head_size]));

// Load the logits from shared memory.
T logit = logits_smem[ti];
out = fma(logit, v, out);
}
} else {
// Loop over the timesteps to compute the partial outputs.

// TODO(hasesh): Tune this value for different workloads. Currently, it is tuned for Whisper model
// Also tune it for different architectures. This works best for Whisper on 80GB A100.
constexpr int V_CACHE_DATA_LOAD_UNROLL = 8;

for (int ti = vo; ti < tlength; ti += V_CACHE_DATA_LOAD_UNROLL * V_PER_ITER) {
int beam_src[V_CACHE_DATA_LOAD_UNROLL];
int beam_offset[V_CACHE_DATA_LOAD_UNROLL];
int time_step[V_CACHE_DATA_LOAD_UNROLL];
bool time_bounds_cond[V_CACHE_DATA_LOAD_UNROLL];

#pragma unroll
for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
beam_src[v_unroll] = 0;
beam_offset[v_unroll] = 0;
time_step[v_unroll] = ti + v_unroll * V_PER_ITER;
time_bounds_cond[v_unroll] = (time_step[v_unroll] < tlength);
}

int head_maxlength_headsize_prod = params.num_heads * params.max_sequence_length * head_size;

if (has_beams) {
// Do the global memory read and corresponding compute in separate unrolled loops
#pragma unroll
for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
if (time_bounds_cond[v_unroll]) {
beam_src[v_unroll] = params.cache_indir[bi_max_seq_length + time_step[v_unroll]];
}
}

#pragma unroll
for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
if (time_bounds_cond[v_unroll]) {
beam_offset[v_unroll] = beam_src[v_unroll] * head_maxlength_headsize_prod;
}
}
}

// Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * head_size]));
// Load the values from the V-cache and logits from shared memory.
V_vec_k v[V_CACHE_DATA_LOAD_UNROLL];
T logits[V_CACHE_DATA_LOAD_UNROLL];

// Load the logits from shared memory.
T logit = logits_smem[ti];
out = fma(logit, v, out);
// Do the global memory read and compute in separate unrolled loops
#pragma unroll
for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
if (time_bounds_cond[v_unroll]) {
v[v_unroll] = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset[v_unroll] + time_step[v_unroll] * head_size]));
logits[v_unroll] = logits_smem[time_step[v_unroll]];
}
}

#pragma unroll
for (int v_unroll = 0; v_unroll < V_CACHE_DATA_LOAD_UNROLL; ++v_unroll) {
if (time_bounds_cond[v_unroll]) {
out = fma(logits[v_unroll], v[v_unroll], out);
}
}
}
}

// One group of threads computes the product(s) for the current timestep.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

Expand Down Expand Up @@ -62,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud
} // namespace cuda

} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,23 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {

tester.AddOutput<float>("present", past_dims, present);

// Run
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
// Run - Regular kernel execution path
{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test alternate kernel path of loading more KV data "in flight"
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());

tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
}
Expand Down Expand Up @@ -852,10 +865,22 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {

tester.AddOutput<MLFloat16>("present", past_dims, present);

// Run
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
// Run - Regular kernel execution path
{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test alternate kernel path of loading more KV data "in flight"
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
}
Expand Down

0 comments on commit f68dfcd

Please sign in to comment.