Skip to content

Commit

Permalink
Parallel Transpose_BSNH_to_BNSH
Browse files Browse the repository at this point in the history
Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end speedup of
1.021 in the OCR model through parallelization of the Transpose_BSNH_to_BNSH
operation.
  • Loading branch information
yihonglyu committed Feb 4, 2024
1 parent 68b6064 commit db0cdb9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv,

// Transpose Q/K/V from BxSxNxH to BxNxSxH
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
OrtValue& qkv_transposed) {
OrtValue& qkv_transposed,
concurrency::ThreadPool* tp = nullptr) {
std::vector<size_t> permutations({0, 2, 1, 3});
gsl::span<const size_t> permutations_span{permutations};
size_t from = 2, to = 1;
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to);
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
return Status::OK();
}

Expand Down Expand Up @@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));

// Transpose Q from BxSxNxH to BxNxSxH
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed));
auto tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));

return Status::OK();
}
Expand Down

0 comments on commit db0cdb9

Please sign in to comment.