Skip to content

Commit

Permalink
optimize threading of mha
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Mar 26, 2024
1 parent 7d976cf commit 75787c5
Showing 1 changed file with 57 additions and 67 deletions.
124 changes: 57 additions & 67 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,45 +145,40 @@ class AttentionCPUBase : public AttentionBase {
const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

// The cost of Gemm
const double cost = static_cast<double>(head_size) * sequence_length * total_sequence_length;

ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i) / num_heads_;

const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
const int mask_offset = batch_index * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;

// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}
ThreadPool::TrySimpleParallelFor(tp, loop_len, [&](std::ptrdiff_t batch_head_id) {
const int batch_index = static_cast<int>(batch_head_id) / num_heads_;

const int output_offset = static_cast<int>(batch_head_id) * sequence_length * total_sequence_length;
const int mask_offset = batch_index * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;

// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}

const T* k = K + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_key) {
k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i);
}
const T* k = K + kv_input_chunk_length * batch_head_id;
if (nullptr != present) {
// Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, batch_head_id);
} else if (nullptr != present_key) {
k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, batch_head_id);
}

// Compute Q*K' + AttentionMask
// original transposed each iteration
// A: Q (B x N x) S x H (B x N x) S x H S x H
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f,
output, nullptr);

if (relative_position_bias_data != nullptr) {
for (int j = 0; j < sequence_length * total_sequence_length; j++) {
output[j] += relative_position_bias_data[output_offset + j];
}
// Compute Q*K' + AttentionMask
// original transposed each iteration
// A: Q (B x N x) S x H (B x N x) S x H S x H
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + q_input_chunk_length * batch_head_id, k, mask_data != nullptr ? 1.0f : 0.0f,
output, nullptr);

if (relative_position_bias_data != nullptr) {
for (int j = 0; j < sequence_length * total_sequence_length; j++) {
output[j] += relative_position_bias_data[output_offset + j];
}
}
});
Expand Down Expand Up @@ -227,37 +222,32 @@ class AttentionCPUBase : public AttentionBase {
present += SafeInt<ptrdiff_t>(batch_size) * num_heads_ * total_sequence_length * v_head_size;
}

const double cost =
static_cast<double>(sequence_length) * static_cast<double>(v_head_size) * static_cast<double>(sequence_length);

ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
}
ThreadPool::TrySimpleParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, [&](std::ptrdiff_t batch_head_id) {

Check warning on line 225 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:225: Lines should be <= 120 characters long [whitespace/line_length] [2]
const T* v = V + kv_input_chunk_length * batch_head_id;
if (nullptr != present) {
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, batch_head_id);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, batch_head_id);
}

T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * i;
math::MatMul<T>(sequence_length, v_head_size, total_sequence_length,
attention_probs + attention_probs_offset,
v, current_tmp_data, nullptr);

// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
T* src = current_tmp_data;
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
T* dest = output + dest_offset;
const auto bytes_to_copy = SafeInt<size_t>(v_head_size) * sizeof(T);
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy);
src += v_head_size;
dest += v_hidden_size;
}
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * batch_head_id;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * batch_head_id;
math::MatMul<T>(sequence_length, v_head_size, total_sequence_length,
attention_probs + attention_probs_offset,
v, current_tmp_data, nullptr);

// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
const int batch_index = static_cast<int>(batch_head_id / num_heads_);
const int head_index = static_cast<int>(batch_head_id % num_heads_);
T* src = current_tmp_data;
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;

Check warning on line 244 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:244: Lines should be <= 120 characters long [whitespace/line_length] [2]
T* dest = output + dest_offset;
const auto bytes_to_copy = SafeInt<size_t>(v_head_size) * sizeof(T);
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy);
src += v_head_size;
dest += v_hidden_size;
}
});
}
Expand Down

0 comments on commit 75787c5

Please sign in to comment.