Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize threading of mha #20088

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,30 @@
const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

// The cost of Gemm
const double cost = static_cast<double>(head_size) * sequence_length * total_sequence_length;
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
TensorOpCost unit_cost;
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * total_sequence_length * sizeof(T);
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);

ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
if (mask_data != nullptr) {
unit_cost.bytes_loaded += static_cast<double>(probs_matrix_bytes);
unit_cost.bytes_stored += static_cast<double>(probs_matrix_bytes);
}

if (present || present_key) {
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_chunk_length);
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
}

if (relative_position_bias_data != nullptr) {
unit_cost.compute_cycles += static_cast<double>(sequence_length * total_sequence_length);
unit_cost.bytes_loaded += probs_matrix_bytes * 2;
unit_cost.bytes_stored += probs_matrix_bytes;
}

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i) / num_heads_;

Expand All @@ -160,7 +180,7 @@
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
probs_matrix_bytes);
}

const T* k = K + kv_input_chunk_length * i;
Expand Down Expand Up @@ -227,10 +247,24 @@
present += SafeInt<ptrdiff_t>(batch_size) * num_heads_ * total_sequence_length * v_head_size;
}

const double cost =
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
static_cast<double>(sequence_length) * static_cast<double>(v_head_size) * static_cast<double>(sequence_length);
// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * v_head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * v_head_size * sizeof(T));

if (present || present_value) {
double bytes_to_copy_value = static_cast<double>(present_chunk_length * sizeof(T));
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}

const size_t bytes_to_copy_trans = SafeInt<size_t>(v_head_size) * sizeof(T);
double bytes_to_copy_trans_all = static_cast<double>(sequence_length * bytes_to_copy_trans);
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {

Check warning on line 267 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:267: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
Expand All @@ -252,9 +286,8 @@
T* src = current_tmp_data;
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
T* dest = output + dest_offset;
const auto bytes_to_copy = SafeInt<size_t>(v_head_size) * sizeof(T);
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy);
memcpy(dest, src, bytes_to_copy_trans);
src += v_head_size;
dest += v_hidden_size;
}
Expand Down
Loading