From 4ffc1ff3b4aafafab5a7e5045b219e62d9638a87 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Sat, 2 Nov 2024 22:05:56 +0900 Subject: [PATCH] DMMHA: add unit tests; fix CPU, CUDA kernel (#22567) ### Description Fixes: (1) cpu kernel: applying scale before bias and mask like other MHA ops (2) cpu kernel: correct offset during appending past to present. (3) cuda kernel: apply mask if provided; fix output_qk offset. Add DMMHA unit tests --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 2 +- .../contrib_ops/cpu/bert/attention_helper.h | 3 +- .../decoder_masked_multihead_attention.cc | 10 +- .../bert/decoder_masked_multihead_attention.h | 4 +- ...decoder_masked_multihead_attention_impl.cu | 5 +- .../core/graph/contrib_ops/bert_defs.cc | 1 - ...oder_masked_multihead_attention_op_test.cc | 723 +++++++++--------- 7 files changed, 381 insertions(+), 367 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index dc9ba80af5ba4..87938f3728750 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -77,7 +77,7 @@ class AttentionCPUBase : public AttentionBase { // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f). // Merge padding mask with causal mask, and broadcast to 3D (BxSxT). PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data), - causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); + causal, batch_size, sequence_length, kv_sequence_length, past_sequence_length, mask_filter_value_); DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 4d435f71cc195..37bb5664393c9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -120,9 +120,10 @@ void PrepareMask(const int32_t* mask_index, bool causal, int batch_size, int sequence_length, + int kv_sequence_length, int past_sequence_length, float mask_filter_value) { - const int all_sequence_length = past_sequence_length + sequence_length; + const int all_sequence_length = past_sequence_length + kv_sequence_length; // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc index b2aaa9cb11beb..e6f65f92e14f4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.cc @@ -339,6 +339,7 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( T* attention_probs_ptr = reinterpret_cast(attention_probs) + last_offset; math::Dot(head_size, q_vec, K + i * head_size, attention_probs_ptr, nullptr); + *attention_probs_ptr *= scale; // Apply the attention bias and mask if (attn_bias_data != nullptr) { *attention_probs_ptr += attn_bias_data[attn_bias_base_offset + past_sequence_length]; @@ -348,7 +349,6 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( if (is_masked) { *attention_probs_ptr += mask_filter_value_; } - *attention_probs_ptr *= scale; } { @@ -362,6 +362,8 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size; T* output = reinterpret_cast(attention_probs) + j + i * probs_matrix_size; math::Dot(head_size, q_vec, past_k_vec, output, nullptr); + + *output *= scale; // Apply the attention bias and mask if (attn_bias_data != nullptr) { *output += attn_bias_data[attn_bias_base_offset + j]; @@ -371,11 +373,11 @@ void DecoderMaskedMultiHeadAttention::ComputeAttentionProbsWithBeams( if (is_masked) { *output += mask_filter_value_; } - *output *= scale; } } // Append current key to present key (past_present_share_buffer_ is true) - memcpy(present_key_data + i * max_sequence_length * head_size, K + i * head_size, head_size * sizeof(T)); + memcpy(present_key_data + (i * max_sequence_length + past_sequence_length) * head_size, + K + i * head_size, head_size * sizeof(T)); } }); @@ -460,7 +462,7 @@ void DecoderMaskedMultiHeadAttention::ComputeVxAttentionScoreWithBeams( } } // Append current value to present value (past_present_share_buffer_ is true) - memcpy(present_value_data + i * max_sequence_length * v_head_size, + memcpy(present_value_data + (i * max_sequence_length + past_sequence_length) * v_head_size, V + i * v_head_size, v_head_size * sizeof(T)); } diff --git a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h index 68d1b9751301d..d5167e8989669 100644 --- a/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/decoder_masked_multihead_attention.h @@ -33,7 +33,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC const Tensor* cache_indir, OpKernelContext* context, int beam_width, - Tensor* scaled_qk = nullptr) const; + Tensor* output_qk = nullptr) const; void ComputeAttentionProbsWithBeams(T* attention_probs, const T* Q, const T* K, @@ -50,7 +50,7 @@ class DecoderMaskedMultiHeadAttention final : public OpKernel, public AttentionC bool broadcast_attn_bias_dim_1, const int32_t* cache_indir_data, int beam_width, - T* scaled_qk_data = nullptr) const; + T* output_qk_data = nullptr) const; void ComputeVxAttentionScoreWithBeams(T* output, T* tmp_buffer, const T* attention_probs, diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 8edae863ff44e..e4c1659c0fb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -298,6 +298,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.attention_bias != nullptr) { qk = add_vec(qk, reinterpret_cast(params.attention_bias)[attn_bias_offset + tlength]); } + if (params.mask != nullptr && params.mask[bi_total_seq_length + params.past_sequence_length] == 0) { + qk += params.mask_filter_value; + } qk_max = qk; qk_smem[tlength] = qk; } @@ -534,7 +537,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio if (params.out_qk != nullptr) { // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] - float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + float* target = (reinterpret_cast(params.out_qk)) + (static_cast(bhi) * (sum_tlength + 1)); for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { target[ti] = (float)(qk_smem[ti]); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0a261d8f731f2..f2a2a52f8334f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -908,7 +908,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(9, "cache_indirection", - // This input is useful for CUDA EP only. "A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies " "which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration", "M", diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 17c9e8592f64e..17685ab82f0ef 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -15,23 +15,20 @@ namespace onnxruntime { namespace test { -// This op is currently only supported on CUDA- so test it only for CUDA -#ifdef USE_CUDA - template static std::vector CreateOnes(int size) { std::vector f; f.reserve(size); for (int i = 0; i < size; ++i) { - f.push_back(T(1)); + f.push_back(T(1.0f)); } return f; } template -static std::vector CreateValues(int size, int val) { +static std::vector CreateValues(int size, float val) { std::vector f; f.reserve(size); @@ -72,39 +69,25 @@ static std::vector CreateRandom(int size) { return f; } -// QKV template -static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size); +float ToFloat(T val); template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, 0); - - for (int b = 0; b < batch_size; ++b) { - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < 3 * hidden_size; ++j) { - float sum = 0; - - for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k] * weights[k * 3 * hidden_size + j]; - } - - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = sum + bias[j]; - } - } - } - - return qkv; +constexpr float ToFloat(float val) { + return val; } template <> -std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { - std::vector qkv; - qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); +float ToFloat(MLFloat16 val) { + return val.ToFloat(); +} + +// QKV +template +static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, + int batch_size, int sequence_length, int hidden_size) { + std::vector qkv; + qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int i = 0; i < sequence_length; ++i) { @@ -112,10 +95,11 @@ std::vector QKV(std::vector& input, std::vector float sum = 0; for (int k = 0; k < hidden_size; ++k) { - sum += input[b * sequence_length * hidden_size + i * hidden_size + k].ToFloat() * weights[k * 3 * hidden_size + j].ToFloat(); + sum += ToFloat(input[b * sequence_length * hidden_size + i * hidden_size + k]) * + ToFloat(weights[k * 3 * hidden_size + j]); } - qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + bias[j].ToFloat()); + qkv[b * sequence_length * 3 * hidden_size + i * 3 * hidden_size + j] = static_cast(sum + ToFloat(bias[j])); } } } @@ -180,15 +164,17 @@ void CheckEquality(T* data_1, T* data_2, int batch_size, int num_heads, int num_ // Reorder 'K' from [B, N, S, H] to [B, N, H/x, S, x] where x = (sizeof(T) / 16); // Copy 'V' over as is template -static std::vector ReorderKVCache(std::vector& unordered_k_cache, +static std::vector ReorderKVCache(const std::vector& unordered_k_cache, int batch_size, int num_heads, int sequence_length, - int head_size, int max_sequence_length) { + int head_size, int max_sequence_length, bool merge_past_kv = true) { std::vector ordered(unordered_k_cache.size(), T{0.f}); // Copy V over - size_t v_start = unordered_k_cache.size() / 2; - for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { - ordered[i] = unordered_k_cache[i]; + if (merge_past_kv) { + size_t v_start = unordered_k_cache.size() / 2; + for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { + ordered[i] = unordered_k_cache[i]; + } } // Now let us re-order K and copy it over to the final buffer @@ -203,7 +189,8 @@ static std::vector ReorderKVCache(std::vector& unordered_k_cache, (h * max_sequence_length * head_size); int input_base_offset = base_offset + (s * head_size) + (c * num_inner_elements); - int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + (s * num_inner_elements); + int output_base_offset = base_offset + (c * max_sequence_length * num_inner_elements) + + (s * num_inner_elements); for (int e = 0; e < num_inner_elements; ++e) { ordered[output_base_offset + e] = unordered_k_cache[input_base_offset + e]; @@ -224,7 +211,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache T* k, int batch_size, int num_heads, int past_sequence_length, int max_sequence_length, - int head_size) { + int head_size, bool merge_past_kv = true) { std::vector merged = ordered_k_cache; int total_seq_length = past_sequence_length + 1; @@ -249,10 +236,11 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache input_value = ordered_k_cache[input_offset]; } else { int hidden_size = num_heads * head_size; - int input_offset = (b * 3 * hidden_size) + - (n * num_chunks * chunk_size) + - (c * chunk_size) + - h; + int input_offset = merge_past_kv ? ((b * 3 * hidden_size) + + (n * num_chunks * chunk_size) + + (c * chunk_size) + + h) + : ((b * hidden_size) + n * head_size + c * chunk_size + h); input_value = k[input_offset]; } @@ -272,7 +260,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache return merged; } -// GIven a pointer to the 'V' component of the past cache, we will merge it +// Given a pointer to the 'V' component of the past cache, we will merge it // with current 'V' in-place template static void MergeReorderedKVCacheWithV(T* v_cache, @@ -299,7 +287,8 @@ static void MergeReorderedKVCacheWithV(T* v_cache, template static std::pair, std::vector> MergePastKWithPresentKAndTranspose(T* past_k, T* present_k, int num_batch, int num_heads, - int past_sequence_length, int max_sequence_length, + int past_sequence_length, + int max_sequence_length, int head_size) { int total_seq_length = (past_sequence_length + 1); std::vector merged_k(num_batch * num_heads * total_seq_length * head_size, T{0.f}); @@ -312,16 +301,18 @@ static std::pair, std::vector> MergePastKWithPresentKAndTransp T input_value{0.f}; if (s < past_sequence_length) { - int input_offset = b * num_heads * max_sequence_length * head_size + (n * max_sequence_length * head_size) + (s * head_size) + h; + int input_offset = b * num_heads * max_sequence_length * head_size + + (n * max_sequence_length * head_size) + (s * head_size) + h; input_value = past_k[input_offset]; } else { int hidden_size = num_heads * head_size; - // Offset by 3* hidden_size because QKV data contains Q, K, and V per batch + // Offset by 3 * hidden_size because QKV data contains Q, K, and V per batch int input_offset = (b * 3 * hidden_size) + (n * head_size) + h; input_value = present_k[input_offset]; } - int output_offset = b * num_heads * total_seq_length * head_size + (n * total_seq_length * head_size) + (s * head_size) + h; + int output_offset = b * num_heads * total_seq_length * head_size + + (n * total_seq_length * head_size) + (s * head_size) + h; merged_k[output_offset] = input_value; } @@ -383,15 +374,11 @@ void ValidateReorderedMergedKWithK(T* k, T* k_cache, int batch_size, int num_hea // QK_Transpose template std::vector QK_Transpose(T* q_matrix, T* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size); - -template <> -std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { + int batch_size, int num_heads, int total_sequence_length, int head_size) { int hidden_size = num_heads * head_size; - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, 0); + std::vector qk_transpose; + qk_transpose.resize(batch_size * num_heads * total_sequence_length, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -409,50 +396,12 @@ std::vector QK_Transpose(float* q_matrix, float* k_transpose_matrix, for (int j = 0; j < total_sequence_length; ++j) { float sum = 0; for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k] * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j]); + sum += (ToFloat(q_matrix[input_1_base_offset + i * head_size + k]) * + ToFloat(k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j])); } float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = scale * sum; - } - } - } - } - - return qk_transpose; -} - -template <> -std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { - int hidden_size = num_heads * head_size; - - std::vector qk_transpose; - qk_transpose.resize(batch_size * num_heads * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * 3 * hidden_size) + - (n * head_size); - - int input_2_base_offset = (b * num_heads * total_sequence_length * head_size) + - (n * total_sequence_length * head_size); - - int output_base_offset = (b * num_heads * total_sequence_length) + - (n * total_sequence_length); - - // sequence_length == 1 - for (int i = 0; i < 1; ++i) { - for (int j = 0; j < total_sequence_length; ++j) { - float sum = 0; - for (int k = 0; k < head_size; ++k) { - sum += (q_matrix[input_1_base_offset + i * head_size + k].ToFloat() * - k_transpose_matrix[input_2_base_offset + k * total_sequence_length + j].ToFloat()); - } - - float scale = 1 / sqrt(static_cast(head_size)); - qk_transpose[output_base_offset + i * total_sequence_length + j] = MLFloat16(scale * sum); + qk_transpose[output_base_offset + i * total_sequence_length + j] = static_cast(scale * sum); } } } @@ -464,26 +413,23 @@ std::vector QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int head_size); - -template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { + int sequence_length, int total_sequence_length) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, 0); + std::vector softmax_qk_transpose; + softmax_qk_transpose.resize(static_cast(batch_size) * num_heads * sequence_length * total_sequence_length, + static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { int base_offset = (b * num_heads * sequence_length * total_sequence_length) + (n * sequence_length * total_sequence_length); - float max = std::numeric_limits::min(); + float max = std::numeric_limits::lowest(); for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); if (val > max) { max = val; } @@ -491,52 +437,13 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_si float denom = 0; for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); denom += std::exp(val - max); } for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s]; - softmax_qk_transpose[base_offset + s] = std::exp(val - max) / (denom + (float)0.000001); - } - } - } - - return softmax_qk_transpose; -} - -template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, - int sequence_length, int total_sequence_length, int /*head_size*/) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector softmax_qk_transpose; - softmax_qk_transpose.resize(batch_size * num_heads * sequence_length * total_sequence_length, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - float max = std::numeric_limits::min(); - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - if (val > max) { - max = val; - } - } - - float denom = 0; - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - denom += std::exp(val - max); - } - - for (int s = 0; s < total_sequence_length; ++s) { - auto val = qk_transpose_matrix[base_offset + s].ToFloat(); - softmax_qk_transpose[base_offset + s] = MLFloat16(std::exp(val - max) / (denom + (float)0.000001)); + auto val = ToFloat(qk_transpose_matrix[base_offset + s]); + softmax_qk_transpose[base_offset + s] = static_cast(std::exp(val - max) / (denom + (float)0.000001)); } } } @@ -550,19 +457,13 @@ std::vector Softmax_QK_Transpose_V(T* softmax_qk_transpose_matrix, T* v_matrix, int batch_size, int num_heads, int sequence_length, int total_sequence_length, int max_sequence_length, - int head_size); -template <> -std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, - float* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { + int head_size) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, 0); + std::vector output; + output.resize(batch_size * sequence_length * num_heads * head_size, static_cast(0.f)); for (int b = 0; b < batch_size; ++b) { for (int n = 0; n < num_heads; ++n) { @@ -580,11 +481,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, float sum = 0; for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k] * - v_matrix[input_2_base_offset + k * head_size + j]); + sum += (ToFloat(softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k]) * + ToFloat(v_matrix[input_2_base_offset + k * head_size + j])); } - output[output_base_offset + i * head_size + j] = sum; + output[output_base_offset + i * head_size + j] = static_cast(sum); } } } @@ -593,48 +494,11 @@ std::vector Softmax_QK_Transpose_V(float* softmax_qk_transpose_matrix, return output; } -template <> -std::vector Softmax_QK_Transpose_V(MLFloat16* softmax_qk_transpose_matrix, - MLFloat16* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { - if (sequence_length != 1) { - throw std::runtime_error("Not supported"); - } - - std::vector output; - output.resize(batch_size * sequence_length * num_heads * head_size, MLFloat16(0.f)); - - for (int b = 0; b < batch_size; ++b) { - for (int n = 0; n < num_heads; ++n) { - int input_1_base_offset = (b * num_heads * sequence_length * total_sequence_length) + - (n * sequence_length * total_sequence_length); - - int input_2_base_offset = (b * num_heads * max_sequence_length * head_size) + - (n * max_sequence_length * head_size); - - int output_base_offset = (b * num_heads * sequence_length * head_size) + - (n * sequence_length * head_size); - - for (int i = 0; i < sequence_length; ++i) { - for (int j = 0; j < head_size; ++j) { - float sum = 0; - - for (int k = 0; k < total_sequence_length; ++k) { - sum += (softmax_qk_transpose_matrix[input_1_base_offset + i * total_sequence_length + k].ToFloat() * - v_matrix[input_2_base_offset + k * head_size + j].ToFloat()); - } - - output[output_base_offset + i * head_size + j] = MLFloat16(sum); - } - } - } - } +// Currently we only support CUDA for DecoderMaskedSelfAttention +#ifdef USE_CUDA - return output; -} -TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { +template +static void TestDecoderMaskedSelfAttention() { // The kernel is only supported on CC 5.3 or higher GPUs if (NeedSkipIfCudaArchLowerThan(530)) { return; @@ -661,19 +525,19 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { }; constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; + constexpr int num_heads = 12; for (MyTestCase test_case : test_cases) { int batch_size = test_case.batch_size; int past_sequence_length = test_case.past_sequence_length; int hidden_size = test_case.hidden_size; - int head_size = (hidden_size / number_of_heads); + int head_size = (hidden_size / num_heads); int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("past_present_share_buffer", static_cast(1)); std::vector input_dims = {batch_size, sequence_length, hidden_size}; @@ -681,38 +545,38 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { std::vector bias_dims = {3 * hidden_size}; std::vector output_dims = {batch_size, sequence_length, hidden_size}; - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); // Mask tester.AddOptionalInputEdge(); // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + std::vector past_dims = {2, batch_size, num_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * num_heads * max_sequence_length * head_size; - auto kv_cache = CreateRandom(past_present_size); + auto kv_cache = CreateRandom(past_present_size); - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + num_heads, past_sequence_length, head_size, max_sequence_length); // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); + int chunk_size = 16 / sizeof(T); int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + auto transposed = Transpose(kv_cache.data(), batch_size, num_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, num_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); // Rel - tester.AddOptionalInputEdge(); + tester.AddOptionalInputEdge(); // Past sequence length std::vector arr_past_sequence_len(1, past_sequence_length); @@ -722,41 +586,44 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); auto k_merged = pair.first; auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, num_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, num_heads, + sequence_length, total_sequence_length); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); // Validate our test logic // We want to validate if our merged "unordered" K is the same as // the merged "ordered" K so that the QKT we do in our test code // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, num_heads, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + num_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, num_heads, sequence_length, total_sequence_length, + max_sequence_length, head_size); // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("output", input_dims, output); + tester.AddOutput("present", past_dims, present); - tester.SetOutputTolerance(0.001f, 0.001f); + if (std::is_same::value) { + tester.SetOutputTolerance(0.005f); + } else { + tester.SetOutputTolerance(0.001f, 0.001f); + } // Run - Regular kernel execution path { @@ -778,150 +645,292 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { } } -TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { - // The kernel is only supported on CC 5.3 or higher GPUs - if (NeedSkipIfCudaArchLowerThan(530)) { - return; - } - - // Buckets for test data: - // batch_size: 1, >=2 - // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) - // head_size: 32, 64, 128 - struct MyTestCase { - int batch_size; - int past_sequence_length; - int hidden_size; - } test_cases[] = { - {1, 0, 768}, - {1, 1, 768}, - {3, 30, 384}, - {8, 31, 1536}, - {4, 256, 384}, - {3, 1024, 768}, - {2, 2046, 1536}, - {1, 2047, 384}, - {2, 3000, 768}, - }; - - constexpr int sequence_length = 1; - constexpr int number_of_heads = 12; - - for (MyTestCase test_case : test_cases) { - int batch_size = test_case.batch_size; - int past_sequence_length = test_case.past_sequence_length; - int hidden_size = test_case.hidden_size; +#endif // USE_CUDA - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); +template +static std::vector CalculateOutputQK(const std::vector& q, const std::vector& k, + const std::vector& mask_index, const std::vector& attention_bias, + int batch_size, int num_heads, + int sequence_length, int max_sequence_length, int head_size) { + // q (B, 1, NH), k (B, N, L(M), H) -> qk (B, N, 1, L) + // mask_index (B, L), (optional) attention_bias (1, 1, 1, L) + float scale = 1 / sqrt(static_cast(head_size)); + std::vector output_qk; + output_qk.resize(static_cast(batch_size) * num_heads * sequence_length, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < sequence_length; ++s) { + float mask_value = (mask_index[b * sequence_length + s] == 0) ? -10000.f : 0.f; + float bias_value = (attention_bias.empty()) ? 0.f : ToFloat(attention_bias[s]); + float sum = 0; + for (int h = 0; h < head_size; ++h) { + sum += ToFloat(q[b * num_heads * head_size + n * head_size + h]) * + ToFloat(k[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + output_qk[b * num_heads * sequence_length + n * sequence_length + s] = + static_cast(scale * sum + mask_value + bias_value); + } + } + } - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + return output_qk; +} - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); +template +static std::vector CalculateOutput(const std::vector& softmax, const std::vector& v, int batch_size, + int num_heads, int sequence_length, int max_sequence_length, int head_size) { + // softmax (B, N, 1, L) v (B, N, L(M), H) -> output (B, N, 1, H) + std::vector output; + output.resize(static_cast(batch_size) * num_heads * head_size, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + float sum = 0; + for (int s = 0; s < sequence_length; ++s) { + sum += ToFloat(softmax[b * num_heads * sequence_length + n * sequence_length + s]) * + ToFloat(v[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + output[b * num_heads * head_size + n * head_size + h] = static_cast(sum); + } + } + } - // Mask - tester.AddOptionalInputEdge(); + return output; +} - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; +template +static std::vector MergePast(const std::vector& past, const std::vector& current, int batch_size, + int num_heads, int past_seq_len, int max_seq_len, int head_size) { + // past (B, N, S(M), H), current (B, 1, NH) -> merged (B, N, S+1(M), H) + std::vector merged = past; + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + merged[b * num_heads * max_seq_len * head_size + n * max_seq_len * head_size + past_seq_len * head_size + h] = + current[b * num_heads * head_size + n * head_size + h]; + } + } + } - auto kv_cache = CreateRandom(past_present_size); + return merged; +} - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); +template +static std::vector ReorderKVByCacheIndirection(const std::vector& key_or_value, + const int32_t* cache_indirection, + int batch_size, int beam_width, int max_sequence_length, + int num_heads, int head_size, int past_sequence_length) { + std::vector reordered = key_or_value; - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + for (int b = 0; b < batch_size; ++b) { + int beam_batch_index = b / beam_width; + const int* beam_indices = cache_indirection + b * max_sequence_length; + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < past_sequence_length; ++s) { + int beam_offset = beam_indices[s] * num_heads * max_sequence_length * head_size; + int beam_batch_offset = (beam_batch_index * beam_width * num_heads + n) * max_sequence_length * head_size; + for (int h = 0; h < head_size; ++h) { + reordered[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h] = + key_or_value[beam_offset + beam_batch_offset + s * head_size + h]; + } + } + } + } - tester.AddInput("past", past_dims, reordered_kv_cache); + return reordered; +} - // Rel - tester.AddOptionalInputEdge(); +template +static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool use_cuda = true) { + int batch_size = 8; + int past_sequence_length = 2; + int kv_sequence_length = 16; + int head_size = 32; + int num_heads = 12; + int beam_width = 4; + int hidden_size = head_size * num_heads; + + OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); + FixedPatternValueGenerator generator{}; + RandomValueGenerator random{}; + + // Attributes + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(!is_cross_attn)); + // Output scaled Q * K^T by default for cross-attention + tester.AddAttribute("output_qk", static_cast(is_cross_attn)); + + // Inputs and outputs + auto query = CreateRandom(batch_size * 1 * hidden_size); + tester.AddInput("query", {batch_size, 1, hidden_size}, query); + + if (is_cross_attn) { + auto key = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + std::vector reordered_key; + if (use_cuda) { + reordered_key = ReorderKVCache(key, batch_size, num_heads, + kv_sequence_length, head_size, kv_sequence_length, false); + } + auto value = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + tester.AddInput("key", {batch_size, num_heads, kv_sequence_length, head_size}, (use_cuda ? reordered_key : key)); + tester.AddInput("value", {batch_size, num_heads, kv_sequence_length, head_size}, + CreateRandom(batch_size * num_heads * kv_sequence_length * head_size)); + + const std::vector mask_index_dims = {batch_size, kv_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, kv_sequence_length}, mask_index); + + // Calculate Softmax(Q * K^T + (Optional) mask) * V + std::vector empty_attention_bias; + auto output_qk = CalculateOutputQK(query, key, mask_index, empty_attention_bias, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + std::vector output_qk_float(output_qk.size()); + for (size_t i = 0; i < output_qk.size(); ++i) { + output_qk_float[i] = static_cast(output_qk[i]); + } + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, kv_sequence_length); + auto output = CalculateOutput(softmax, value, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOptionalOutputEdge(); // optional present_key + tester.AddOptionalOutputEdge(); // optional present_value + tester.AddOutput("qk", {batch_size, num_heads, 1, kv_sequence_length}, output_qk_float); + } else { + int max_sequence_length = past_sequence_length + 10; + int total_sequence_length = past_sequence_length + 1; + + auto key = CreateRandom(batch_size * hidden_size); + auto value = CreateRandom(batch_size * hidden_size); + tester.AddInput("key", {batch_size, 1, hidden_size}, key); + tester.AddInput("value", {batch_size, 1, hidden_size}, value); + + const std::vector mask_index_dims = {batch_size, total_sequence_length}; + auto mask_index = generator.Discrete(mask_index_dims, AsSpan({0, 1})); + tester.AddInput("mask_index", {batch_size, total_sequence_length}, mask_index); + std::vector attention_bias_dims = {1, 1, 1, total_sequence_length}; + auto attention_bias_float = random.Gaussian(attention_bias_dims, 0.0f, 0.3f); + std::vector attention_bias(attention_bias_float.size()); + for (size_t i = 0; i < attention_bias.size(); ++i) { + attention_bias[i] = static_cast(attention_bias_float[i]); + } + tester.AddInput("attention_bias", {1, 1, 1, total_sequence_length}, attention_bias); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + auto past_key = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); + auto past_value = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + std::vector reordered_past_key; // For CUDA, we need to reorder past key + if (use_cuda) { + reordered_past_key = ReorderKVCache(past_key, batch_size, num_heads, + past_sequence_length, head_size, max_sequence_length, false); + } - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + tester.AddInput("past_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? reordered_past_key : past_key)); + tester.AddInput("past_value", {batch_size, num_heads, max_sequence_length, head_size}, past_value); + + // merge past key and value with current key and value + auto merged_key = MergePast(past_key, key, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + std::vector merged_reordered_key; + if (use_cuda) { + merged_reordered_key = MergeReorderedKVCacheWithK(reordered_past_key, key.data(), batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size, false); + } + auto merged_value = MergePast(past_value, value, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + + tester.AddInput("past_sequence_length", {1}, {past_sequence_length}); + + std::vector mod_merged_key, mod_merged_value; + if (beam_width > 1) { + tester.AddInput("beam_width", {1}, {beam_width}); + + const std::vector cache_indir_dims = {batch_size, beam_width, max_sequence_length}; + auto value_candidates = ValueRange(beam_width); + auto cache_indir = generator.Discrete(cache_indir_dims, value_candidates); + tester.AddInput("cache_indirection", cache_indir_dims, cache_indir); + + // Modify merged_key and merged_value according to cache_indirection + mod_merged_key = ReorderKVByCacheIndirection(merged_key, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + mod_merged_value = ReorderKVByCacheIndirection(merged_value, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + } - auto k_merged = pair.first; - auto k_transpose = pair.second; + // Calculate Softmax(Q * K^T + (Optional) mask) * V + auto output_qk = CalculateOutputQK(query, (beam_width > 1 ? mod_merged_key : merged_key), + mask_index, attention_bias, + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, 1, total_sequence_length); + auto output = CalculateOutput(softmax, (beam_width > 1 ? mod_merged_value : merged_value), + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOutput("present_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? merged_reordered_key : merged_key)); + tester.AddOutput("present_value", {batch_size, num_heads, max_sequence_length, head_size}, merged_value); + } - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + if (std::is_same::value) { + tester.SetOutputTolerance(0.02f); + } else { + tester.SetOutputTolerance(0.0001f, 0.0001f); + } - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + { + std::vector> execution_providers; + if (use_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } else { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +#ifdef USE_CUDA - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { + TestDecoderMaskedSelfAttention(); +} - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); +TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { + TestDecoderMaskedSelfAttention(); +} - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(); +} - // Output(s) - tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(); +} - tester.SetOutputTolerance(0.005f); +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; +#endif - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ true, /* use_cuda = */ false); } -#endif +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false, /* use_cuda = */ false); +} } // namespace test } // namespace onnxruntime