Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Decoding support in GQA #21523

Merged
merged 28 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct GroupQueryAttentionParameters {
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_interactive; // indicates whether seqlens_k is 1 or 2 dimensional. 2-d case enables interactive decoding
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
Expand Down
20 changes: 13 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,25 @@ T* ConcatStateChunkGQA(const T* past,
size_t past_buff_chunk_length,
size_t past_chunk_length,
size_t new_chunk_length,
bool is_prompt,
bool past_present_share_buffer,
std::ptrdiff_t i) {
T* start = present + i * present_buff_chunk_length;

T* p = start;
if (!is_prompt) {
if (!past_present_share_buffer) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;
if (!past_present_share_buffer && past_chunk_length > 0) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;
// if (!is_prompt) {
// if (!past_present_share_buffer) {
// const T* src_past = past + i * past_buff_chunk_length;
// memcpy(p, src_past, past_chunk_length * sizeof(T));
// }
// p += past_chunk_length;
// }
// std::cout << "past_chunk_length: " << past_chunk_length << std::endl;
// std::cout << "new_chunk_length: " << new_chunk_length << std::endl;

memcpy(p, chunk, new_chunk_length * sizeof(T));
return start;
Expand Down
113 changes: 63 additions & 50 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"

#include <iostream>

Check warning on line 14 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: gqa_attention_base.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:14: Found C++ system header after other header. Should be: gqa_attention_base.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {
namespace contrib {

Expand Down Expand Up @@ -53,6 +55,8 @@
GroupQueryAttentionParameters& parameters, // attention parameters
AllocatorPtr allocator, // allocator for temporary tensors
OpKernelContext* context) const {
const bool is_interactive = parameters.is_interactive;
const bool is_prompt = parameters.is_prompt;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int head_size = parameters.head_size;
Expand Down Expand Up @@ -82,14 +86,14 @@
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, tp);
present_key_data, past_present_share_buffer, packed_qkv, is_interactive, is_prompt, tp);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v, seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
tp);
is_interactive, is_prompt, tp);

return Status::OK();
}
Expand All @@ -110,10 +114,11 @@
int head_size, // head size of self-attention
const T* past_key, // past key only
T* present_key, // present key only
bool past_present_share_buffer, // whether present key and value share the same buffer
bool packed_qkv, // whether Q, K, V are packed
const bool past_present_share_buffer,// whether present key and value share the same buffer

Check warning on line 117 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:117: At least two spaces is best between code and comments [whitespace/comments] [2]
const bool packed_qkv, // whether Q, K, V are packed
const bool is_interactive, // whether it is interactive
const bool is_prompt, // whether it is prompt
ThreadPool* tp) const { // thread pool
const bool is_prompt = sequence_length != 1;
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
Expand Down Expand Up @@ -150,12 +155,14 @@

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i) / num_heads_;
const int head_index = static_cast<int>(i) % num_heads_;
const int past_seqlen =
sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t batch_index = static_cast<size_t>(i) / static_cast<size_t>(num_heads_);
const size_t head_index = static_cast<size_t>(i) % static_cast<size_t>(num_heads_);
// const int past_seqlen =
// sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_seqlen =
is_interactive || !is_prompt ? static_cast<size_t>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
const int total_seqlen = seqlens_k[batch_index] + 1;
const size_t total_seqlen = is_interactive ? seqlens_k[batch_size + batch_index] : seqlens_k[batch_index] + 1;

const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
T* output = attention_probs + output_offset;
Expand All @@ -167,9 +174,10 @@
k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_key) {
// TODO(aciddelgado): refactor now that interactive decoding is supported
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
i / kv_num_heads_factor);
!is_interactive && is_prompt ? 0 : past_chunk_length, kv_input_chunk_length,
past_present_share_buffer, i / kv_num_heads_factor);
}

// Compute Q*K' + AttentionMask
Expand All @@ -183,16 +191,19 @@
} else {
q = Q + q_input_chunk_length * i;
}
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,

size_t q_seqlen = is_interactive ? total_seqlen - past_seqlen : sequence_length;
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, q_seqlen, total_seqlen, head_size, alpha, q,
head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length,
nullptr);

// compute Softmax
T* output_softmax = output;
for (int seq = 0; seq < sequence_length; seq++) {
int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1;
if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) {
for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
for (size_t seq = 0; seq < q_seqlen; seq++) {
// int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1; // E
size_t seq_causal_length = is_interactive ? past_seqlen + seq + 1 : (is_prompt ? seq + 1 : total_seqlen);
if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}
ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1,
Expand All @@ -202,7 +213,7 @@
}

