From 92a725f28f0ec192754ed5c170c22a2080b70b70 Mon Sep 17 00:00:00 2001 From: zesongw Date: Sat, 7 Oct 2023 14:10:41 +0800 Subject: [PATCH] [WebNN EP] Update Op Softmax for readability Improve readability by fixing misplaced comments and utilizing std::rotate. --- .../webnn/builders/impl/softmax_op_builder.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 6a86ca7aca6e9..beee8b1d77cee 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -69,8 +69,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); std::vector permutation(input_shape.size()); std::iota(permutation.begin(), permutation.end(), 0); - permutation.erase(permutation.begin() + axis); - permutation.push_back(axis); + std::rotate(permutation.begin() + axis, permutation.begin() + axis + 1, permutation.end()); options.set("permutation", emscripten::val::array(permutation)); input = model_builder.GetBuilder().call("transpose", input, options); } @@ -87,7 +86,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, output = model_builder.GetBuilder().call("softmax", input); - // Transpose back to the axis. + // Restore from 2-D to the original shape. if (input_shape.size() != 2) { std::vector new_shape; std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), @@ -98,13 +97,12 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); } - // Reshape to the original shape. + // Restore the corresponding axis back to the initial position from the last position. if (axis != static_cast(input_shape.size() - 1)) { emscripten::val options = emscripten::val::object(); std::vector permutation(input_shape.size()); std::iota(permutation.begin(), permutation.end(), 0); - permutation.pop_back(); - permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + std::rotate(permutation.rbegin(), permutation.rbegin() + 1, permutation.rend() - axis); options.set("permutation", emscripten::val::array(permutation)); output = model_builder.GetBuilder().call("transpose", output, options); }