Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Decoding support in GQA #21523

Merged
merged 28 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ struct GroupQueryAttentionParameters {
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool is_interactive; // indicates whether seqlens_k is 1 or 2 dimensional. 2-d case enables interactive decoding
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,16 @@ T* ConcatStateChunkGQA(const T* past,
size_t past_buff_chunk_length,
size_t past_chunk_length,
size_t new_chunk_length,
bool is_prompt,
bool past_present_share_buffer,
std::ptrdiff_t i) {
T* start = present + i * present_buff_chunk_length;

T* p = start;
if (!is_prompt) {
if (!past_present_share_buffer) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;
if (!past_present_share_buffer && past_chunk_length > 0) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;

memcpy(p, chunk, new_chunk_length * sizeof(T));
return start;
Expand Down
155 changes: 81 additions & 74 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Large diffs are not rendered by default.

26 changes: 21 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
}

if (do_rotary_) {
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
rotary_params.sequence_length = sequence_length;
Expand All @@ -114,17 +115,30 @@
rotary_params.seq_stride = head_size;
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
rotary_params.position_ids_format = parameters.is_interactive || !parameters.is_prompt ? 1 : 0;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
if (sequence_length == 1) {
// Generate position ids
const int pos_ids_size = parameters.is_interactive ? batch_size * sequence_length : (parameters.is_prompt ? 1 : batch_size);

Check warning on line 122 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:122: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::vector<int64_t> pos_ids(pos_ids_size);
if (parameters.is_interactive) {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
for (int s = 0; s < sequence_length; s++) {
if (seqlens_k->Data<int32_t>()[b] + s < seqlens_k->Data<int32_t>()[batch_size + b]) {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b] + s);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
Fixed Show fixed Hide fixed
} else {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(1);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
}
}
}
} else {
} else if (parameters.is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
}
}
// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
T* q_rotary;
Expand All @@ -149,6 +163,7 @@
Q = RotaryQ;
K = RotaryK;
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
Expand All @@ -161,6 +176,7 @@
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
// Pack V into rotary QKV buffer
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
Expand Down
27 changes: 21 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Status CheckInputs(const Tensor* query,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
float scale) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
// Note: Here S* is past_cache_sequence_length, S+ is seqlen_present_kv_cache
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// no packing for q/k/v:
Expand Down Expand Up @@ -169,12 +169,16 @@ Status CheckInputs(const Tensor* query,

// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
bool is_interactive = seqlens_dim.size() == 2;
if (is_interactive && (seqlens_dim[1] != batch_size || seqlens_dim[0] != 2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (2, batch_size), or shape (batch_size).");
} else if (!is_interactive && (seqlens_dim.size() > 1 || seqlens_dim[0] != batch_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}

// Set present sequence length and kv_share_buffer from input total_seqlen tensor
// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
Expand All @@ -194,11 +198,11 @@ Status CheckInputs(const Tensor* query,
}
if (cos_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be not be less than total_sequence_length.");
"cos_cache dimension 0 shall not be less than total_sequence_length.");
}
if (sin_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be not be less than total_sequence_length.");
"sin_cache dimension 0 shall not be less than total_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -218,7 +222,17 @@ Status CheckInputs(const Tensor* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

bool is_prompt = sequence_length != 1;
bool is_prompt;
if (is_interactive) {
is_prompt = false; // irrelevant for interactive decoding
} else {
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
}

if (parameters != nullptr) {
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
Expand All @@ -234,6 +248,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_interactive = is_interactive;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->qkv_format = qkv_format;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class SparseAttentionBase {
// Concatenate past_k + k -> present_k
// TODO: avoid copying mutiple times for a group.
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
i / kv_num_heads_factor);

// Compute Q*K' + AttentionMask
Expand Down Expand Up @@ -365,7 +365,7 @@ class SparseAttentionBase {

// Concatenate past_v + v -> present_v
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
i / kv_num_heads_factor);

DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@

auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;

// Advance to current batch - in case of different sequence lengths
if (p.seqlen_k_ptr) {
// When seqstart_k_ptr is provided, we interpret it as past sequence length. This is used for interactive mode in GQA

Check warning on line 45 in onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h:45: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (p.seqstart_k_ptr) {
p.num_keys = p.seqstart_k_ptr[batch_id] + p.num_queries;
} else if (p.seqlen_k_ptr) {
p.num_keys = p.seqlen_k_ptr[batch_id];
}

Expand Down
30 changes: 21 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,19 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}

// Check seqlens_k tensor (holding past seqlen for token gen)
// Check seqlens_k tensor. Holds past_sequence_length and total_sequence_length for each sequence,
// or (total_sequence_length - 1) for each sequence. 2-d case enables interactive decoding.
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
bool is_interactive = seqlens_dim.size() == 2;
if (is_interactive && (seqlens_dim[1] != batch_size || seqlens_dim[0] != 2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (2, batch_size), or shape (batch_size).");
} else if (!is_interactive && (seqlens_dim.size() > 1 || seqlens_dim[0] != batch_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}

// Set present sequence length and kv_share_buffer from input total_seqlen tensor
// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
Expand All @@ -215,11 +220,11 @@ Status CheckInputs(const Tensor* query,
}
if (cos_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be not be less than total_sequence_length.");
"cos_cache dimension 0 shall not be less than total_sequence_length.");
}
if (sin_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be not be less than total_sequence_length.");
"sin_cache dimension 0 shall not be less than total_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -239,10 +244,16 @@ Status CheckInputs(const Tensor* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

bool is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
bool is_prompt;
if (is_interactive) {
is_prompt = false; // irrelevant for interactive decoding
} else {
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
}

if (parameters != nullptr) {
Expand All @@ -258,6 +269,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_interactive = is_interactive;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->qkv_format = qkv_format;
Expand Down
Loading
Loading