Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Decoding support in GQA #21523

Merged
merged 28 commits into from
Sep 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
single batch implementation unclean
  • Loading branch information
aciddelgado committed Aug 6, 2024
commit 60fe746d4e5e184b5dce5c438413e8180abe5d5f
66 changes: 31 additions & 35 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@
Tensor* present_key, // present K output tensor (if separating present KV)
Tensor* present_value, // present V output tensor (if separating present KV)
const Tensor* seqlens_k, // past sequence lengths tensor
const Tensor* seqlens_q, // past sequence lengths tensor
// const Tensor* seqlens_q, // past sequence lengths tensor
GroupQueryAttentionParameters& parameters, // attention parameters
AllocatorPtr allocator, // allocator for temporary tensors
OpKernelContext* context) const {
@@ -84,14 +84,14 @@

const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(),
seqlens_q->Data<int32_t>(), batch_size, sequence_length, seqlen_past_kv_cache,
/*seqlens_q->Data<int32_t>(),*/ batch_size, sequence_length, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, past_key_data, present_key_data,
past_present_share_buffer, packed_qkv, is_interactive, is_prompt, tp);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v, seqlens_k->Data<int32_t>(),
seqlens_q->Data<int32_t>(), batch_size, sequence_length, seqlen_past_kv_cache,
/*seqlens_q->Data<int32_t>(),*/ batch_size, sequence_length, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
past_present_share_buffer, packed_qkv, is_interactive, is_prompt, tp);

@@ -107,15 +107,15 @@
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const int32_t* seqlens_q, // (optional) new sequence lengths tensor
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int past_buffer_sequence_length, // sequence length of past state
int present_buffer_sequence_length, // sequence length of present state
int head_size, // head size of self-attention
// const int32_t* seqlens_q, // (optional) new sequence lengths tensor
const size_t batch_size, // batch size of self-attention
const size_t sequence_length, // sequence length of self-attention (S)
const size_t past_buffer_sequence_length, // sequence length of past state
const size_t present_buffer_sequence_length, // sequence length of present state
const size_t head_size, // head size of self-attention
const T* past_key, // past key only
T* present_key, // present key only
const bool past_present_share_buffer, // whether present key and value share the same buffer

Check warning on line 118 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:118: Lines should be <= 120 characters long [whitespace/line_length] [2]
const bool packed_qkv, // whether Q, K, V are packed
const bool is_interactive, // whether it is interactive
const bool is_prompt, // whether it is prompt
@@ -123,7 +123,7 @@
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
@@ -156,11 +156,10 @@

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = static_cast<size_t>(i) / static_cast<size_t>(num_heads_);
const size_t head_index = static_cast<size_t>(i) % static_cast<size_t>(num_heads_);
const size_t batch_index = i / num_heads_;
const size_t head_index = i % num_heads_;
const size_t total_seqlen = seqlens_k[batch_index] + 1;
Fixed Show fixed Hide fixed
const size_t past_seqlen = is_interactive ? total_seqlen - static_cast<size_t>(seqlens_q[batch_index])
: (is_prompt ? 0 : total_seqlen - 1);
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
const size_t past_chunk_length = past_seqlen * head_size;

const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
@@ -190,15 +189,14 @@
q = Q + q_input_chunk_length * i;
}

const size_t q_seqlen = is_interactive ? static_cast<size_t>(seqlens_q[batch_index]) : sequence_length;
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, q_seqlen, total_seqlen, head_size, alpha, q,
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length,
nullptr);

// compute Softmax
T* output_softmax = output;
for (size_t seq = 0; seq < q_seqlen; seq++) {
size_t seq_causal_length = is_interactive ? past_seqlen + seq + 1 : (is_prompt ? seq + 1 : total_seqlen);
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;
if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) {
output_softmax[total_seq_id] = 0.f;
@@ -225,16 +223,16 @@
const T* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const int32_t* seqlens_q, // (optional) new sequence lengths tensor
int batch_size, // batch size
int sequence_length, // sequence length
int past_buffer_sequence_length, // sequence length in past state
int present_buffer_sequence_length, // sequence length in past state
int head_size, // head size of Q, K, V
int hidden_size, // hidden size of Output
// const int32_t* seqlens_q, // (optional) new sequence lengths tensor
const size_t batch_size, // batch size
const size_t sequence_length, // sequence length
const size_t past_buffer_sequence_length, // sequence length in past state
const size_t present_buffer_sequence_length, // sequence length in past state
const size_t head_size, // head size of Q, K, V
const size_t hidden_size, // hidden size of Output
const T* past_value, // past value only
T* present_value, // present value only
const bool past_present_share_buffer, // whether present key and value share the same buffer

Check warning on line 235 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:235: Lines should be <= 120 characters long [whitespace/line_length] [2]
const bool packed_qkv, // whether Q, K, V are packed
const bool is_interactive, // whether it is interactive
const bool is_prompt, // whether it is prompt
@@ -242,10 +240,10 @@
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const int kv_input_chunk_length = sequence_length * head_size; // L x H
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t kv_input_chunk_length = sequence_length * head_size; // L x H

Check warning on line 244 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:244: Lines should be <= 120 characters long [whitespace/line_length] [2]
const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
@@ -274,11 +272,10 @@

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
const size_t batch_index = i / num_heads_;
const size_t head_index = i % num_heads_;
const size_t total_seqlen = seqlens_k[batch_index] + 1;
Fixed Show fixed Hide fixed
const size_t past_seqlen = is_interactive ? total_seqlen - static_cast<size_t>(seqlens_q[batch_index])
: (is_prompt ? 0 : total_seqlen - 1);
const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length;
const size_t past_chunk_length = past_seqlen * head_size;

const T* v;
@@ -296,8 +293,7 @@
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

size_t q_seqlen = is_interactive ? static_cast<size_t>(seqlens_q[batch_index]) : sequence_length;
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, q_seqlen, head_size, total_seqlen, 1.f, /*alpha*/
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
attention_probs + attention_probs_offset, present_buffer_sequence_length, v,
head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr);
}
21 changes: 9 additions & 12 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);
const Tensor* seqlens_q = context->Input<Tensor>(9);
// const Tensor* seqlens_q = context->Input<Tensor>(9);

