diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 34f57c1655cc2..8ae7b4589d677 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -69,9 +69,8 @@ class AttentionCPUBase : public AttentionBase { BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data() : nullptr; - gsl::span mask_index_dims = mask_index != nullptr - ? mask_index->Shape().GetDims() - : gsl::span{}; + gsl::span mask_index_dims = + mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{}; const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; @@ -84,22 +83,19 @@ class AttentionCPUBase : public AttentionBase { relative_position_bias_data = relative_position_bias->Data(); } - ComputeAttentionProbs(static_cast(attention_probs), Q, K, - mask_index_data, mask_index_dims, static_cast(mask_data), causal, - batch_size, sequence_length, kv_sequence_length, past_sequence_length, - qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, relative_position_bias_data); + ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, + static_cast(mask_data), causal, batch_size, sequence_length, kv_sequence_length, + past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, + past_key_data, present_data, present_key_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T)); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); - ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), - static_cast(attention_probs), V, - batch_size, sequence_length, kv_sequence_length, past_sequence_length, - v_head_size, v_hidden_size, past_data, past_value_data, - present_data, present_value_data, tp); + ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), + V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size, + v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp); return Status::OK(); } @@ -138,16 +134,17 @@ class AttentionCPUBase : public AttentionBase { { // mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxT if (mask_data != nullptr) { - PrepareMask(mask_index, mask_index_dims, mask_data, - causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); + PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length, + mask_filter_value_); } const int loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; - const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length); + const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); @@ -172,15 +169,13 @@ class AttentionCPUBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - const int mask_offset = batch_index * sequence_length * total_sequence_length; + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; + const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length; T* output = attention_probs + output_offset; // Broadcast mask data: (Bx)SxT -> (BxNx)SxT if (mask_data != nullptr) { - memcpy(output, - mask_data + mask_offset, - probs_matrix_bytes); + memcpy(output, mask_data + mask_offset, probs_matrix_bytes); } const T* k = K + kv_input_chunk_length * i; @@ -197,8 +192,8 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, - output, nullptr); + Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, + nullptr); if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { @@ -249,8 +244,10 @@ class AttentionCPUBase : public AttentionBase { // The cost of Gemm TensorOpCost unit_cost; - unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T)); unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); if (present || present_value) { @@ -264,35 +261,36 @@ class AttentionCPUBase : public AttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const T* v = V + kv_input_chunk_length * i; - if (nullptr != present) { - // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v - v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); - } else if (nullptr != present_value) { - v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); - } + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const T* v = V + kv_input_chunk_length * i; + if (nullptr != present) { + // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v + v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); + } else if (nullptr != present_value) { + v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); + } - T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; - math::MatMul(sequence_length, v_head_size, total_sequence_length, - attention_probs + attention_probs_offset, - v, current_tmp_data, nullptr); - - // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - T* src = current_tmp_data; - ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; - T* dest = output + dest_offset; - for (int j = 0; j < sequence_length; j++) { - memcpy(dest, src, bytes_to_copy_trans); - src += v_head_size; - dest += v_hidden_size; - } - } - }); + T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + math::MatMul(sequence_length, v_head_size, total_sequence_length, + attention_probs + attention_probs_offset, v, current_tmp_data, nullptr); + + // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + T* src = current_tmp_data; + ptrdiff_t dest_offset = + (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; + T* dest = output + dest_offset; + for (int j = 0; j < sequence_length; j++) { + memcpy(dest, src, bytes_to_copy_trans); + src += v_head_size; + dest += v_hidden_size; + } + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index fa80efffc9ea1..6b0c5f395cab0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -63,17 +63,16 @@ class GQAAttentionBase : public AttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, - seqlens_k->Data(), - batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + present_key_data, past_present_share_buffer, packed_qkv, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), - v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, - past_present_share_buffer, packed_qkv, tp); + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, + tp); return Status::OK(); } @@ -98,7 +97,9 @@ class GQAAttentionBase : public AttentionBase { bool packed_qkv, // whether Q, K, V are packed ThreadPool* tp) const { // thread pool const bool is_prompt = sequence_length != 1; - const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); const int kv_num_heads_factor = num_heads_ / kv_num_heads_; const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H @@ -113,9 +114,12 @@ class GQAAttentionBase : public AttentionBase { const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; - const size_t probs_matrix_bytes = SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T); - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T)); + const ptrdiff_t probs_matrix_bytes = + SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = + static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); @@ -131,11 +135,12 @@ class GQAAttentionBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const int past_seqlen = + sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; const size_t past_chunk_length = static_cast(past_seqlen) * head_size; const int total_seqlen = seqlens_k[batch_index] + 1; - const int output_offset = static_cast(i) * sequence_length * present_buffer_sequence_length; + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; const T* k; @@ -161,11 +166,9 @@ class GQAAttentionBase : public AttentionBase { } else { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, - sequence_length, total_seqlen, head_size, alpha, - q, head_size, k, head_size, - 0.0f /*bata*/, - output, present_buffer_sequence_length, nullptr); + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, + nullptr); // compute Softmax T* output_softmax = output; @@ -175,7 +178,8 @@ class GQAAttentionBase : public AttentionBase { for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, local_window_size_ + 1, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); } else { ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); } @@ -208,7 +212,9 @@ class GQAAttentionBase : public AttentionBase { bool packed_qkv, // whether Q, K, V are packed ThreadPool* tp) const { const bool is_prompt = sequence_length != 1; - const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); const int kv_num_heads_factor = num_heads_ / kv_num_heads_; const int kv_input_chunk_length = sequence_length * head_size; // L x H const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H @@ -220,8 +226,10 @@ class GQAAttentionBase : public AttentionBase { // The cost of Gemm TensorOpCost unit_cost; - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T)); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) * + present_buffer_sequence_length * sizeof(T)); unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T)); if (present_value) { @@ -235,39 +243,37 @@ class GQAAttentionBase : public AttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + const int past_seqlen = + sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const size_t past_chunk_length = static_cast(past_seqlen) * head_size; + const int total_seqlen = seqlens_k[batch_index] + 1; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + } - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - - math::GemmEx(CblasNoTrans, - CblasNoTrans, - sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, - v, head_size, - 0.0f /*beta*/, - output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ + attention_probs + attention_probs_offset, present_buffer_sequence_length, v, + head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + } + }); } }; diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 4df1ac1cc2b7e..b6b8aee15852f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1775,6 +1775,7 @@ def test_gqa_no_past(self): (2000, 2000), (200, 200), (240, 240), + (8000, 8000), ] ) num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]