From 8c2ee7b32e28837ef4e109ea200fc7f1404ebb91 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 2 Aug 2024 00:15:31 +0800 Subject: [PATCH] [WebNN EP] Create MLGraphBuilder for every model builder (#21514) Currently WebNN spec only allows MLGraphBuilder.build() to be called once, we need to create new builder for every subgraph in WebNN EP. Spec change: https://github.com/webmachinelearning/webnn/pull/717 --- .../core/providers/webnn/builders/helper.cc | 4 ++-- .../core/providers/webnn/builders/helper.h | 6 +++--- .../providers/webnn/builders/model_builder.cc | 16 ++++++++++---- .../providers/webnn/builders/model_builder.h | 8 +++---- .../webnn/webnn_execution_provider.cc | 21 +++++++------------ .../webnn/webnn_execution_provider.h | 1 - 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 44e6953db438e..d3c1d06818db2 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -84,7 +84,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons } std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -103,7 +103,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { + if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 63fd97abb9a9a..05b783fd17902 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -151,7 +151,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c // Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP. std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, - const emscripten::val& wnn_builder_, + const emscripten::val& wnn_builder, const WebnnDeviceType device_type, const logging::Logger& logger); static const InlinedHashMap op_map = { @@ -241,14 +241,14 @@ static const InlinedHashMap op_map = { {"Where", {"where", true}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder, const WebnnDeviceType device_type) { // Returns false if the op_type is not listed in the op_map. if (op_map.find(op_type) == op_map.end()) { return false; } // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. - if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + if (!wnn_builder[op_map.find(op_type)->second.opName].as()) { return false; } // The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 6b0e1495f552d..b21f717eedc7a 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -20,14 +20,20 @@ namespace onnxruntime { namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type) + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), - wnn_builder_(builder), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) {} + wnn_device_type_(wnn_device_type) { + // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() + // is only allowed to be called once. + wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); + if (!wnn_builder_.as()) { + ORT_THROW("Failed to create WebNN builder."); + } +} Status ModelBuilder::Initialize() { PreprocessInitializers(); @@ -332,6 +338,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } + // Explicitly release the WebNN builder to free memory. + wnn_builder_ = emscripten::val::undefined(); model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 6a1688f16d2a6..b1561f009aa25 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -22,8 +22,8 @@ class IOpBuilder; class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const emscripten::val& builder, - const DataLayout preferred_layout, const WebnnDeviceType wnn_device_type); + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -62,8 +62,8 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; - emscripten::val wnn_context_ = emscripten::val::object(); - emscripten::val wnn_builder_ = emscripten::val::object(); + emscripten::val wnn_context_ = emscripten::val::undefined(); + emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; InlinedHashMap wnn_operands_; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 0da0dfc6dfb26..1cd382c1e75e9 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -38,10 +38,6 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); - if (!wnn_builder_.as()) { - ORT_THROW("Failed to create WebNN builder."); - } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -81,14 +77,13 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto& logger = *GetLogger(); - if (!wnn_builder_.as()) { - // The GetCapability function may be called again after Compile due to the logic in the - // PartitionOnnxFormatModel function (see onnxruntime/core/framework/graph_partitioner.cc). - // We need to re-create the wnn_builder_ here to avoid it's been released in last Compile. - wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + emscripten::val wnn_builder = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); + if (!wnn_builder.as()) { + ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder_, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { return result; @@ -218,9 +213,10 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); + // Build map from input name to its index in input definitions. { InlinedHashMap input_map; @@ -329,9 +325,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector