diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index b207b804416aa..6a86ca7aca6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -35,30 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); - // WebNN Softmax only support 2d input shape, reshape input to 2d. - if (input_size != 2) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node); + if (node.SinceVersion() < 13) { int32_t axis = helper.Get("axis", 1); - if (node.SinceVersion() >= 13) - // Opset 13 has default value -1. - axis = helper.Get("axis", -1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + if (input_size != 2) { + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Reshape output to the same shape of input. + if (input_size != 2) { + emscripten::val new_shape = emscripten::val::array(); + for (size_t i = 0; i < input_size; i++) { + new_shape.call("push", static_cast(input_shape[i])); + } + output = model_builder.GetBuilder().call("reshape", output, new_shape); + } + } else { + int32_t axis = helper.Get("axis", -1); axis = static_cast(HandleNegativeAxis(axis, input_size)); - int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, - 1, std::multiplies())); - int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), - 1, std::multiplies())); - emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); - input = model_builder.GetBuilder().call("reshape", input, new_shape); - } - output = model_builder.GetBuilder().call("softmax", input); - // Reshape output to the same shape of input. - if (input_size != 2) { - emscripten::val new_shape = emscripten::val::array(); - for (size_t i = 0; i < input_size; i++) { - new_shape.call("push", static_cast(input_shape[i])); + // Wraparound for transpose the target axis to the last. + // WebNN compute the softmax values of the 2-D input tensor along axis 1. + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.erase(permutation.begin() + axis); + permutation.push_back(axis); + options.set("permutation", emscripten::val::array(permutation)); + input = model_builder.GetBuilder().call("transpose", input, options); + } + // Wraparound for reshape input tensor to 2-D. + if (input_shape.size() != 2) { + uint32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + first_dim *= static_cast(std::reduce(input_shape.begin() + axis + 1, input_shape.end(), + 1, std::multiplies())); + uint32_t second_dim = static_cast(input_shape[axis]); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Transpose back to the axis. + if (input_shape.size() != 2) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + new_shape.push_back(static_cast(input_shape[axis])); + output = model_builder.GetBuilder().call("reshape", + output, emscripten::val::array(new_shape)); + } + // Reshape to the original shape. + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.pop_back(); + permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + options.set("permutation", emscripten::val::array(permutation)); + output = model_builder.GetBuilder().call("transpose", output, options); } - output = model_builder.GetBuilder().call("reshape", output, new_shape); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -80,14 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali << input_size << "d shape"; return false; } - NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis", 1); - // WebNN softmax only support reshape for the last axis or version before 13. - // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). - if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { - LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; - return false; - } return true; }