Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Decoding support in GQA #21523

Merged
merged 28 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters {
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool is_interactive; // indicates whether we have past context and seqlen > 1
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
bool is_prompt; // indicates whether this is first decoding step
bool do_rotary;
bool rotary_interleaved;
bool use_smooth_softmax;
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past,
size_t past_buff_chunk_length,
size_t past_chunk_length,
size_t new_chunk_length,
bool is_prompt,
bool past_present_share_buffer,
std::ptrdiff_t i) {
T* start = present + i * present_buff_chunk_length;

T* p = start;
if (!is_prompt) {
if (!past_present_share_buffer) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;
if (!past_present_share_buffer && past_chunk_length > 0) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;

memcpy(p, chunk, new_chunk_length * sizeof(T));
return start;
Expand Down
177 changes: 89 additions & 88 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Large diffs are not rendered by default.

31 changes: 23 additions & 8 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
const Tensor* past_key = context->Input<Tensor>(3);
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* total_seqlen_tensor = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);

Expand All @@ -61,7 +61,7 @@
num_heads_,
kv_num_heads_,
seqlens_k,
total_seqlen,
total_seqlen_tensor,
scale_,
softcap_));

Expand Down Expand Up @@ -103,6 +103,7 @@
}

if (do_rotary_) {
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
rotary_params.sequence_length = sequence_length;
Expand All @@ -114,17 +115,29 @@
rotary_params.seq_stride = head_size;
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
rotary_params.position_ids_format = parameters.is_interactive || !parameters.is_prompt ? 1 : 0;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
if (sequence_length == 1) {
// Generate position ids
const int pos_ids_size = (parameters.is_prompt && !parameters.is_interactive) ? 1 : batch_size * sequence_length;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> pos_ids(pos_ids_size);
if (parameters.is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
// Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1.

Check warning on line 127 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:127: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
for (int s = 0; s < sequence_length; s++) {
const int total_seqlen = seqlens_k->Data<int32_t>()[b] + 1;
const int past_seqlen = total_seqlen - sequence_length;
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
if (past_seqlen + s < total_seqlen) {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(past_seqlen) + s;
} else {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(1);
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
}
}
}
} else {
pos_ids[0] = static_cast<int64_t>(0);
}
// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
T* q_rotary;
Expand All @@ -149,6 +162,7 @@
Q = RotaryQ;
K = RotaryK;
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
Expand All @@ -161,6 +175,7 @@
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
// Pack V into rotary QKV buffer
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
Expand Down
33 changes: 26 additions & 7 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}

// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim[0] != batch_size) {
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}

// Set present sequence length and kv_share_buffer from input total_seqlen tensor
// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
Expand All @@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query,
}
if (cos_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be not be less than total_sequence_length.");
"cos_cache dimension 0 shall not be less than total_sequence_length.");
}
if (sin_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be not be less than total_sequence_length.");
"sin_cache dimension 0 shall not be less than total_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -219,7 +218,26 @@ Status CheckInputs(const Tensor* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

bool is_prompt = sequence_length != 1;
bool is_interactive = false;
if (sequence_length > 1 && sequence_length != total_sequence_length) {
if (batch_size != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"batch_size must be 1 when sequence_length > 1 and past context is given.");
}
is_interactive = true;
}

bool is_prompt;
if (is_interactive) {
is_prompt = false; // irrelevant for interactive decoding
} else {
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
}

if (parameters != nullptr) {
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
Expand All @@ -235,6 +253,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_interactive = is_interactive;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->softcap = softcap;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class SparseAttentionBase {
// Concatenate past_k + k -> present_k
// TODO: avoid copying mutiple times for a group.
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
i / kv_num_heads_factor);

// Compute Q*K' + AttentionMask
Expand Down Expand Up @@ -365,7 +365,7 @@ class SparseAttentionBase {

// Concatenate past_v + v -> present_v
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
aciddelgado marked this conversation as resolved.
Show resolved Hide resolved
i / kv_num_heads_factor);

DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct RightPaddingBatchHook {

auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;

// Advance to current batch - in case of different sequence lengths
if (p.seqlen_k_ptr) {
p.num_keys = p.seqlen_k_ptr[batch_id];
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}
if (seqlens_k_buffer != nullptr) {
data.seqlens_k_total = reinterpret_cast<int*>(seqlens_k_buffer.get());
data.seqlens_k_buff = reinterpret_cast<int*>(seqlens_k_buffer.get());
}
// Memory Efficient Buffers
if (k_buffer != nullptr) {
Expand Down
35 changes: 25 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,13 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}

// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}

// Set present sequence length and kv_share_buffer from input total_seqlen tensor
// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
Expand All @@ -216,11 +215,11 @@ Status CheckInputs(const Tensor* query,
}
if (cos_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be not be less than total_sequence_length.");
"cos_cache dimension 0 shall not be less than total_sequence_length.");
}
if (sin_dims[0] < total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be not be less than total_sequence_length.");
"sin_cache dimension 0 shall not be less than total_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -240,10 +239,25 @@ Status CheckInputs(const Tensor* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

bool is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
bool is_interactive = false;
if (sequence_length > 1 && sequence_length != total_sequence_length) {
if (batch_size != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"batch_size must be 1 when sequence_length > 1 and past context is given.");
}
is_interactive = true;
}

bool is_prompt;
if (is_interactive) {
is_prompt = false; // irrelevant for interactive decoding
} else {
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
is_prompt = (sequence_length == total_sequence_length);
if (!is_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
}

if (parameters != nullptr) {
Expand All @@ -260,6 +274,7 @@ Status CheckInputs(const Tensor* query,
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_interactive = is_interactive;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->softcap = softcap;
Expand Down
Loading
Loading