Skip to content

Commit

Permalink
refine the code
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Apr 1, 2024
1 parent dc3d01b commit d4a32d3
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,26 @@ class AttentionCPUBase : public AttentionBase {
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

TensorOpCost unit_cost;
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * total_sequence_length * sizeof(T);
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = double(sequence_length * head_size * sizeof(T) + head_size * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_loaded = static_cast<double>((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);

if (mask_data != nullptr) {
unit_cost.bytes_loaded += double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_stored += double(sequence_length * total_sequence_length * sizeof(T));
unit_cost.bytes_loaded += double(probs_matrix_bytes);

Check warning on line 155 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:155: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
unit_cost.bytes_stored += double(probs_matrix_bytes);

Check warning on line 156 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:156: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
}

if (present || present_key) {
size_t bytes_to_copy_key = sizeof(T) * size_t(past || past_key ? present_chunk_length : present_chunk_length - past_chunk_length);
unit_cost.bytes_loaded += double(bytes_to_copy_key);
unit_cost.bytes_stored += double(bytes_to_copy_key);
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_chunk_length);
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
}

if (relative_position_bias_data != nullptr) {
unit_cost.compute_cycles += double(sequence_length * total_sequence_length);

Check warning on line 166 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:166: Using deprecated casting style. Use static_cast<double>(...) instead [readability/casting] [4]
unit_cost.bytes_loaded += probs_matrix_bytes;
unit_cost.bytes_stored += probs_matrix_bytes;
}

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
Expand All @@ -173,7 +180,7 @@ class AttentionCPUBase : public AttentionBase {
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
probs_matrix_bytes);
}

const T* k = K + kv_input_chunk_length * i;
Expand Down Expand Up @@ -243,13 +250,13 @@ class AttentionCPUBase : public AttentionBase {
// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * v_head_size * total_sequence_length);
unit_cost.bytes_loaded = double(sequence_length * total_sequence_length * sizeof(T) + v_head_size * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = double(sequence_length * v_head_size * sizeof(T));
unit_cost.bytes_loaded = static_cast<double>((sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * v_head_size * sizeof(T));

if (present || present_value) {
size_t bytes_to_copy_value = sizeof(T) * size_t(past || past_value ? present_chunk_length : present_chunk_length - past_chunk_length);
unit_cost.bytes_loaded += double(bytes_to_copy_value);
unit_cost.bytes_stored += double(bytes_to_copy_value);
double bytes_to_copy_value = static_cast<double>(present_chunk_length * sizeof(T));
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}

const size_t bytes_to_copy_trans = SafeInt<size_t>(v_head_size) * sizeof(T);
Expand Down

0 comments on commit d4a32d3

Please sign in to comment.