Skip to content

Commit

Permalink
[WebNN EP] Create MLGraphBuilder for every model builder (microsoft#2…
Browse files Browse the repository at this point in the history
…1514)

Currently WebNN spec only allows MLGraphBuilder.build() to be called
once, we need to create new builder for every subgraph in WebNN EP.

Spec change: webmachinelearning/webnn#717
  • Loading branch information
Honry authored Aug 1, 2024
1 parent 3b73ef2 commit 8c2ee7b
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 28 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons
}

std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
Expand All @@ -103,7 +103,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) {
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser";
supported = IsNodeSupported(*node, graph_viewer, device_type, logger);
}
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder_,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger);
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
Expand Down Expand Up @@ -241,14 +241,14 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Where", {"where", true}},
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_,
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
// Returns false if the op_type is not listed in the op_map.
if (op_map.find(op_type) == op_map.end()) {
return false;
}
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (!wnn_builder_[op_map.find(op_type)->second.opName].as<bool>()) {
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
return false;
}
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ namespace onnxruntime {
namespace webnn {

ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type)
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type)
: graph_viewer_(graph_viewer),
logger_(logger),
wnn_context_(context),
wnn_builder_(builder),
preferred_layout_(preferred_layout),
wnn_device_type_(wnn_device_type) {}
wnn_device_type_(wnn_device_type) {
// Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build()
// is only allowed to be called once.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}

Status ModelBuilder::Initialize() {
PreprocessInitializers();
Expand Down Expand Up @@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class IOpBuilder;
class ModelBuilder {
public:
ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger,
const emscripten::val& context, const emscripten::val& builder,
const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type);
const emscripten::val& context, const DataLayout preferred_layout,
const WebnnDeviceType wnn_device_type);
~ModelBuilder() = default;

Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
Expand Down Expand Up @@ -62,8 +62,8 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;

emscripten::val wnn_context_ = emscripten::val::object();
emscripten::val wnn_builder_ = emscripten::val::object();
emscripten::val wnn_context_ = emscripten::val::undefined();
emscripten::val wnn_builder_ = emscripten::val::undefined();
DataLayout preferred_layout_;
WebnnDeviceType wnn_device_type_;
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
Expand Down
21 changes: 7 additions & 14 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}
}

WebNNExecutionProvider::~WebNNExecutionProvider() {}
Expand Down Expand Up @@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view

const auto& logger = *GetLogger();

if (!wnn_builder_.as<bool>()) {
// The GetCapability function may be called again after Compile due to the logic in the
// PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc).
// We need to re-create the wnn_builder_ here to avoid it's been released in last Compile.
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
}

const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger);
const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger);
wnn_builder = emscripten::val::undefined();

if (node_groups.empty()) {
return result;
Expand Down Expand Up @@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);

webnn::ModelBuilder builder(graph_viewer, *GetLogger(), wnn_context_,
wnn_builder_, preferred_layout_, wnn_device_type_);
preferred_layout_, wnn_device_type_);
std::unique_ptr<webnn::Model> model;
ORT_RETURN_IF_ERROR(builder.Compile(model));

// Build map from input name to its index in input definitions.
{
InlinedHashMap<std::string, size_t> input_map;
Expand Down Expand Up @@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
node_compute_funcs.push_back(compute_info);
}

// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();

return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class WebNNExecutionProvider : public IExecutionProvider {

private:
emscripten::val wnn_context_ = emscripten::val::undefined();
mutable emscripten::val wnn_builder_ = emscripten::val::undefined();

DataLayout preferred_layout_;
webnn::WebnnDeviceType wnn_device_type_;
Expand Down

0 comments on commit 8c2ee7b

Please sign in to comment.