From 11f6a18ee898eb829981cf4fe739a38f5bff41c0 Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 18 Sep 2023 18:02:03 -0700 Subject: [PATCH] Main Optimized --- ...decoder_masked_multihead_attention_impl.cu | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c8877a5e3f872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = add_vec(q, q_bias); } - T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; - if (has_beams) { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } - } - } else { + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); }