From db0cdb9eecf2eea77d85ab305799c1a314a5cac1 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Thu, 1 Feb 2024 21:19:35 +0000 Subject: [PATCH] Parallel Transpose_BSNH_to_BNSH Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end speedup of 1.021 in the OCR model through parallelization of the Transpose_BSNH_to_BNSH operation. --- onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index eb25d0fd7cc1e..c4e4b4ec707fb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv, // Transpose Q/K/V from BxSxNxH to BxNxSxH Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed) { + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr) { std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to); + SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); return Status::OK(); } @@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); return Status::OK(); }