diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 0a4e43b1697e1..0445c9509053c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -146,19 +146,26 @@ class AttentionCPUBase : public AttentionBase { const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; + const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length); - unit_cost.bytes_loaded = double(sequence_length * head_size * sizeof(T) + head_size * total_sequence_length * sizeof(T)); - unit_cost.bytes_stored = double(sequence_length * total_sequence_length * sizeof(T)); + unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); if (mask_data != nullptr) { - unit_cost.bytes_loaded += double(sequence_length * total_sequence_length * sizeof(T)); - unit_cost.bytes_stored += double(sequence_length * total_sequence_length * sizeof(T)); + unit_cost.bytes_loaded += double(probs_matrix_bytes); + unit_cost.bytes_stored += double(probs_matrix_bytes); } if (present || present_key) { - size_t bytes_to_copy_key = sizeof(T) * size_t(past || past_key ? present_chunk_length : present_chunk_length - past_chunk_length); - unit_cost.bytes_loaded += double(bytes_to_copy_key); - unit_cost.bytes_stored += double(bytes_to_copy_key); + double bytes_to_copy_key = static_cast(sizeof(T) * present_chunk_length); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } + + if (relative_position_bias_data != nullptr) { + unit_cost.compute_cycles += double(sequence_length * total_sequence_length); + unit_cost.bytes_loaded += probs_matrix_bytes; + unit_cost.bytes_stored += probs_matrix_bytes; } ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { @@ -173,7 +180,7 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { memcpy(output, mask_data + mask_offset, - static_cast(sequence_length) * total_sequence_length * sizeof(T)); + probs_matrix_bytes); } const T* k = K + kv_input_chunk_length * i; @@ -243,13 +250,13 @@ class AttentionCPUBase : public AttentionBase { // The cost of Gemm TensorOpCost unit_cost; unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length); - unit_cost.bytes_loaded = double(sequence_length * total_sequence_length * sizeof(T) + v_head_size * total_sequence_length * sizeof(T)); - unit_cost.bytes_stored = double(sequence_length * v_head_size * sizeof(T)); + unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); if (present || present_value) { - size_t bytes_to_copy_value = sizeof(T) * size_t(past || past_value ? present_chunk_length : present_chunk_length - past_chunk_length); - unit_cost.bytes_loaded += double(bytes_to_copy_value); - unit_cost.bytes_stored += double(bytes_to_copy_value); + double bytes_to_copy_value = static_cast(present_chunk_length * sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_value; + unit_cost.bytes_stored += bytes_to_copy_value; } const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T);