// set causal [seq_causal_length, total_seqlen) to 0.f
for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) {
for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
}

Expand All @@ -225,10 +236,11 @@
int hidden_size, // hidden size of Output
const T* past_value, // past value only
T* present_value, // present value only
bool past_present_share_buffer, // whether present key and value share the same buffer
bool packed_qkv, // whether Q, K, V are packed
const bool past_present_share_buffer,// whether present key and value share the same buffer

Check warning on line 239 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:239: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 239 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:239: At least two spaces is best between code and comments [whitespace/comments] [2]
const bool packed_qkv, // whether Q, K, V are packed
const bool is_interactive, // whether it is interactive
const bool is_prompt, // whether it is prompt
ThreadPool* tp) const {
const bool is_prompt = sequence_length != 1;
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
Expand All @@ -241,6 +253,8 @@
memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}

const int loop_len = batch_size * num_heads_;

// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles =
Expand All @@ -260,37 +274,36 @@
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

ThreadPool::TryParallelFor(
tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
const int past_seqlen =
sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
const int total_seqlen = seqlens_k[batch_index] + 1;

const T* v;
if (packed_qkv) {
v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
} else {
v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_value) {
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
i / kv_num_heads_factor);
}
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
const size_t past_seqlen =
is_interactive || !is_prompt ? static_cast<size_t>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
const size_t total_seqlen = is_interactive ? seqlens_k[batch_size + batch_index] : seqlens_k[batch_index] + 1;

T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;
const T* v;
if (packed_qkv) {
v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
Fixed Show fixed Hide fixed
} else {
v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_value) {
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
!is_interactive && is_prompt ? 0 : past_chunk_length, kv_input_chunk_length,
past_present_share_buffer, i / kv_num_heads_factor);
}

math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
1.f, /*alpha*/
attention_probs + attention_probs_offset, present_buffer_sequence_length, v,
head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr);
}
});
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

size_t q_seqlen = is_interactive ? total_seqlen - past_seqlen : sequence_length;
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, q_seqlen, head_size, total_seqlen, 1.f, /*alpha*/
attention_probs + attention_probs_offset, present_buffer_sequence_length, v,
head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr);
}
});
}
};

Expand Down
46 changes: 41 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>

#include <iostream>

Check warning on line 19 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: group_query_attention.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:19: Found C++ system header after other header. Should be: group_query_attention.h, c system, c++ system, other. [build/include_order] [4]

using onnxruntime::concurrency::ThreadPool;

namespace onnxruntime {
Expand Down Expand Up @@ -103,6 +105,7 @@
}

if (do_rotary_) {
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
rotary_params.sequence_length = sequence_length;
Expand All @@ -114,17 +117,48 @@
rotary_params.seq_stride = head_size;
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
rotary_params.position_ids_format = parameters.is_interactive || !parameters.is_prompt ? 1 : 0;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
if (sequence_length == 1) {
// Generate position ids
// std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
// if (sequence_length == 1) {
// for (int b = 0; b < batch_size; b++) {
// pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
// }
// } else {
// pos_ids[0] = static_cast<int64_t>(0);
// }
const int pos_ids_size = parameters.is_interactive ? batch_size * sequence_length : (parameters.is_prompt ? 1 : batch_size);

Check warning on line 132 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:132: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::vector<int64_t> pos_ids(pos_ids_size);
// const int32_t* seqlens_k_data = seqlens_k->Data<const int32_t>();
if (parameters.is_interactive) {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
for (int s = 0; s < sequence_length; s++) {
if (seqlens_k->Data<int32_t>()[b] + s < seqlens_k->Data<int32_t>()[batch_size + b]) {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b] + s);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
Fixed Show fixed Hide fixed
} else {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(1);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
}
}
}
} else {
// print pos_ids
// std::cout << "pos_ids: ";
// for (int i = 0; i < pos_ids_size; i++) {
// if (i % sequence_length == 0) {
// std::cout << std::endl;
// }
// std::cout << pos_ids[i] << " ";
// }
// std::cout << std::endl;
} else if (parameters.is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
}
}
// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
T* q_rotary;
Expand All @@ -149,6 +183,7 @@
Q = RotaryQ;
K = RotaryK;
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
Expand All @@ -161,6 +196,7 @@
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
// Pack V into rotary QKV buffer
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
Expand Down
Loading
Loading