Skip to content

Commit

Permalink
Cuda: Decoder Masked Multihead Attention Q values get corrupted when …
Browse files Browse the repository at this point in the history
…using cross attention (#16721)

### Description
Some code was accidentally moved into the
`if(!params.is_cross_attention)' block, it must stay outside to work in
both cases.

### Motivation and Context
This causes invalid results. We detected this as a performance bug, as
it caused the EOS early exit to never happen, and the runs would always
take max_length to complete which was slow.
  • Loading branch information
RyanUnderhill authored Jul 15, 2023
1 parent 2b7a94e commit 2cf31a2
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

const float inv_sqrt_dh = params.scale;

if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
}

if (!params.is_cross_attention) {
Qk_vec_k k;

Expand Down Expand Up @@ -241,9 +246,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
}

if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;

// Write the K values to the global memory cache.
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
Expand Down

0 comments on commit 2cf31a2

Please sign in to comment.