From a03d49cdef8d7ebaa06afb52ca5293e5f94cbfd1 Mon Sep 17 00:00:00 2001 From: shiyi Date: Thu, 31 Oct 2024 01:20:21 +0800 Subject: [PATCH] [WebNN] Add ScatterElements and GatherElements (#22534) --- js/web/docs/webnn-operators.md | 2 + js/web/test/suite-test-list.jsonc | 8 +- .../core/providers/webnn/builders/helper.h | 2 + .../impl/gatherElements_op_builder.cc | 74 ++++++++++++++ .../impl/scatterElements_op_builder.cc | 97 +++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 8 ++ .../webnn/builders/op_builder_factory.h | 2 + 7 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index ecbfe2e9c41b8..004188bd6646c 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | | | Floor | ai.onnx(7-12, 13+) | floor | ✓ | ✓ | | | Gather | ai.onnx(7-10, 11-12, 13+) | gather | ✓ | ✓ | | +| GatherElements | ai.onnx(11-12, 13+) | gatherElements | ✗ | ✓ | | | GatherND | ai.onnx(11, 12, 13+) | gatherND | ✓ | ✓ | Only supports 'batch_dims' == 0 | | Gelu | ai.onnx(20+) | gelu | ✓ | ✓ | | | Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | ✓ | ✓ | Only supports 1-D 'C' input | @@ -80,6 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | | Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | +| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' | | ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index d81d741aa0ebe..d4e7198257c9a 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -2254,10 +2254,10 @@ // // "test_round", // // "test_scan_sum", // // "test_scan9_sum", - // // "test_scatter_elements_with_axis", - // // "test_scatter_elements_with_duplicate_indices", - // // "test_scatter_elements_with_negative_indices", - // // "test_scatter_elements_without_axis", + "test_scatter_elements_with_axis", + "test_scatter_elements_with_duplicate_indices", + "test_scatter_elements_with_negative_indices", + "test_scatter_elements_without_axis", // // "test_scatter_with_axis", // // "test_scatter_without_axis", "test_scatternd_add", diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 5ea16253b6033..adeb34a4c707e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -215,6 +215,7 @@ static const InlinedHashMap op_map = { {"Flatten", "reshape"}, {"Floor", "floor"}, {"Gather", "gather"}, + {"GatherElements", "gatherElements"}, {"GatherND", "gatherND"}, {"Gelu", "gelu"}, {"Gemm", "gemm"}, @@ -261,6 +262,7 @@ static const InlinedHashMap op_map = { {"Relu", "relu"}, {"Reshape", "reshape"}, {"Resize", "resample2d"}, + {"ScatterElements", "scatterElements"}, {"ScatterND", "scatterND"}, {"Shape", "slice"}, {"Sigmoid", "sigmoid"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc new file mode 100644 index 0000000000000..225cfcdfc852c --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GatherElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = model_builder.GetBuilder().call("gatherElements", data, indices, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc new file mode 100644 index 0000000000000..c786aa468736c --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class ScatterElementsOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +// Add operator related. + +Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + emscripten::val data = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name()); + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + const size_t rank = input_shape.size(); + NodeAttrHelper helper(node); + const uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 0), rank)); + options.set("axis", axis); + + emscripten::val output = + model_builder.GetBuilder().call("scatterElements", data, indices, updates, options); + + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); +} + +// Operator support related. + +bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + if (helper.Get("reduction", "none") != "none") { + LOGS(logger, VERBOSE) << "ScatterElements: WebNN only supports reduction type none (default)"; + return false; + } + + return true; +} + +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& data = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; + const auto& updates = *node.InputDefs()[2]; + const auto& op_type = node.OpType(); + + int32_t data_type; + int32_t indices_type; + int32_t updates_type; + if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) || + !GetType(updates, updates_type, logger)) { + return false; + } + + if (data_type != updates_type) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); +} + +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 93fed1704e014..343502e71a200 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGatherOpBuilder("Gather", op_registrations); } + { // GatherElements + CreateGatherElementsOpBuilder("GatherElements", op_registrations); + } + { // GatherND CreateGatherNDOpBuilder("GatherND", op_registrations); } @@ -174,6 +178,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateResizeOpBuilder("Resize", op_registrations); } + { // ScatterElements + CreateScatterElementsOpBuilder("ScatterElements", op_registrations); + } + { // ScatterND CreateScatterNDOpBuilder("ScatterND", op_registrations); } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index 2278571b5a57f..de7f5161071b8 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -44,6 +45,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);