Skip to content

Commit

Permalink
[WebNN EP] Update Op Softmax for readability
Browse files Browse the repository at this point in the history
Improve readability by fixing misplaced comments and utilizing std::rotate.
  • Loading branch information
zesongw committed Oct 7, 2023
1 parent 7201def commit 92a725f
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
permutation.erase(permutation.begin() + axis);
permutation.push_back(axis);
std::rotate(permutation.begin() + axis, permutation.begin() + axis + 1, permutation.end());
options.set("permutation", emscripten::val::array(permutation));
input = model_builder.GetBuilder().call<emscripten::val>("transpose", input, options);
}
Expand All @@ -87,7 +86,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);

// Transpose back to the axis.
// Restore from 2-D to the original shape.
if (input_shape.size() != 2) {
std::vector<uint32_t> new_shape;
std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape),
Expand All @@ -98,13 +97,12 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output, emscripten::val::array(new_shape));
}
// Reshape to the original shape.
// Restore the corresponding axis back to the initial position from the last position.
if (axis != static_cast<int>(input_shape.size() - 1)) {
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
permutation.pop_back();
permutation.insert(permutation.begin() + axis, input_shape.size() - 1);
std::rotate(permutation.rbegin(), permutation.rbegin() + 1, permutation.rend() - axis);
options.set("permutation", emscripten::val::array(permutation));
output = model_builder.GetBuilder().call<emscripten::val>("transpose", output, options);
}
Expand Down

0 comments on commit 92a725f

Please sign in to comment.