Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize threading of mha #20088

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,26 @@
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

TensorOpCost unit_cost;
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * total_sequence_length * sizeof(T);
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = double(sequence_length * head_size * sizeof(T) + head_size * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_loaded = static_cast<double>((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);

if (mask_data != nullptr) {
unit_cost.bytes_loaded += double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_stored += double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_loaded += double(probs_matrix_bytes);

Check warning on line 155 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:155: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
unit_cost.bytes_stored += double(probs_matrix_bytes);

Check warning on line 156 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:156: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
}

if (present || present_key) {
size_t bytes_to_copy_key = sizeof(T) * size_t(past || past_key ? present_chunk_length : present_chunk_length - past_chunk_length);
unit_cost.bytes_loaded += double(bytes_to_copy_key);
unit_cost.bytes_stored += double(bytes_to_copy_key);
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_chunk_length);
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
}

if (relative_position_bias_data != nullptr) {
unit_cost.compute_cycles += double(sequence_length * total_sequence_length);

Check warning on line 166 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:166: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
unit_cost.bytes_loaded += probs_matrix_bytes;
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
unit_cost.bytes_stored += probs_matrix_bytes;
}

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
Expand All @@ -173,7 +180,7 @@
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
probs_matrix_bytes);
}

const T* k = K + kv_input_chunk_length * i;
Expand Down Expand Up @@ -243,20 +250,20 @@
// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * v_head_size * total_sequence_length);
unit_cost.bytes_loaded = double(sequence_length * total_sequence_length * sizeof(T) + v_head_size * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = double(sequence_length * v_head_size * sizeof(T));
unit_cost.bytes_loaded = static_cast<double>((sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * v_head_size * sizeof(T));

if (present || present_value) {
size_t bytes_to_copy_value = sizeof(T) * size_t(past || past_value ? present_chunk_length : present_chunk_length - past_chunk_length);
unit_cost.bytes_loaded += double(bytes_to_copy_value);
unit_cost.bytes_stored += double(bytes_to_copy_value);
double bytes_to_copy_value = static_cast<double>(present_chunk_length * sizeof(T));
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}

const size_t bytes_to_copy_trans = SafeInt<size_t>(v_head_size) * sizeof(T);
unit_cost.bytes_loaded += double(SafeInt<size_t>(sequence_length) * bytes_to_copy_trans);

Check warning on line 263 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:263: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
unit_cost.bytes_stored += double(SafeInt<size_t>(sequence_length) * bytes_to_copy_trans);

Check warning on line 264 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:264: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]

ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {

Check warning on line 266 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:266: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
Expand Down
Loading