Skip to content

Commit

Permalink
[WebNN EP] Support Tile operator (#22148)
Browse files Browse the repository at this point in the history
PTAL, thanks! @Honry , @fdwr thanks!
  • Loading branch information
mingmingtasd authored Oct 5, 2024
1 parent 98a7590 commit 004bd36
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 2 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Sub | ai.onnx(7-12, 13, 14+) | sub ||| |
| Tan | ai.onnx(7+) | tan ||| |
| Tanh | ai.onnx(7-12, 13+) | tanh ||| |
| Tile | ai.onnx(7-12, 13+) | tile ||| Input 'repeats' should be a constant |
| Transpose | ai.onnx(7-12, 13-20, 21+) | transpose ||| |
| Trilu | ai.onnx(14+) | triangular ||| Input 'k' (option 'diagonal' for WebNN) if present should be a constant |
| Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape ||| |
Expand Down
4 changes: 2 additions & 2 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -2498,8 +2498,8 @@
// "test_thresholdedrelu_default",
// "test_thresholdedrelu_example",
// "test_thresholdedrelu",
// "test_tile_precomputed",
// "test_tile",
"test_tile_precomputed",
"test_tile",
// // "test_top_k_negative_axis",
// // "test_top_k_smallest",
// // "test_top_k",
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Sub", "sub"},
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Tile", "tile"},
{"Transpose", "transpose"},
{"Trilu", "triangular"},
{"Unsqueeze", "reshape"},
Expand Down
98 changes: 98 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/framework/tensorprotoutils.h"
#include "core/optimizer/initializer.h"
#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class TileOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// Add operator related.

void TileOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name());
}

Status TileOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& initializers(model_builder.GetInitializerTensors());
const auto& repetitions_initializer = *initializers.at(input_defs[1]->Name());
const int64_t* raw_repetitions_data = repetitions_initializer.int64_data().empty()
? reinterpret_cast<const int64_t*>(repetitions_initializer.raw_data().data())
: repetitions_initializer.int64_data().data();
const auto size = repetitions_initializer.dims()[0];
TensorShapeVector repetitions_data{raw_repetitions_data, raw_repetitions_data + size};
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
std::vector<uint32_t> repetitions;
std::transform(repetitions_data.cbegin(), repetitions_data.cend(),
std::back_inserter(repetitions),
[](int64_t repetition) -> uint32_t { return SafeInt<uint32_t>(repetition); });

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("tile",
input,
emscripten::val::array(repetitions),
options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}

// Operator support related.

bool TileOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& repetitions_name = input_defs[1]->Name();
if (!Contains(initializers, repetitions_name)) {
LOGS(logger, VERBOSE) << "Repetitions of tile must be a constant initializer";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Tile does not support empty input shape";
return false;
}

return true;
}

void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<TileOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateSqueezeUnsqueezeOpBuilder("Unsqueeze", op_registrations);
}

{ // Tile
CreateTileOpBuilder("Tile", op_registrations);
}

{ // Transpose
CreateTransposeOpBuilder("Transpose", op_registrations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations&
void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateSqueezeUnsqueezeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateTriangularOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 004bd36

Please sign in to comment.