From 9df4782c96c0ee06daf50cef454172662d311e83 Mon Sep 17 00:00:00 2001 From: mindest Date: Tue, 22 Oct 2024 15:38:09 +0000 Subject: [PATCH] Add unit tests for DecoderMaskedMultiHeadAttention. --- ...oder_masked_multihead_attention_op_test.cc | 297 ++++++++++++++++-- 1 file changed, 278 insertions(+), 19 deletions(-) diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index d03ba3eac916d..b01d9a0c4a636 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -24,14 +24,14 @@ static std::vector CreateOnes(int size) { f.reserve(size); for (int i = 0; i < size; ++i) { - f.push_back(T(1)); + f.push_back(T(1.0f)); } return f; } template -static std::vector CreateValues(int size, int val) { +static std::vector CreateValues(int size, float val) { std::vector f; f.reserve(size); @@ -88,7 +88,7 @@ float ToFloat(MLFloat16 val) { // QKV template static std::vector QKV(std::vector& input, std::vector& weights, std::vector& bias, - int batch_size, int sequence_length, int hidden_size) { + int batch_size, int sequence_length, int hidden_size) { std::vector qkv; qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast(0.f)); @@ -167,15 +167,17 @@ void CheckEquality(T* data_1, T* data_2, int batch_size, int num_heads, int num_ // Reorder 'K' from [B, N, S, H] to [B, N, H/x, S, x] where x = (sizeof(T) / 16); // Copy 'V' over as is template -static std::vector ReorderKVCache(std::vector& unordered_k_cache, +static std::vector ReorderKVCache(const std::vector& unordered_k_cache, int batch_size, int num_heads, int sequence_length, - int head_size, int max_sequence_length) { + int head_size, int max_sequence_length, bool merge_past_kv = true) { std::vector ordered(unordered_k_cache.size(), T{0.f}); // Copy V over - size_t v_start = unordered_k_cache.size() / 2; - for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { - ordered[i] = unordered_k_cache[i]; + if (merge_past_kv) { + size_t v_start = unordered_k_cache.size() / 2; + for (size_t i = v_start; i < unordered_k_cache.size(); ++i) { + ordered[i] = unordered_k_cache[i]; + } } // Now let us re-order K and copy it over to the final buffer @@ -212,7 +214,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache T* k, int batch_size, int num_heads, int past_sequence_length, int max_sequence_length, - int head_size) { + int head_size, bool merge_past_kv = true) { std::vector merged = ordered_k_cache; int total_seq_length = past_sequence_length + 1; @@ -237,10 +239,11 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache input_value = ordered_k_cache[input_offset]; } else { int hidden_size = num_heads * head_size; - int input_offset = (b * 3 * hidden_size) + - (n * num_chunks * chunk_size) + - (c * chunk_size) + - h; + int input_offset = merge_past_kv ? ((b * 3 * hidden_size) + + (n * num_chunks * chunk_size) + + (c * chunk_size) + + h) + : ((b * hidden_size) + n * head_size + c * chunk_size + h); input_value = k[input_offset]; } @@ -260,7 +263,7 @@ static std::vector MergeReorderedKVCacheWithK(std::vector& ordered_k_cache return merged; } -// GIven a pointer to the 'V' component of the past cache, we will merge it +// Given a pointer to the 'V' component of the past cache, we will merge it // with current 'V' in-place template static void MergeReorderedKVCacheWithV(T* v_cache, @@ -306,7 +309,7 @@ static std::pair, std::vector> MergePastKWithPresentKAndTransp input_value = past_k[input_offset]; } else { int hidden_size = num_heads * head_size; - // Offset by 3* hidden_size because QKV data contains Q, K, and V per batch + // Offset by 3 * hidden_size because QKV data contains Q, K, and V per batch int input_offset = (b * 3 * hidden_size) + (n * head_size) + h; input_value = present_k[input_offset]; } @@ -374,7 +377,7 @@ void ValidateReorderedMergedKWithK(T* k, T* k_cache, int batch_size, int num_hea // QK_Transpose template std::vector QK_Transpose(T* q_matrix, T* k_transpose_matrix, - int batch_size, int num_heads, int total_sequence_length, int head_size) { + int batch_size, int num_heads, int total_sequence_length, int head_size) { int hidden_size = num_heads * head_size; std::vector qk_transpose; @@ -454,9 +457,9 @@ std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int template std::vector Softmax_QK_Transpose_V(T* softmax_qk_transpose_matrix, T* v_matrix, - int batch_size, int num_heads, int sequence_length, - int total_sequence_length, int max_sequence_length, - int head_size) { + int batch_size, int num_heads, int sequence_length, + int total_sequence_length, int max_sequence_length, + int head_size) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } @@ -641,6 +644,238 @@ static void TestDecoderMaskedSelfAttention() { } } +template +static std::vector CalculateOutputQK(const std::vector& q, const std::vector& k, + const std::vector& mask_index, const std::vector& attention_bias, + int batch_size, int num_heads, + int sequence_length, int max_sequence_length, int head_size) { + // q (B, 1, NH), k (B, N, L(M), H) -> qk (B, N, 1, L) + // mask_index (B, L), (optional) attention_bias (1, 1, 1, L) + float scale = 1 / sqrt(static_cast(head_size)); + std::vector output_qk; + output_qk.resize(batch_size * num_heads * sequence_length, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < sequence_length; ++s) { + float mask_value = (mask_index[b * sequence_length + s] == 0) ? -10000.f : 0.f; + float bias_value = (attention_bias.empty()) ? 0.f : ToFloat(attention_bias[s]); + float sum = 0; + for (int h = 0; h < head_size; ++h) { + sum += ToFloat(q[b * num_heads * head_size + n * head_size + h]) * + ToFloat(k[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } + + output_qk[b * num_heads * sequence_length + n * sequence_length + s] = + static_cast(scale * (sum + mask_value + bias_value)); + } + } + } + + return output_qk; +} + +template +static std::vector CalculateOutput(const std::vector& softmax, const std::vector& v, int batch_size, + int num_heads, int sequence_length, int max_sequence_length, int head_size) { + // softmax (B, N, 1, L) v (B, N, L(M), H) -> output (B, N, 1, H) + std::vector output; + output.resize(batch_size * num_heads * head_size, static_cast(0.f)); + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + float sum = 0; + for (int s = 0; s < sequence_length; ++s) { + sum += ToFloat(softmax[b * num_heads * sequence_length + n * sequence_length + s]) * + ToFloat(v[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h]); + } + + output[b * num_heads * head_size + n * head_size + h] = static_cast(sum); + } + } + } + + return output; +} + +template +static std::vector MergePast(const std::vector& past, const std::vector& current, int batch_size, + int num_heads, int past_seq_len, int max_seq_len, int head_size) { + // past (B, N, S(M), H), current (B, 1, NH) -> merged (B, N, S+1(M), H) + std::vector merged = past; + for (int b = 0; b < batch_size; ++b) { + for (int n = 0; n < num_heads; ++n) { + for (int h = 0; h < head_size; ++h) { + merged[b * num_heads * max_seq_len * head_size + n * max_seq_len * head_size + past_seq_len * head_size + h] = + current[b * num_heads * head_size + n * head_size + h]; + } + } + } + + return merged; +} + +template +static std::vector ReorderKVByCacheIndirection(const std::vector& key_or_value, + const int32_t* cache_indirection, + int batch_size, int beam_width, int max_sequence_length, + int num_heads, int head_size, int past_sequence_length) { + std::vector reordered = key_or_value; + + for (int b = 0; b < batch_size; ++b) { + int beam_batch_index = b / beam_width; + const int* beam_indices = cache_indirection + b * max_sequence_length; + for (int n = 0; n < num_heads; ++n) { + for (int s = 0; s < past_sequence_length; ++s) { + int beam_offset = beam_indices[s] * num_heads * max_sequence_length * head_size; + int beam_batch_offset = (beam_batch_index * beam_width * num_heads + n) * max_sequence_length * head_size; + for (int h = 0; h < head_size; ++h) { + reordered[b * num_heads * max_sequence_length * head_size + + n * max_sequence_length * head_size + s * head_size + h] = + key_or_value[beam_offset + beam_batch_offset + s * head_size + h]; + } + } + } + } + + return reordered; +} + +template +static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool use_cuda = true) { + int batch_size = 8; + int past_sequence_length = 2; + int kv_sequence_length = 16; + int head_size = 32; + int num_heads = 12; + int beam_width = 4; + int hidden_size = head_size * num_heads; + + OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain); + + // Attributes + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(!is_cross_attn)); + // Output scaled Q * K^T by default for self-attention + tester.AddAttribute("output_qk", static_cast(!is_cross_attn)); + + // Inputs and outputs + auto query = CreateRandom(batch_size * 1 * hidden_size); + tester.AddInput("query", {batch_size, 1, hidden_size}, query); + + if (is_cross_attn) { + auto key = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + auto value = CreateRandom(batch_size * num_heads * kv_sequence_length * head_size); + tester.AddInput("key", {batch_size, num_heads, kv_sequence_length, head_size}, key); + tester.AddInput("value", {batch_size, num_heads, kv_sequence_length, head_size}, + CreateRandom(batch_size * num_heads * kv_sequence_length * head_size)); + + auto mask_index = CreateOnes(batch_size * kv_sequence_length); + tester.AddInput("mask_index", {batch_size, kv_sequence_length}, mask_index); + + // Calculate Softmax(Q * K^T + (Optional) mask) * V + std::vector empty_attention_bias; + auto output_qk = CalculateOutputQK(query, key, mask_index, empty_attention_bias, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + auto softmax = Softmax_QK_Transpose(output_qk.data(), batch_size, num_heads, + 1, kv_sequence_length, head_size); + auto output = CalculateOutput(softmax, value, batch_size, num_heads, + kv_sequence_length, kv_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + } else { + int max_sequence_length = past_sequence_length + 10; + int total_sequence_length = past_sequence_length + 1; + + auto key = CreateRandom(batch_size * hidden_size); + auto value = CreateRandom(batch_size * hidden_size); + tester.AddInput("key", {batch_size, 1, hidden_size}, key); + tester.AddInput("value", {batch_size, 1, hidden_size}, value); + + auto mask_index = CreateOnes(batch_size * total_sequence_length); + auto attention_bias = CreateValues(total_sequence_length, 0); + tester.AddInput("mask_index", {batch_size, total_sequence_length}, mask_index); + tester.AddInput("attention_bias", {1, 1, 1, total_sequence_length}, attention_bias); + + auto past_key = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); + auto past_value = CreateRandom(batch_size * num_heads * max_sequence_length * head_size); + + std::vector reordered_past_key; // For CUDA, we need to reorder past key + if (use_cuda) { + reordered_past_key = ReorderKVCache(past_key, batch_size, num_heads, + past_sequence_length, head_size, max_sequence_length, false); + } + + tester.AddInput("past_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? reordered_past_key : past_key)); + tester.AddInput("past_value", {batch_size, num_heads, max_sequence_length, head_size}, past_value); + + // merge past key and value with current key and value + auto merged_key = MergePast(past_key, key, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + std::vector merged_reordered_key; + if (use_cuda) { + merged_reordered_key = MergeReorderedKVCacheWithK(reordered_past_key, key.data(), batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size, false); + } + auto merged_value = MergePast(past_value, value, batch_size, num_heads, + past_sequence_length, max_sequence_length, head_size); + + tester.AddInput("past_sequence_length", {1}, {past_sequence_length}); + + std::vector mod_merged_key, mod_merged_value; + if (beam_width > 1) { + tester.AddInput("beam_width", {1}, {beam_width}); + + const std::vector cache_indir_dims = {batch_size, beam_width, max_sequence_length}; + auto value_candidates = ValueRange(beam_width); + FixedPatternValueGenerator generator{}; + auto cache_indir = generator.Discrete(cache_indir_dims, value_candidates); + tester.AddInput("cache_indirection", cache_indir_dims, cache_indir); + + // Modify merged_key and merged_value according to cache_indirection + mod_merged_key = ReorderKVByCacheIndirection(merged_key, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + mod_merged_value = ReorderKVByCacheIndirection(merged_value, cache_indir.data(), + batch_size, beam_width, max_sequence_length, + num_heads, head_size, past_sequence_length); + } + + // Calculate Softmax(Q * K^T + (Optional) mask) * V + auto output_qk = CalculateOutputQK(query, (beam_width > 1 ? mod_merged_key : merged_key), + mask_index, attention_bias, + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + auto softmax = Softmax_QK_Transpose(output_qk.data(), + batch_size, num_heads, 1, total_sequence_length, head_size); + auto output = CalculateOutput(softmax, (beam_width > 1 ? mod_merged_value : merged_value), + batch_size, num_heads, total_sequence_length, max_sequence_length, head_size); + + tester.AddOutput("output", {batch_size, 1, hidden_size}, output); + tester.AddOutput("present_key", {batch_size, num_heads, max_sequence_length, head_size}, + (use_cuda ? merged_reordered_key : merged_key)); + tester.AddOutput("present_value", {batch_size, num_heads, max_sequence_length, head_size}, merged_value); + tester.AddOutput("output_qk", {batch_size, num_heads, 1, total_sequence_length}, output_qk); + } + + if (std::is_same::value) { + tester.SetOutputTolerance(0.005f); + } else { + tester.SetOutputTolerance(0.001f, 0.001f); + } + + { + std::vector> execution_providers; + if (use_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } else { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { TestDecoderMaskedSelfAttention(); } @@ -649,6 +884,30 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { TestDecoderMaskedSelfAttention(); } +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(); +} + +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(); +} + +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} + +TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp16) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false); +} + +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cross_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ true, /* use_cuda = */ false); +} + +TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) { + TestDecoderMaskedMultiHeadAttention(/* is_cross_attn = */ false, /* use_cuda = */ false); +} + #endif } // namespace test