Skip to content

Commit

Permalink
support sequence length > 1 for non prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 2, 2024
1 parent b67b706 commit 2684e97
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
22 changes: 16 additions & 6 deletions onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(q_hidden_size);
Tensor* output = context->Output(0, output_shape);

parameters.past_present_share_buffer = true; // Only supports share kv cache buffer for past and present for now.
constexpr bool past_present_share_buffer = true; // Only supports share buffer for past and present for now.
parameters.past_present_share_buffer = past_present_share_buffer;

int head_size = parameters.head_size;
const int cache_length = parameters.past_present_share_buffer
const int cache_length = past_present_share_buffer
? parameters.max_cache_sequence_length
: parameters.total_sequence_length;
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
Expand All @@ -100,7 +101,7 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_v_shape);

// Check past and present share buffer.
if (parameters.past_present_share_buffer) {
if (past_present_share_buffer) {
ORT_ENFORCE(past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw());
}

Expand Down Expand Up @@ -142,13 +143,22 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();

std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
if (sequence_length == 1) {
const bool is_prompt = parameters.total_sequence_length == parameters.sequence_length;
std::vector<int64_t> pos_ids(is_prompt ? 1 : batch_size * sequence_length);
if (is_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else if (sequence_length == 1) {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(total_key_lengths->Data<int32_t>()[b]) - 1;
}
} else {
pos_ids[0] = static_cast<int64_t>(0);
// This supports a rare case that sequence_length > 1 when it is not prompt.
for (int b = 0; b < batch_size; b++) {
for (int s = 0; s < sequence_length; s++) {
pos_ids[b * sequence_length + s] = static_cast<int64_t>(total_key_lengths->Data<int32_t>()[b]) -
(sequence_length - s);
}
}
}

const T* q_input;
Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,6 @@ Status CheckInputs(void* params,
}
int total_sequence_length = *((*total_seq_len).template Data<int32_t>());

// // Make sure that query sequence length is 1 when it is not prompt.
// if (total_sequence_length > sequence_length && sequence_length != 1) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
// "sequence_length shall be 1 when total_sequence_length > sequence_length.");
// }

// Check block_row_indices
const auto& block_row_indices_dim = block_row_indices->Shape().GetDims();
if (!(block_row_indices_dim.size() == 2 &&
Expand Down
29 changes: 25 additions & 4 deletions onnxruntime/test/python/transformers/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,11 +843,20 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot
for head_size in head_sizes:
for format in formats:
packed_qkv = format == InputFormats.QKV_BSN3H

non_prompt_len = 1
if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary:
# Generate case of sequence_length > 1 when it is not prompt for CPU provider.
non_prompt_len = batch_size

query_sequence_length = non_prompt_len if has_past_kv else sequence_length
config = SparseAttentionConfig(
batch_size=batch_size,
sequence_length=1 if has_past_kv else sequence_length,
sequence_length=query_sequence_length,
max_sequence_length=256,
past_sequence_length=min(255, sequence_length) if has_past_kv else 0,
past_sequence_length=(
min(256 - query_sequence_length, sequence_length) if has_past_kv else 0
),
num_heads=num_heads,
kv_num_heads=num_heads // 2,
head_size=head_size,
Expand All @@ -873,11 +882,19 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot
head_size = head_sizes[i % len(head_sizes)]
format = formats[i % len(formats)]
packed_qkv = format == InputFormats.QKV_BSN3H

non_prompt_len = 1
if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary:
# Generate case of sequence_length > 1 when it is not prompt for CPU provider.
non_prompt_len = batch_size

query_sequence_length = non_prompt_len if has_past_kv else sequence_length

config = SparseAttentionConfig(
batch_size=batch_size,
sequence_length=1 if has_past_kv else sequence_length,
sequence_length=query_sequence_length,
max_sequence_length=256,
past_sequence_length=sequence_length if has_past_kv else 0,
past_sequence_length=min(256 - query_sequence_length, sequence_length) if has_past_kv else 0,
num_heads=num_heads,
kv_num_heads=num_heads // 2,
head_size=head_size,
Expand Down Expand Up @@ -927,6 +944,10 @@ def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig):
def test_sparse_att_token_gpu(self, config):
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_token_cpu(self, config):
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_prompt_cpu(self, config):
self.run_one_relevance_test(config)
Expand Down

0 comments on commit 2684e97

Please sign in to comment.