From 035d40f7adb337ccebe8c26d3e9628255d6e4382 Mon Sep 17 00:00:00 2001 From: mingmingtasd Date: Thu, 19 Sep 2024 13:56:27 +0800 Subject: [PATCH 1/3] support tile in WebNN EP --- js/web/docs/webnn-operators.md | 1 + js/web/test/suite-test-list.jsonc | 4 +- .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/tile_op_builder.cc | 98 +++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 6 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 6fd4f9af20432..a001248f68c85 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -91,6 +91,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | | | Tan | ai.onnx(7+) | tan | ✓ | ✓ | | | Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | | +| Tile | ai.onnx(7-12, 13+) | tile | ✗ | ✓ | | | Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | | | Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 5c1e2e27a6eff..de70398081b9c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -2498,8 +2498,8 @@ // "test_thresholdedrelu_default", // "test_thresholdedrelu_example", // "test_thresholdedrelu", - // "test_tile_precomputed", - // "test_tile", + "test_tile_precomputed", + "test_tile", // // "test_top_k_negative_axis", // // "test_top_k_smallest", // // "test_top_k", diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index dd4a8acc662ef..b98d4b36fb0b6 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -234,6 +234,7 @@ static const InlinedHashMap op_map = { {"Sub", "sub"}, {"Tan", "tan"}, {"Tanh", "tanh"}, + {"Tile", "tile"}, {"Transpose", "transpose"}, {"Trilu", "triangular"}, {"Unsqueeze", "reshape"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc new file mode 100644 index 0000000000000..5ee56422675be --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class TileOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; +}; + +// Add operator related. + +void TileOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); +} + +Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& repetitions_initializer = *initializers.at(input_defs[1]->Name()); + const int64_t* raw_repetitions_data = repetitions_initializer.int64_data().empty() + ? reinterpret_cast(repetitions_initializer.raw_data().data()) + : repetitions_initializer.int64_data().data(); + const auto size = repetitions_initializer.dims()[0]; + TensorShapeVector repetitions_data{raw_repetitions_data, raw_repetitions_data + size}; + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + std::vector repetitions; + std::transform(repetitions_data.cbegin(), repetitions_data.cend(), + std::back_inserter(repetitions), + [](int64_t repetition) -> uint32_t { return SafeInt(repetition); }); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + emscripten::val output = model_builder.GetBuilder().call("tile", + input, + emscripten::val::array(repetitions), + options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool TileOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& repetitions_name = input_defs[1]->Name(); + if (!Contains(initializers, repetitions_name)) { + LOGS(logger, VERBOSE) << "Repetitions of tile must be a constant initializer"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; + return false; + } + + return true; +} + +void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 93a2b232a7d51..baf475f7f8a75 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -187,6 +187,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSqueezeUnsqueezeOpBuilder("Unsqueeze", op_registrations); } + { // Tile + CreateTileOpBuilder("Tile", op_registrations); + } + { // Transpose CreateTransposeOpBuilder("Transpose", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 61fe6d936e9d1..1fe501959b673 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -48,6 +48,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); From c27717539853427ba79c3e2f4869103a1dbe2c2e Mon Sep 17 00:00:00 2001 From: mingmingtasd Date: Thu, 19 Sep 2024 16:11:20 +0800 Subject: [PATCH 2/3] address wanming's comments to use uint32_t repeats --- js/web/docs/webnn-operators.md | 2 +- .../core/providers/webnn/builders/impl/tile_op_builder.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index a001248f68c85..95ee9d26ee40c 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -91,7 +91,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | | | Tan | ai.onnx(7+) | tan | ✓ | ✓ | | | Tanh | ai.onnx(7-12, 13+) | tanh | ✓ | ✓ | | -| Tile | ai.onnx(7-12, 13+) | tile | ✗ | ✓ | | +| Tile | ai.onnx(7-12, 13+) | tile | ✗ | ✓ | Input 'repeats' should be a constant | | Transpose | ai.onnx(7-12, 13-20, 21+) | transpose | ✓ | ✓ | | | Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant | | Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 5ee56422675be..f6f47e7567721 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -52,7 +52,7 @@ Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector repetitions; std::transform(repetitions_data.cbegin(), repetitions_data.cend(), std::back_inserter(repetitions), - [](int64_t repetition) -> uint32_t { return SafeInt(repetition); }); + [](int64_t repetition) -> uint32_t { return SafeInt(repetition); }); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); From 1a314e9b05e1ab420878ad553f52cb16140f55df Mon Sep 17 00:00:00 2001 From: mingmingtasd Date: Thu, 19 Sep 2024 16:19:23 +0800 Subject: [PATCH 3/3] fix nits --- .../core/providers/webnn/builders/impl/tile_op_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index f6f47e7567721..672a3a510d54d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -49,7 +49,7 @@ Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto size = repetitions_initializer.dims()[0]; TensorShapeVector repetitions_data{raw_repetitions_data, raw_repetitions_data + size}; emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - std::vector repetitions; + std::vector repetitions; std::transform(repetitions_data.cbegin(), repetitions_data.cend(), std::back_inserter(repetitions), [](int64_t repetition) -> uint32_t { return SafeInt(repetition); });