Skip to content

Commit

Permalink
[WebNN EP] Support Softmax since version 13 (microsoft#17714)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
WebNN only supports 2-D input tensor along axis 1. For now, we use
Reshape and Transpose wraparound to get the compatible input.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Enable more models to run on WebNN.
  • Loading branch information
zesongw authored Oct 2, 2023
1 parent ac4e726 commit f158f39
Showing 1 changed file with 69 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const auto input_size = input_shape.size();
// WebNN Softmax only support 2d input shape, reshape input to 2d.
if (input_size != 2) {
NodeAttrHelper helper(node);
NodeAttrHelper helper(node);
if (node.SinceVersion() < 13) {
int32_t axis = helper.Get("axis", 1);
if (node.SinceVersion() >= 13)
// Opset 13 has default value -1.
axis = helper.Get("axis", -1);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
// Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}].
if (input_size != 2) {
int32_t first_dim = static_cast<int32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
int32_t second_dim = static_cast<int32_t>(std::reduce(input_shape.begin() + axis, input_shape.end(),
1, std::multiplies<int64_t>()));
emscripten::val new_shape = emscripten::val::array(std::vector<int32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}

output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);

// Reshape output to the same shape of input.
if (input_size != 2) {
emscripten::val new_shape = emscripten::val::array();
for (size_t i = 0; i < input_size; i++) {
new_shape.call<void>("push", static_cast<int32_t>(input_shape[i]));
}
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, new_shape);
}
} else {
int32_t axis = helper.Get("axis", -1);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
int32_t first_dim = static_cast<int32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
int32_t second_dim = static_cast<int32_t>(std::reduce(input_shape.begin() + axis, input_shape.end(),
1, std::multiplies<int64_t>()));
emscripten::val new_shape = emscripten::val::array(std::vector<int32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}
output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);
// Reshape output to the same shape of input.
if (input_size != 2) {
emscripten::val new_shape = emscripten::val::array();
for (size_t i = 0; i < input_size; i++) {
new_shape.call<void>("push", static_cast<int32_t>(input_shape[i]));
// Wraparound for transpose the target axis to the last.
// WebNN compute the softmax values of the 2-D input tensor along axis 1.
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method
if (axis != static_cast<int>(input_shape.size() - 1)) {
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
permutation.erase(permutation.begin() + axis);
permutation.push_back(axis);
options.set("permutation", emscripten::val::array(permutation));
input = model_builder.GetBuilder().call<emscripten::val>("transpose", input, options);
}
// Wraparound for reshape input tensor to 2-D.
if (input_shape.size() != 2) {
uint32_t first_dim = static_cast<uint32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
first_dim *= static_cast<uint32_t>(std::reduce(input_shape.begin() + axis + 1, input_shape.end(),
1, std::multiplies<int64_t>()));
uint32_t second_dim = static_cast<uint32_t>(input_shape[axis]);
emscripten::val new_shape = emscripten::val::array(std::vector<uint32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}

output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);

// Transpose back to the axis.
if (input_shape.size() != 2) {
std::vector<uint32_t> new_shape;
std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return static_cast<uint32_t>(dim); });
std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return static_cast<uint32_t>(dim); });
new_shape.push_back(static_cast<int32_t>(input_shape[axis]));
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output, emscripten::val::array(new_shape));
}
// Reshape to the original shape.
if (axis != static_cast<int>(input_shape.size() - 1)) {
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
permutation.pop_back();
permutation.insert(permutation.begin() + axis, input_shape.size() - 1);
options.set("permutation", emscripten::val::array(permutation));
output = model_builder.GetBuilder().call<emscripten::val>("transpose", output, options);
}
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, new_shape);
}
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand All @@ -80,14 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
<< input_size << "d shape";
return false;
}
NodeAttrHelper helper(node);
const int64_t axis = helper.Get("axis", 1);
// WebNN softmax only support reshape for the last axis or version before 13.
// TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1).
if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) {
LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis;
return false;
}

return true;
}
Expand Down

0 comments on commit f158f39

Please sign in to comment.