diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index c617533319a18..34f57c1655cc2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -145,10 +145,30 @@ class AttentionCPUBase : public AttentionBase { const int loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; - // The cost of Gemm - const double cost = static_cast(head_size) * sequence_length * total_sequence_length; + TensorOpCost unit_cost; + const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + if (mask_data != nullptr) { + unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); + unit_cost.bytes_stored += static_cast(probs_matrix_bytes); + } + + if (present || present_key) { + double bytes_to_copy_key = static_cast(sizeof(T) * present_chunk_length); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } + + if (relative_position_bias_data != nullptr) { + unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length); + unit_cost.bytes_loaded += probs_matrix_bytes * 2; + unit_cost.bytes_stored += probs_matrix_bytes; + } + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; @@ -160,7 +180,7 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { memcpy(output, mask_data + mask_offset, - static_cast(sequence_length) * total_sequence_length * sizeof(T)); + probs_matrix_bytes); } const T* k = K + kv_input_chunk_length * i; @@ -227,10 +247,24 @@ class AttentionCPUBase : public AttentionBase { present += SafeInt(batch_size) * num_heads_ * total_sequence_length * v_head_size; } - const double cost = - static_cast(sequence_length) * static_cast(v_head_size) * static_cast(sequence_length); + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); + + if (present || present_value) { + double bytes_to_copy_value = static_cast(present_chunk_length * sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; + } + + const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const T* v = V + kv_input_chunk_length * i; if (nullptr != present) { @@ -252,9 +286,8 @@ class AttentionCPUBase : public AttentionBase { T* src = current_tmp_data; ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; T* dest = output + dest_offset; - const auto bytes_to_copy = SafeInt(v_head_size) * sizeof(T); for (int j = 0; j < sequence_length; j++) { - memcpy(dest, src, bytes_to_copy); + memcpy(dest, src, bytes_to_copy_trans); src += v_head_size; dest += v_hidden_size; }