From decb3852a02ed5bfcd46572e41609df9c2634613 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:21:36 -0700 Subject: [PATCH] refactor: extract shared util function ComputeBroadcastOutputShape (#21940) ### Description This is used in multiple places. --- .../cuda/collective/distributed_expand.cc | 3 +- onnxruntime/core/providers/cann/cann_utils.cc | 29 ---------------- onnxruntime/core/providers/cann/cann_utils.h | 2 -- .../cann/math/binary_elementwise_ops.cc | 4 ++- onnxruntime/core/providers/common.h | 34 +++++++++++++++++++ .../cuda/math/binary_elementwise_ops.cc | 32 ++--------------- .../cuda/math/binary_elementwise_ops.h | 6 ---- .../cuda/math/variadic_elementwise_ops.cc | 3 +- .../core/providers/cuda/tensor/expand.cc | 4 +-- .../core/providers/cuda/tensor/expand.h | 6 ---- 10 files changed, 46 insertions(+), 77 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc index 3cfa3ab959343..170ded752bf20 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -10,6 +10,7 @@ // ORT system. #include "core/providers/cuda/tensor/expand.h" +#include "core/providers/common.h" // std C++. #include @@ -51,7 +52,7 @@ Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; TensorShape original_output_shape(original_output_dims); ORT_ENFORCE( - onnxruntime::cuda::ComputeOutputShape( + onnxruntime::ComputeBroadcastOutputShape( Node().Name(), original_input_shape, original_output_dims, original_output_shape) diff --git a/onnxruntime/core/providers/cann/cann_utils.cc b/onnxruntime/core/providers/cann/cann_utils.cc index b0e61848bac97..95d7a462ca9d9 100644 --- a/onnxruntime/core/providers/cann/cann_utils.cc +++ b/onnxruntime/core/providers/cann/cann_utils.cc @@ -224,34 +224,5 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) { hash_value = hash[0] | (uint64_t(hash[1]) << 32); } -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, - const TensorShape& rhs_shape, TensorShape& out_shape) { - size_t lhs_rank = lhs_shape.NumDimensions(); - size_t rhs_rank = rhs_shape.NumDimensions(); - size_t out_rank = std::max(lhs_rank, rhs_rank); - - std::vector output_dims(out_rank, 0); - for (size_t i = 0; i < out_rank; ++i) { - int64_t lhs_dim = 1; - if (i < lhs_rank) - lhs_dim = lhs_shape[lhs_rank - 1 - i]; - int64_t rhs_dim = 1; - if (i < rhs_rank) - rhs_dim = rhs_shape[rhs_rank - 1 - i]; - int64_t max = std::max(lhs_dim, rhs_dim); - int64_t min = std::min(lhs_dim, rhs_dim); - int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. - if (lhs_dim != out_dim && lhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - if (rhs_dim != out_dim && rhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - output_dims[out_rank - 1 - i] = out_dim; - } - out_shape = TensorShape(output_dims); - return Status::OK(); -} - } // namespace cann } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_utils.h b/onnxruntime/core/providers/cann/cann_utils.h index 5eb1873ae32dd..3739924758ea4 100644 --- a/onnxruntime/core/providers/cann/cann_utils.h +++ b/onnxruntime/core/providers/cann/cann_utils.h @@ -124,8 +124,6 @@ Status aclrtblasGemmEx(aclTransType transA, bool FileExist(const std::string& file_name); void GenerateHashValue(const std::string string, HashValue& hash_value); -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, - const TensorShape& rhs_shape, TensorShape& out_shape); std::unique_ptr CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc index d8911a4caa8c8..a0115243446cc 100644 --- a/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cann/math/binary_elementwise_ops.cc @@ -2,6 +2,8 @@ // Copyright (c) Huawei. All rights reserved. // Licensed under the MIT License. +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/common.h" #include "core/providers/cann/math/binary_elementwise_ops.h" #include #include @@ -20,7 +22,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* ctx, CannPreparation& prepare const Tensor* B = ctx->Input(1); TensorShape output_shape; - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape)); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), A->Shape(), B->Shape(), output_shape)); Tensor* C = ctx->Output(0, output_shape); void* A_data = const_cast(A->DataRaw()); diff --git a/onnxruntime/core/providers/common.h b/onnxruntime/core/providers/common.h index 7576dfba5c85e..aa20b88ef40cc 100644 --- a/onnxruntime/core/providers/common.h +++ b/onnxruntime/core/providers/common.h @@ -180,4 +180,38 @@ T Product(const Container& c) { return accumulate(c.cbegin(), c.cend(), static_cast(1), std::multiplies()); } +/// +/// Compute the output shape for broadcasting the given input shapes of lhs and rhs. +/// +inline Status ComputeBroadcastOutputShape(const std::string& node_name, + const TensorShape& lhs_shape, + const TensorShape& rhs_shape, + TensorShape& out_shape) { + size_t lhs_rank = lhs_shape.NumDimensions(); + size_t rhs_rank = rhs_shape.NumDimensions(); + size_t out_rank = std::max(lhs_rank, rhs_rank); + + std::vector output_dims(out_rank, 0); + for (size_t i = 0; i < out_rank; ++i) { + int64_t lhs_dim = 1; + if (i < lhs_rank) + lhs_dim = lhs_shape[lhs_rank - 1 - i]; + int64_t rhs_dim = 1; + if (i < rhs_rank) + rhs_dim = rhs_shape[rhs_rank - 1 - i]; + int64_t max = std::max(lhs_dim, rhs_dim); + int64_t min = std::min(lhs_dim, rhs_dim); + int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. + if (lhs_dim != out_dim && lhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + if (rhs_dim != out_dim && rhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + output_dims[out_rank - 1 - i] = out_dim; + } + out_shape = TensorShape(output_dims); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index 2c38ce2d3ca9a..8aca8635a24fe 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/common.h" #include "core/providers/cuda/math/binary_elementwise_ops.h" #include "core/providers/cuda/math/binary_elementwise_ops_impl.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" @@ -21,34 +23,6 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, return Status::OK(); } -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) { - size_t lhs_rank = lhs_shape.NumDimensions(); - size_t rhs_rank = rhs_shape.NumDimensions(); - size_t out_rank = std::max(lhs_rank, rhs_rank); - - std::vector output_dims(out_rank, 0); - for (size_t i = 0; i < out_rank; ++i) { - int64_t lhs_dim = 1; - if (i < lhs_rank) - lhs_dim = lhs_shape[lhs_rank - 1 - i]; - int64_t rhs_dim = 1; - if (i < rhs_rank) - rhs_dim = rhs_shape[rhs_rank - 1 - i]; - int64_t max = std::max(lhs_dim, rhs_dim); - int64_t min = std::min(lhs_dim, rhs_dim); - int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. - if (lhs_dim != out_dim && lhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - if (rhs_dim != out_dim && rhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - output_dims[out_rank - 1 - i] = out_dim; - } - out_shape = TensorShape(output_dims); - return Status::OK(); -} - Status BinaryElementwiseBroadcastPrepare( const Tensor* lhs_tensor, const Tensor* rhs_tensor, @@ -77,7 +51,7 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, Bin const auto& rhs_shape = rhs_tensor->Shape(); TensorShape output_shape; - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); auto output_tensor = context->Output(0, output_shape); ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p)); diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h index 048887c326de1..d519658aa3ca5 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.h @@ -108,12 +108,6 @@ struct BinaryElementwisePreparation { } }; -Status ComputeOutputShape( - const std::string& node_name, - const TensorShape& lhs_shape, - const TensorShape& rhs_shape, - TensorShape& out_shape); - Status BinaryElementwiseBroadcastPrepare( const Tensor* lhs_tensor, const Tensor* rhs_tensor, diff --git a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc index 69904edb8130b..f543ba9c975e1 100644 --- a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" +#include "core/providers/common.h" #include "core/providers/cuda/math/variadic_elementwise_ops.h" #include @@ -209,7 +210,7 @@ Status VariadicElementwiseOp TensorShape output_shape; TensorShape previous_output_shape = first_input_tensor.Shape(); for (int index = 1; index < input_count; index++) { - ORT_RETURN_IF_ERROR(ComputeOutputShape( + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape( node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape)); previous_output_shape = output_shape; } diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index 806ecfa1aab17..60e219e6d03e6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()}; TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape)); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape)); auto& output_tensor = *ctx->Output(0, output_shape); if (0 == output_shape.Size()) { return Status::OK(); @@ -202,7 +202,7 @@ std::unique_ptr FuncExpand( TensorShape output_shape(output_dims); ORT_ENFORCE( - ComputeOutputShape( + ComputeBroadcastOutputShape( cuda_kernel->Node().Name(), input_data_tensor->Shape(), output_dims, output_shape) diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index a0b12790017f6..133d17fc78ac0 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -14,12 +14,6 @@ class Expand final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; -Status ComputeOutputShape( - const std::string& node_name, - const TensorShape& lhs_shape, - const TensorShape& rhs_shape, - TensorShape& out_shape); - Status FuncExpand( const CudaKernel* cuda_kernel, OpKernelContext* ctx,