Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix performance bug in DecoderMaskedMultiheadAttention for BeamSearch #17613

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
q = add_vec(q, q_bias);
}


T* params_k_cache = reinterpret_cast<T*>(params.k_cache);

const float inv_sqrt_dh = params.scale;
Expand Down Expand Up @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
if (ti < tlength) {
if (has_beams) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;

if (has_beams) {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

if (ti < tlength) {
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
}
} else {
} else {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

if (ti < tlength) {
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
Expand Down
Loading