diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index c617533319a18..67de265f4e907 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -145,45 +145,40 @@ class AttentionCPUBase : public AttentionBase { const int loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; - // The cost of Gemm - const double cost = static_cast(head_size) * sequence_length * total_sequence_length; - - ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i) / num_heads_; - - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - const int mask_offset = batch_index * sequence_length * total_sequence_length; - T* output = attention_probs + output_offset; - - // Broadcast mask data: (Bx)SxT -> (BxNx)SxT - if (mask_data != nullptr) { - memcpy(output, - mask_data + mask_offset, - static_cast(sequence_length) * total_sequence_length * sizeof(T)); - } + ThreadPool::TrySimpleParallelFor(tp, loop_len, [&](std::ptrdiff_t batch_head_id) { + const int batch_index = static_cast(batch_head_id) / num_heads_; + + const int output_offset = static_cast(batch_head_id) * sequence_length * total_sequence_length; + const int mask_offset = batch_index * sequence_length * total_sequence_length; + T* output = attention_probs + output_offset; + + // Broadcast mask data: (Bx)SxT -> (BxNx)SxT + if (mask_data != nullptr) { + memcpy(output, + mask_data + mask_offset, + static_cast(sequence_length) * total_sequence_length * sizeof(T)); + } - const T* k = K + kv_input_chunk_length * i; - if (nullptr != present) { - // Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH - k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i); - } else if (nullptr != present_key) { - k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i); - } + const T* k = K + kv_input_chunk_length * batch_head_id; + if (nullptr != present) { + // Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH + k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, batch_head_id); + } else if (nullptr != present_key) { + k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, batch_head_id); + } - // Compute Q*K' + AttentionMask - // original transposed each iteration - // A: Q (B x N x) S x H (B x N x) S x H S x H - // B: K' (B x N x) T x H (B x N x) H x T H x T - // C: attention_probs (B x N x) S x T (B x N x) S x T S x T - math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, - output, nullptr); - - if (relative_position_bias_data != nullptr) { - for (int j = 0; j < sequence_length * total_sequence_length; j++) { - output[j] += relative_position_bias_data[output_offset + j]; - } + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, + Q + q_input_chunk_length * batch_head_id, k, mask_data != nullptr ? 1.0f : 0.0f, + output, nullptr); + + if (relative_position_bias_data != nullptr) { + for (int j = 0; j < sequence_length * total_sequence_length; j++) { + output[j] += relative_position_bias_data[output_offset + j]; } } }); @@ -227,37 +222,32 @@ class AttentionCPUBase : public AttentionBase { present += SafeInt(batch_size) * num_heads_ * total_sequence_length * v_head_size; } - const double cost = - static_cast(sequence_length) * static_cast(v_head_size) * static_cast(sequence_length); - - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const T* v = V + kv_input_chunk_length * i; - if (nullptr != present) { - // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v - v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); - } else if (nullptr != present_value) { - v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); - } + ThreadPool::TrySimpleParallelFor(tp, SafeInt(batch_size) * num_heads_, [&](std::ptrdiff_t batch_head_id) { + const T* v = V + kv_input_chunk_length * batch_head_id; + if (nullptr != present) { + // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v + v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, batch_head_id); + } else if (nullptr != present_value) { + v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, batch_head_id); + } - T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; - math::MatMul(sequence_length, v_head_size, total_sequence_length, - attention_probs + attention_probs_offset, - v, current_tmp_data, nullptr); - - // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - T* src = current_tmp_data; - ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; - T* dest = output + dest_offset; - const auto bytes_to_copy = SafeInt(v_head_size) * sizeof(T); - for (int j = 0; j < sequence_length; j++) { - memcpy(dest, src, bytes_to_copy); - src += v_head_size; - dest += v_hidden_size; - } + T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * batch_head_id; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * batch_head_id; + math::MatMul(sequence_length, v_head_size, total_sequence_length, + attention_probs + attention_probs_offset, + v, current_tmp_data, nullptr); + + // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) + const int batch_index = static_cast(batch_head_id / num_heads_); + const int head_index = static_cast(batch_head_id % num_heads_); + T* src = current_tmp_data; + ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; + T* dest = output + dest_offset; + const auto bytes_to_copy = SafeInt(v_head_size) * sizeof(T); + for (int j = 0; j < sequence_length; j++) { + memcpy(dest, src, bytes_to_copy); + src += v_head_size; + dest += v_hidden_size; } }); }