Skip to content

Commit

Permalink
Add unit tests for DecoderMaskedMultiHeadAttention.
Browse files Browse the repository at this point in the history
  • Loading branch information
mindest committed Oct 22, 2024
1 parent d79c3c3 commit 9df4782
Showing 1 changed file with 278 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ static std::vector<T> CreateOnes(int size) {
f.reserve(size);

for (int i = 0; i < size; ++i) {
f.push_back(T(1));
f.push_back(T(1.0f));
}

return f;
}

template <typename T>
static std::vector<T> CreateValues(int size, int val) {
static std::vector<T> CreateValues(int size, float val) {
std::vector<T> f;
f.reserve(size);

Expand Down Expand Up @@ -88,7 +88,7 @@ float ToFloat(MLFloat16 val) {
// QKV
template <typename T>
static std::vector<T> QKV(std::vector<T>& input, std::vector<T>& weights, std::vector<T>& bias,
int batch_size, int sequence_length, int hidden_size) {
int batch_size, int sequence_length, int hidden_size) {
std::vector<T> qkv;
qkv.resize(batch_size * sequence_length * 3 * hidden_size, static_cast<T>(0.f));

Expand Down Expand Up @@ -167,15 +167,17 @@ void CheckEquality(T* data_1, T* data_2, int batch_size, int num_heads, int num_
// Reorder 'K' from [B, N, S, H] to [B, N, H/x, S, x] where x = (sizeof(T) / 16);
// Copy 'V' over as is
template <typename T>
static std::vector<T> ReorderKVCache(std::vector<T>& unordered_k_cache,
static std::vector<T> ReorderKVCache(const std::vector<T>& unordered_k_cache,
int batch_size, int num_heads, int sequence_length,
int head_size, int max_sequence_length) {
int head_size, int max_sequence_length, bool merge_past_kv = true) {
std::vector<T> ordered(unordered_k_cache.size(), T{0.f});

// Copy V over
size_t v_start = unordered_k_cache.size() / 2;
for (size_t i = v_start; i < unordered_k_cache.size(); ++i) {
ordered[i] = unordered_k_cache[i];
if (merge_past_kv) {
size_t v_start = unordered_k_cache.size() / 2;
for (size_t i = v_start; i < unordered_k_cache.size(); ++i) {
ordered[i] = unordered_k_cache[i];
}
}

// Now let us re-order K and copy it over to the final buffer
Expand Down Expand Up @@ -212,7 +214,7 @@ static std::vector<T> MergeReorderedKVCacheWithK(std::vector<T>& ordered_k_cache
T* k,
int batch_size, int num_heads,
int past_sequence_length, int max_sequence_length,
int head_size) {
int head_size, bool merge_past_kv = true) {
std::vector<T> merged = ordered_k_cache;

int total_seq_length = past_sequence_length + 1;
Expand All @@ -237,10 +239,11 @@ static std::vector<T> MergeReorderedKVCacheWithK(std::vector<T>& ordered_k_cache
input_value = ordered_k_cache[input_offset];
} else {
int hidden_size = num_heads * head_size;
int input_offset = (b * 3 * hidden_size) +
(n * num_chunks * chunk_size) +
(c * chunk_size) +
h;
int input_offset = merge_past_kv ? ((b * 3 * hidden_size) +
(n * num_chunks * chunk_size) +
(c * chunk_size) +
h)
: ((b * hidden_size) + n * head_size + c * chunk_size + h);
input_value = k[input_offset];
}

Expand All @@ -260,7 +263,7 @@ static std::vector<T> MergeReorderedKVCacheWithK(std::vector<T>& ordered_k_cache
return merged;
}

// GIven a pointer to the 'V' component of the past cache, we will merge it
// Given a pointer to the 'V' component of the past cache, we will merge it
// with current 'V' in-place
template <typename T>
static void MergeReorderedKVCacheWithV(T* v_cache,
Expand Down Expand Up @@ -306,7 +309,7 @@ static std::pair<std::vector<T>, std::vector<T>> MergePastKWithPresentKAndTransp
input_value = past_k[input_offset];
} else {
int hidden_size = num_heads * head_size;
// Offset by 3* hidden_size because QKV data contains Q, K, and V per batch
// Offset by 3 * hidden_size because QKV data contains Q, K, and V per batch
int input_offset = (b * 3 * hidden_size) + (n * head_size) + h;
input_value = present_k[input_offset];
}
Expand Down Expand Up @@ -374,7 +377,7 @@ void ValidateReorderedMergedKWithK(T* k, T* k_cache, int batch_size, int num_hea
// QK_Transpose
template <typename T>
std::vector<T> QK_Transpose(T* q_matrix, T* k_transpose_matrix,
int batch_size, int num_heads, int total_sequence_length, int head_size) {
int batch_size, int num_heads, int total_sequence_length, int head_size) {
int hidden_size = num_heads * head_size;

std::vector<T> qk_transpose;
Expand Down Expand Up @@ -454,9 +457,9 @@ std::vector<T> Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int
template <typename T>
std::vector<T> Softmax_QK_Transpose_V(T* softmax_qk_transpose_matrix,
T* v_matrix,
int batch_size, int num_heads, int sequence_length,
int total_sequence_length, int max_sequence_length,
int head_size) {
int batch_size, int num_heads, int sequence_length,
int total_sequence_length, int max_sequence_length,
int head_size) {
if (sequence_length != 1) {
throw std::runtime_error("Not supported");
}
Expand Down Expand Up @@ -641,6 +644,238 @@ static void TestDecoderMaskedSelfAttention() {
}
}

template <typename T>
static std::vector<T> CalculateOutputQK(const std::vector<T>& q, const std::vector<T>& k,
const std::vector<int32_t>& mask_index, const std::vector<T>& attention_bias,
int batch_size, int num_heads,
int sequence_length, int max_sequence_length, int head_size) {
// q (B, 1, NH), k (B, N, L(M), H) -> qk (B, N, 1, L)
// mask_index (B, L), (optional) attention_bias (1, 1, 1, L)
float scale = 1 / sqrt(static_cast<float>(head_size));
std::vector<T> output_qk;
output_qk.resize(batch_size * num_heads * sequence_length, static_cast<T>(0.f));
for (int b = 0; b < batch_size; ++b) {
for (int n = 0; n < num_heads; ++n) {
for (int s = 0; s < sequence_length; ++s) {
float mask_value = (mask_index[b * sequence_length + s] == 0) ? -10000.f : 0.f;
float bias_value = (attention_bias.empty()) ? 0.f : ToFloat(attention_bias[s]);
float sum = 0;
for (int h = 0; h < head_size; ++h) {
sum += ToFloat(q[b * num_heads * head_size + n * head_size + h]) *
ToFloat(k[b * num_heads * max_sequence_length * head_size +
n * max_sequence_length * head_size + s * head_size + h]);
}

output_qk[b * num_heads * sequence_length + n * sequence_length + s] =
static_cast<T>(scale * (sum + mask_value + bias_value));
}
}
}

return output_qk;
}

template <typename T>
static std::vector<T> CalculateOutput(const std::vector<T>& softmax, const std::vector<T>& v, int batch_size,
int num_heads, int sequence_length, int max_sequence_length, int head_size) {
// softmax (B, N, 1, L) v (B, N, L(M), H) -> output (B, N, 1, H)
std::vector<T> output;
output.resize(batch_size * num_heads * head_size, static_cast<T>(0.f));
for (int b = 0; b < batch_size; ++b) {
for (int n = 0; n < num_heads; ++n) {
for (int h = 0; h < head_size; ++h) {
float sum = 0;
for (int s = 0; s < sequence_length; ++s) {
sum += ToFloat(softmax[b * num_heads * sequence_length + n * sequence_length + s]) *
ToFloat(v[b * num_heads * max_sequence_length * head_size +
n * max_sequence_length * head_size + s * head_size + h]);
}

output[b * num_heads * head_size + n * head_size + h] = static_cast<T>(sum);
}
}
}

return output;
}

template <typename T>
static std::vector<T> MergePast(const std::vector<T>& past, const std::vector<T>& current, int batch_size,
int num_heads, int past_seq_len, int max_seq_len, int head_size) {
// past (B, N, S(M), H), current (B, 1, NH) -> merged (B, N, S+1(M), H)
std::vector<T> merged = past;
for (int b = 0; b < batch_size; ++b) {
for (int n = 0; n < num_heads; ++n) {
for (int h = 0; h < head_size; ++h) {
merged[b * num_heads * max_seq_len * head_size + n * max_seq_len * head_size + past_seq_len * head_size + h] =
current[b * num_heads * head_size + n * head_size + h];
}
}
}

return merged;
}

template <typename T>
static std::vector<T> ReorderKVByCacheIndirection(const std::vector<T>& key_or_value,
const int32_t* cache_indirection,
int batch_size, int beam_width, int max_sequence_length,
int num_heads, int head_size, int past_sequence_length) {
std::vector<T> reordered = key_or_value;

for (int b = 0; b < batch_size; ++b) {
int beam_batch_index = b / beam_width;
const int* beam_indices = cache_indirection + b * max_sequence_length;
for (int n = 0; n < num_heads; ++n) {
for (int s = 0; s < past_sequence_length; ++s) {
int beam_offset = beam_indices[s] * num_heads * max_sequence_length * head_size;
int beam_batch_offset = (beam_batch_index * beam_width * num_heads + n) * max_sequence_length * head_size;
for (int h = 0; h < head_size; ++h) {
reordered[b * num_heads * max_sequence_length * head_size +
n * max_sequence_length * head_size + s * head_size + h] =
key_or_value[beam_offset + beam_batch_offset + s * head_size + h];
}
}
}
}

return reordered;
}

template <typename T>
static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool use_cuda = true) {
int batch_size = 8;
int past_sequence_length = 2;
int kv_sequence_length = 16;
int head_size = 32;
int num_heads = 12;
int beam_width = 4;
int hidden_size = head_size * num_heads;

OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain);

// Attributes
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(!is_cross_attn));
// Output scaled Q * K^T by default for self-attention
tester.AddAttribute<int64_t>("output_qk", static_cast<int64_t>(!is_cross_attn));

// Inputs and outputs
auto query = CreateRandom<T>(batch_size * 1 * hidden_size);
tester.AddInput<T>("query", {batch_size, 1, hidden_size}, query);

if (is_cross_attn) {
auto key = CreateRandom<T>(batch_size * num_heads * kv_sequence_length * head_size);
auto value = CreateRandom<T>(batch_size * num_heads * kv_sequence_length * head_size);
tester.AddInput<T>("key", {batch_size, num_heads, kv_sequence_length, head_size}, key);
tester.AddInput<T>("value", {batch_size, num_heads, kv_sequence_length, head_size},
CreateRandom<T>(batch_size * num_heads * kv_sequence_length * head_size));

auto mask_index = CreateOnes<int32_t>(batch_size * kv_sequence_length);
tester.AddInput<int32_t>("mask_index", {batch_size, kv_sequence_length}, mask_index);

// Calculate Softmax(Q * K^T + (Optional) mask) * V
std::vector<T> empty_attention_bias;
auto output_qk = CalculateOutputQK(query, key, mask_index, empty_attention_bias, batch_size, num_heads,
kv_sequence_length, kv_sequence_length, head_size);
auto softmax = Softmax_QK_Transpose<T>(output_qk.data(), batch_size, num_heads,
1, kv_sequence_length, head_size);
auto output = CalculateOutput<T>(softmax, value, batch_size, num_heads,
kv_sequence_length, kv_sequence_length, head_size);

tester.AddOutput<T>("output", {batch_size, 1, hidden_size}, output);
} else {
int max_sequence_length = past_sequence_length + 10;
int total_sequence_length = past_sequence_length + 1;

auto key = CreateRandom<T>(batch_size * hidden_size);
auto value = CreateRandom<T>(batch_size * hidden_size);
tester.AddInput<T>("key", {batch_size, 1, hidden_size}, key);
tester.AddInput<T>("value", {batch_size, 1, hidden_size}, value);

auto mask_index = CreateOnes<int32_t>(batch_size * total_sequence_length);
auto attention_bias = CreateValues<T>(total_sequence_length, 0);
tester.AddInput<int32_t>("mask_index", {batch_size, total_sequence_length}, mask_index);
tester.AddInput<T>("attention_bias", {1, 1, 1, total_sequence_length}, attention_bias);

auto past_key = CreateRandom<T>(batch_size * num_heads * max_sequence_length * head_size);
auto past_value = CreateRandom<T>(batch_size * num_heads * max_sequence_length * head_size);

std::vector<T> reordered_past_key; // For CUDA, we need to reorder past key
if (use_cuda) {
reordered_past_key = ReorderKVCache<T>(past_key, batch_size, num_heads,
past_sequence_length, head_size, max_sequence_length, false);
}

tester.AddInput<T>("past_key", {batch_size, num_heads, max_sequence_length, head_size},
(use_cuda ? reordered_past_key : past_key));
tester.AddInput<T>("past_value", {batch_size, num_heads, max_sequence_length, head_size}, past_value);

// merge past key and value with current key and value
auto merged_key = MergePast<T>(past_key, key, batch_size, num_heads,
past_sequence_length, max_sequence_length, head_size);
std::vector<T> merged_reordered_key;
if (use_cuda) {
merged_reordered_key = MergeReorderedKVCacheWithK<T>(reordered_past_key, key.data(), batch_size, num_heads,
past_sequence_length, max_sequence_length, head_size, false);
}
auto merged_value = MergePast<T>(past_value, value, batch_size, num_heads,
past_sequence_length, max_sequence_length, head_size);

tester.AddInput<int32_t>("past_sequence_length", {1}, {past_sequence_length});

std::vector<T> mod_merged_key, mod_merged_value;
if (beam_width > 1) {
tester.AddInput<int32_t>("beam_width", {1}, {beam_width});

const std::vector<int64_t> cache_indir_dims = {batch_size, beam_width, max_sequence_length};
auto value_candidates = ValueRange<int32_t>(beam_width);
FixedPatternValueGenerator generator{};
auto cache_indir = generator.Discrete<int32_t>(cache_indir_dims, value_candidates);
tester.AddInput<int32_t>("cache_indirection", cache_indir_dims, cache_indir);

// Modify merged_key and merged_value according to cache_indirection
mod_merged_key = ReorderKVByCacheIndirection<T>(merged_key, cache_indir.data(),
batch_size, beam_width, max_sequence_length,
num_heads, head_size, past_sequence_length);
mod_merged_value = ReorderKVByCacheIndirection<T>(merged_value, cache_indir.data(),
batch_size, beam_width, max_sequence_length,
num_heads, head_size, past_sequence_length);
}

// Calculate Softmax(Q * K^T + (Optional) mask) * V
auto output_qk = CalculateOutputQK<T>(query, (beam_width > 1 ? mod_merged_key : merged_key),
mask_index, attention_bias,
batch_size, num_heads, total_sequence_length, max_sequence_length, head_size);
auto softmax = Softmax_QK_Transpose<T>(output_qk.data(),
batch_size, num_heads, 1, total_sequence_length, head_size);
auto output = CalculateOutput<T>(softmax, (beam_width > 1 ? mod_merged_value : merged_value),
batch_size, num_heads, total_sequence_length, max_sequence_length, head_size);

tester.AddOutput<T>("output", {batch_size, 1, hidden_size}, output);
tester.AddOutput<T>("present_key", {batch_size, num_heads, max_sequence_length, head_size},
(use_cuda ? merged_reordered_key : merged_key));
tester.AddOutput<T>("present_value", {batch_size, num_heads, max_sequence_length, head_size}, merged_value);
tester.AddOutput<T>("output_qk", {batch_size, num_heads, 1, total_sequence_length}, output_qk);
}

if (std::is_same<T, MLFloat16>::value) {
tester.SetOutputTolerance(0.005f);
} else {
tester.SetOutputTolerance(0.001f, 0.001f);
}

{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (use_cuda) {
execution_providers.push_back(DefaultCudaExecutionProvider());
} else {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {
TestDecoderMaskedSelfAttention<float>();
}
Expand All @@ -649,6 +884,30 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
TestDecoderMaskedSelfAttention<MLFloat16>();
}

TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp32) {
TestDecoderMaskedMultiHeadAttention<float>();
}

TEST(DecoderMaskedMultiHeadAttentionTest, cuda_cross_attn_fp16) {
TestDecoderMaskedMultiHeadAttention<MLFloat16>();
}

TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp32) {
TestDecoderMaskedMultiHeadAttention<float>(/* is_cross_attn = */ false);
}

TEST(DecoderMaskedMultiHeadAttentionTest, cuda_self_attn_fp16) {
TestDecoderMaskedMultiHeadAttention<MLFloat16>(/* is_cross_attn = */ false);
}

TEST(DecoderMaskedMultiHeadAttentionTest, cpu_cross_attn_fp32) {
TestDecoderMaskedMultiHeadAttention<float>(/* is_cross_attn = */ true, /* use_cuda = */ false);
}

TEST(DecoderMaskedMultiHeadAttentionTest, cpu_self_attn_fp32) {
TestDecoderMaskedMultiHeadAttention<float>(/* is_cross_attn = */ false, /* use_cuda = */ false);
}

#endif

} // namespace test
Expand Down

0 comments on commit 9df4782

Please sign in to comment.