Skip to content

Commit

Permalink
[WebNN EP] Support Dropout op (#21586)
Browse files Browse the repository at this point in the history
### Description
WebNN only supports test mode, so we don't care about other inputs or
attributes about training mode, use WebNN's identity op to implement the
Dropout op directly.
  • Loading branch information
Honry authored Aug 2, 2024
1 parent 45b7c41 commit 8c641d7
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d ||| Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group |
| Cos | ai.onnx(7+) | cos ||| |
| Div | ai.onnx(7-12, 13, 14+) | div ||| |
| Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity ||| Only supports test mode |
| Elu | ai.onnx(7+) | elu ||| WebNN CPU backend only supports 'alpha' value is 1.0 |
| Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal ||| |
| Erf | ai.onnx(7-9, 10-12, 13+) | erf ||| |
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Cos", {"cos", true}},
{"Div", {"div", true}},
{"DequantizeLinear", {"dequantizeLinear", false}},
{"Dropout", {"identity", true}},
{"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}},
{"Elu", {"elu", true}},
{"Equal", {"equal", true}},
Expand Down
101 changes: 101 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Intel Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/webnn/builders/helper.h"
#include "core/providers/webnn/builders/model_builder.h"
#include "core/providers/webnn/builders/op_builder_factory.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace webnn {

class DropoutOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// Add operator related.

void DropoutOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
// Skip ratio and training_mode if present.
for (size_t i = 1; i < node.InputDefs().size(); i++) {
const auto input_name = node.InputDefs()[i]->Name();
model_builder.AddInitializerToSkip(input_name);
model_builder.AddInputToSkip(input_name);
}
}

Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& output_defs = node.OutputDefs();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

// WebNN EP only supports test mode. So we don't need to care about other inputs or
// attributes about training mode. Simply use WebNN's identity op to copy the input.
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("identity", input, options);

model_builder.AddOperand(output_defs[0]->Name(), std::move(output));

// If mask output is requested as output it will contain all ones (bool tensor).
if (output_defs.size() > 1) {
std::vector<int64_t> mask_shape;
ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape");
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(mask_shape);

emscripten::val desc = emscripten::val::object();
desc.set("dataType", "uint8");
desc.set("dimensions", emscripten::val::array(dims));
const auto num_elements = narrow<uint32_t>(Product(mask_shape));
emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
ones_buffer.call<void>("fill", 1);

emscripten::val mask_output = model_builder.GetBuilder().call<emscripten::val>("constant", desc, ones_buffer);

emscripten::val options = emscripten::val::object();
options.set("label", output_defs[1]->Name() + "_identity");
// Add additional identity op in case the mask is the output of a WebNN graph,
// beacuse WebNN does not support a constant operand as output.
mask_output = model_builder.GetBuilder().call<emscripten::val>("identity", mask_output, options);
model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output));
}
return Status::OK();
}

// Operator support related.
bool DropoutOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

return true;
}

void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<DropoutOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace webnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateConcatOpBuilder("Concat", op_registrations);
}

{ // Dropout
CreateDropoutOpBuilder("Dropout", op_registrations);
}

{ // Quantize/Dequantize
CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations);
CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down

0 comments on commit 8c641d7

Please sign in to comment.