Skip to content

Commit

Permalink
Remove memset for the case no any mask
Browse files Browse the repository at this point in the history
  • Loading branch information
yihonglyu committed Mar 7, 2024
1 parent 1ce5bfb commit 92870eb
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,6 @@ class AttentionCPUBase : public AttentionBase {
if (mask_data != nullptr) {
PrepareMask(mask_index, mask_index_dims, mask_data,
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
} else { // no any mask
const int memset_loop_len = batch_size * num_heads_;
const double memset_cost = static_cast<double>(sequence_length) * total_sequence_length;

ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;
memset(output, 0, static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}
});
}

const int loop_len = batch_size * num_heads_;
Expand Down Expand Up @@ -188,7 +177,7 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + q_input_chunk_length * i, k, 1.0,
Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0 : 0.0,
output, nullptr);

if (relative_position_bias_data != nullptr) {
Expand Down

0 comments on commit 92870eb

Please sign in to comment.