GroupQueryAttentionParameters parameters = {};
constexpr float scale = 1.0f;
@@ -63,7 +63,7 @@
num_heads_,
kv_num_heads_,
seqlens_k,
seqlens_q,
// seqlens_q,
total_seqlen,
scale));

@@ -121,26 +121,23 @@
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();
// Generate position ids
const int pos_ids_size = parameters.is_interactive ? batch_size * sequence_length : (parameters.is_prompt ? 1 : batch_size);
const int pos_ids_size = (parameters.is_prompt && !parameters.is_interactive) ? 1 : batch_size * sequence_length;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> pos_ids(pos_ids_size);
if (parameters.is_interactive) {
if (parameters.is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
// Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1.

Check warning on line 129 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:129: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (int b = 0; b < batch_size; b++) {
for (int s = 0; s < sequence_length; s++) {
const int total_seqlen = seqlens_k->Data<int32_t>()[b] + 1;
const int past_seqlen = total_seqlen - seqlens_q->Data<int32_t>()[b];
const int past_seqlen = total_seqlen - sequence_length;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
if (past_seqlen + s < total_seqlen) {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(past_seqlen + s);
Fixed Show fixed Hide fixed
} else {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(1);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
}
}
}
} else if (parameters.is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
}
}
// Initialize separate buffers for rotary embeddings
const T* q_input;
@@ -199,7 +196,7 @@
// Compute the attention score and apply the score to V
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value, output, present_k, present_v,
seqlens_k, seqlens_q, parameters, allocator, context);
seqlens_k,/* seqlens_q,*/ parameters, allocator, context);

Check warning on line 199 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing space after , [whitespace/comma] [3] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:199: Missing space after , [whitespace/comma] [3]
}
} // namespace contrib
} // namespace onnxruntime
24 changes: 12 additions & 12 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* seqlens_q,
// const Tensor* seqlens_q,
const Tensor* total_seqlen,
float scale) {
// Note: Here S* is past_cache_sequence_length, S+ is seqlen_present_kv_cache
@@ -174,15 +174,6 @@
"seqlens_k must be shape (batch_size).");
}

bool is_interactive = seqlens_q != nullptr;
if (is_interactive) {
const auto& seqlens_q_dim = seqlens_q->Shape().GetDims();
if (seqlens_q_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_q must be shape (batch_size) when it is present.");
}
}

// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -227,6 +218,15 @@
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

bool is_interactive = false;
if (sequence_length > 1 && sequence_length != total_sequence_length) {
if (batch_size != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"batch_size must be 1 when sequence_length > 1 and past context is given.");
}
is_interactive = true;
}

bool is_prompt;
if (is_interactive) {
is_prompt = false; // irrelevant for interactive decoding
@@ -274,15 +274,15 @@
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* seqlens_q,
// const Tensor* seqlens_q,
const Tensor* total_seqlen,
float scale,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, seqlens_q, total_seqlen, scale);
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k,/* seqlens_q,*/ total_seqlen, scale);

Check warning on line 285 in onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:285: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 285 in onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing space after , [whitespace/comma] [3] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h:285: Missing space after , [whitespace/comma] [3]
}
} // namespace group_query_attention_helper
} // namespace contrib
Original file line number Diff line number Diff line change
@@ -42,19 +42,20 @@

auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;

// // When seqstart_k_ptr is provided, we interpret it as past sequence length. This is used for interactive mode in GQA

Check warning on line 45 in onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h:45: Lines should be <= 120 characters long [whitespace/line_length] [2]
// if (p.seqstart_k_ptr) {
// p.num_keys = p.seqstart_k_ptr[batch_id] + p.num_queries;
// } else if (p.seqlen_k_ptr) {
// p.num_keys = p.seqlen_k_ptr[batch_id];
// }

// When seqstart_q_ptr is provided we interpret it as new sequence length, we use it to calculate past sequence length.

Check warning on line 52 in onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h:52: Lines should be <= 120 characters long [whitespace/line_length] [2]
// Used primarily in interactive mode in GQA.
if (p.seqstart_q_ptr && p.seqlen_k_ptr) {
const int past_seqlen = p.seqlen_k_ptr[batch_id] - p.seqstart_q_ptr[batch_id];
p.num_keys = past_seqlen + p.num_queries;
} else if (p.seqlen_k_ptr) {
// if (p.seqstart_q_ptr && p.seqlen_k_ptr) {
// const int past_seqlen = p.seqlen_k_ptr[batch_id] - p.seqstart_q_ptr[batch_id];
// p.num_keys = past_seqlen + p.num_queries;
// } else
if (p.seqlen_k_ptr) {
p.num_keys = p.seqlen_k_ptr[batch_id];
}

Loading
Loading