From 2cf31a20cf98de3386afac7c90a37d4412d4a269 Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Sat, 15 Jul 2023 00:41:06 -0700 Subject: [PATCH] Cuda: Decoder Masked Multihead Attention Q values get corrupted when using cross attention (#16721) ### Description Some code was accidentally moved into the `if(!params.is_cross_attention)' block, it must stay outside to work in both cases. ### Motivation and Context This causes invalid results. We detected this as a performance bug, as it caused the EOS early exit to never happen, and the runs would always take max_length to complete which was slow. --- .../decoder_masked_multihead_attention_impl.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 194ddb787e3db..5827bdfee1ab5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -179,6 +179,11 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio const float inv_sqrt_dh = params.scale; + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + } + if (!params.is_cross_attention) { Qk_vec_k k; @@ -241,9 +246,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio } if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - // Write the K values to the global memory cache. // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory // system. We designed it this way as it allows much better memory loads (and there are many