Skip to content

Commit

Permalink
[WebNN] Add ScatterElements and GatherElements (#22534)
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 authored Oct 30, 2024
1 parent 86b3b89 commit 46ff240
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 4 deletions.
2 changes: 2 additions & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape ||| |
| Floor | ai.onnx(7-12, 13+) | floor ||| |
| Gather | ai.onnx(7-10, 11-12, 13+) | gather ||| |
| GatherElements | ai.onnx(11-12, 13+) | gatherElements ||| |
| GatherND | ai.onnx(11, 12, 13+) | gatherND ||| Only supports 'batch_dims' == 0 |
| Gelu | ai.onnx(20+) | gelu ||| |
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm ||| Only supports 1-D 'C' input |
Expand Down Expand Up @@ -80,6 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
| ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements ||| Only supports 'reduction' == 'none' |
| ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND ||| Only supports 'reduction' == 'none' |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
Expand Down
8 changes: 4 additions & 4 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2254,10 +2254,10 @@
// // "test_round",
// // "test_scan_sum",
// // "test_scan9_sum",
// // "test_scatter_elements_with_axis",
// // "test_scatter_elements_with_duplicate_indices",
// // "test_scatter_elements_with_negative_indices",
// // "test_scatter_elements_without_axis",
"test_scatter_elements_with_axis",
"test_scatter_elements_with_duplicate_indices",
"test_scatter_elements_with_negative_indices",
"test_scatter_elements_without_axis",
// // "test_scatter_with_axis",
// // "test_scatter_without_axis",
"test_scatternd_add",
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Flatten", "reshape"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherElements", "gatherElements"},
{"GatherND", "gatherND"},
{"Gelu", "gelu"},
{"Gemm", "gemm"},
Expand Down Expand Up @@ -261,6 +262,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"ScatterElements", "scatterElements"},
{"ScatterND", "scatterND"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class GatherElementsOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Add operator related.

Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const size_t rank = input_shape.size();
NodeAttrHelper helper(node);
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
options.set("axis", axis);

emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gatherElements", data, indices, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GatherElementsOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class ScatterElementsOpBuilder : public BaseOpBuilder {
// Add operator related.
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Add operator related.

Status ScatterElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val data = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val updates = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const size_t rank = input_shape.size();
NodeAttrHelper helper(node);
const uint32_t axis = static_cast<uint32_t>(HandleNegativeAxis(helper.Get("axis", 0), rank));
options.set("axis", axis);

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("scatterElements", data, indices, updates, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
if (helper.Get("reduction", "none") != "none") {
LOGS(logger, VERBOSE) << "ScatterElements: WebNN only supports reduction type none (default)";
return false;
}

return true;
}

bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& updates = *node.InputDefs()[2];
const auto& op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
int32_t updates_type;
if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger) ||
!GetType(updates, updates_type, logger)) {
return false;
}

if (data_type != updates_type) {
return false;
}

return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) &&
IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger);
}

void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ScatterElementsOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateGatherOpBuilder("Gather", op_registrations);
}

{ // GatherElements
CreateGatherElementsOpBuilder("GatherElements", op_registrations);
}

{ // GatherND
CreateGatherNDOpBuilder("GatherND", op_registrations);
}
Expand Down Expand Up @@ -174,6 +178,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateResizeOpBuilder("Resize", op_registrations);
}

{ // ScatterElements
CreateScatterElementsOpBuilder("ScatterElements", op_registrations);
}

{ // ScatterND
CreateScatterNDOpBuilder("ScatterND", op_registrations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderR
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand All @@ -44,6 +45,7 @@ void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r
void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 46ff240

Please sign in to comment.