diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index 0f6f1a757c8f7..e337f41a8688d 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -82,10 +82,11 @@ Status SparseAttention::Compute(OpKernelContext* context) const { output_shape[2] = static_cast(q_hidden_size); Tensor* output = context->Output(0, output_shape); - parameters.past_present_share_buffer = true; // Only supports share kv cache buffer for past and present for now. + constexpr bool past_present_share_buffer = true; // Only supports share buffer for past and present for now. + parameters.past_present_share_buffer = past_present_share_buffer; int head_size = parameters.head_size; - const int cache_length = parameters.past_present_share_buffer + const int cache_length = past_present_share_buffer ? parameters.max_cache_sequence_length : parameters.total_sequence_length; std::vector present_k_shape({static_cast(batch_size), @@ -100,7 +101,7 @@ Status SparseAttention::Compute(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_v_shape); // Check past and present share buffer. - if (parameters.past_present_share_buffer) { + if (past_present_share_buffer) { ORT_ENFORCE(past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw()); } @@ -142,13 +143,22 @@ Status SparseAttention::Compute(OpKernelContext* context) const { rotary_params.transposed = true; auto* tp = context->GetOperatorThreadPool(); - std::vector pos_ids(sequence_length == 1 ? batch_size : 1); - if (sequence_length == 1) { + const bool is_prompt = parameters.total_sequence_length == parameters.sequence_length; + std::vector pos_ids(is_prompt ? 1 : batch_size * sequence_length); + if (is_prompt) { + pos_ids[0] = static_cast(0); + } else if (sequence_length == 1) { for (int b = 0; b < batch_size; b++) { pos_ids[b] = static_cast(total_key_lengths->Data()[b]) - 1; } } else { - pos_ids[0] = static_cast(0); + // This supports a rare case that sequence_length > 1 when it is not prompt. + for (int b = 0; b < batch_size; b++) { + for (int s = 0; s < sequence_length; s++) { + pos_ids[b * sequence_length + s] = static_cast(total_key_lengths->Data()[b]) - + (sequence_length - s); + } + } } const T* q_input; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h index 82baa3b9a4d51..ca69370b4ce17 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h @@ -128,12 +128,6 @@ Status CheckInputs(void* params, } int total_sequence_length = *((*total_seq_len).template Data()); - // // Make sure that query sequence length is 1 when it is not prompt. - // if (total_sequence_length > sequence_length && sequence_length != 1) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - // "sequence_length shall be 1 when total_sequence_length > sequence_length."); - // } - // Check block_row_indices const auto& block_row_indices_dim = block_row_indices->Shape().GetDims(); if (!(block_row_indices_dim.size() == 2 && diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index c95a69e8a1fbe..64877fb257e20 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -843,11 +843,20 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot for head_size in head_sizes: for format in formats: packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length config = SparseAttentionConfig( batch_size=batch_size, - sequence_length=1 if has_past_kv else sequence_length, + sequence_length=query_sequence_length, max_sequence_length=256, - past_sequence_length=min(255, sequence_length) if has_past_kv else 0, + past_sequence_length=( + min(256 - query_sequence_length, sequence_length) if has_past_kv else 0 + ), num_heads=num_heads, kv_num_heads=num_heads // 2, head_size=head_size, @@ -873,11 +882,19 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot head_size = head_sizes[i % len(head_sizes)] format = formats[i % len(formats)] packed_qkv = format == InputFormats.QKV_BSN3H + + non_prompt_len = 1 + if provider == "CPUExecutionProvider" and sequence_length > 128 and not do_rotary: + # Generate case of sequence_length > 1 when it is not prompt for CPU provider. + non_prompt_len = batch_size + + query_sequence_length = non_prompt_len if has_past_kv else sequence_length + config = SparseAttentionConfig( batch_size=batch_size, - sequence_length=1 if has_past_kv else sequence_length, + sequence_length=query_sequence_length, max_sequence_length=256, - past_sequence_length=sequence_length if has_past_kv else 0, + past_sequence_length=min(256 - query_sequence_length, sequence_length) if has_past_kv else 0, num_heads=num_heads, kv_num_heads=num_heads // 2, head_size=head_size, @@ -927,6 +944,10 @@ def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig): def test_sparse_att_token_gpu(self, config): self.run_one_relevance_test(config) + @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True) + def test_sparse_att_token_cpu(self, config): + self.run_one_relevance_test(config) + @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True) def test_sparse_att_prompt_cpu(self, config): self.run_one_relevance_test(config)