From 004bd36f3d4dbaeb2b0ebb8c9c06e91807d4660e Mon Sep 17 00:00:00 2001 From: mingmingtasd Date: Sat, 5 Oct 2024 15:56:55 +0800 Subject: [PATCH] [WebNN EP] Support Tile operator (#22148) PTAL, thanks! @Honry , @fdwr thanks! --- js/web/docs/webnn-operators.md | 1 + js/web/test/suite-test-list.jsonc | 4 +- .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/tile_op_builder.cc | 98 +++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 6 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 6c50f3752737b..f696264aeead7 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -92,6 +92,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | | | Tan | ai.onnx(7+) | tan | ✓ | ✓ | | | Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | | +| Tile | ai.onnx(7-12, 13+) | tile | ✗ | ✓ | Input 'repeats' should be a constant | | Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | | | Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index ae708467be8a2..dcfc8ccc3928f 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -2498,8 +2498,8 @@ // "test_thresholdedrelu_default", // "test_thresholdedrelu_example", // "test_thresholdedrelu", - // "test_tile_precomputed", - // "test_tile", + "test_tile_precomputed", + "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 529463f0808ad..aecb1f7a03bb9 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -237,6 +237,7 @@ static const InlinedHashMap op_map = { {"Sub", "sub"}, {"Tan", "tan"}, {"Tanh", "tanh"}, + {"Tile", "tile"}, {"Transpose", "transpose"}, {"Trilu", "triangular"}, {"Unsqueeze", "reshape"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc new file mode 100644 index 0000000000000..672a3a510d54d --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TileOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void TileOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& repetitions_initializer = *initializers.at(input_defs[1]->Name()); + const int64_t* raw_repetitions_data = repetitions_initializer.int64_data().empty() + ? reinterpret_cast(repetitions_initializer.raw_data().data()) + : repetitions_initializer.int64_data().data(); + const auto size = repetitions_initializer.dims()[0]; + TensorShapeVector repetitions_data{raw_repetitions_data, raw_repetitions_data + size}; + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector repetitions; + std::transform(repetitions_data.cbegin(), repetitions_data.cend(), + std::back_inserter(repetitions), + [](int64_t repetition) -> uint32_t { return SafeInt(repetition); }); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("tile", + input, + emscripten::val::array(repetitions), + options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool TileOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& repetitions_name = input_defs[1]->Name(); + if (!Contains(initializers, repetitions_name)) { + LOGS(logger, VERBOSE) << "Repetitions of tile must be a constant initializer"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; + return false; + } + + return true; +} + +void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 9df09af01ba67..8baa4790247ec 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -191,6 +191,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSqueezeUnsqueezeOpBuilder("Unsqueeze", op_registrations); } + { // Tile + CreateTileOpBuilder("Tile", op_registrations); + } + { // Transpose CreateTransposeOpBuilder("Transpose", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 398dfc2d3f1c7..990be04d42107 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -49,6 +49,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);