From 81695d14b9a89563bdd3df36e40d518c910382cf Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Fri, 10 May 2024 11:17:54 -0700 Subject: [PATCH] [WebNN EP] Move MLContext creation to a singleton In order to enable I/O Binding with the upcoming MLBuffer changes to the WebNN specification, we need to share the same MLContext across multiple sessions. This is because MLBuffers are tied to the MLContext where they were created. --- .../webnn/webnn_execution_provider.cc | 111 ++++++++++++++++-- 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index d72abf1a721c8..c5707045e989e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -10,6 +10,8 @@ #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" +#include "core/common/inlined_containers.h" +#include "core/common/hash_combine.h" #include "builders/model.h" #include "builders/helper.h" @@ -17,6 +19,92 @@ namespace onnxruntime { +struct WebNNContextOptions { + std::optional device_type; + std::optional num_threads; + std::optional power_preference; + + [[nodiscard]] emscripten::val AsVal() const { + emscripten::val options = emscripten::val::object(); + if (device_type.has_value()) { + options.set("deviceType", device_type.value()); + } + if (num_threads.has_value()) { + options.set("numThreads", num_threads.value()); + } + if (power_preference.has_value()) { + options.set("powerPreference", power_preference.value()); + } + return options; + } + + bool operator==(const WebNNContextOptions& other) const { + return std::tie(device_type, num_threads, power_preference) == + std::tie(other.device_type, other.num_threads, other.power_preference); + } +}; + +} // namespace onnxruntime + +// Specialize std::hash for WebNNContextOptions. +template <> +struct std::hash<::onnxruntime::WebNNContextOptions> { + size_t operator()(const ::onnxruntime::WebNNContextOptions& options) const { + size_t hash{0xbc9f1d34}; // seed + if (options.device_type.has_value()) { + onnxruntime::HashCombine(options.device_type.value(), hash); + } + if (options.num_threads.has_value()) { + onnxruntime::HashCombine(options.num_threads.value(), hash); + } + if (options.power_preference.has_value()) { + onnxruntime::HashCombine(options.power_preference.value(), hash); + } + return hash; + } +}; + +namespace onnxruntime { + +// WebNNContextManager is a singleton object that is used to create MLContexts. +class WebNNContextManager { + public: + static WebNNContextManager& GetInstance() { + static WebNNContextManager instance; + return instance; + } + + WebNNContextManager(const WebNNContextManager&) = delete; + WebNNContextManager& operator=(const WebNNContextManager&) = delete; + WebNNContextManager(WebNNContextManager&&) = delete; + WebNNContextManager& operator=(WebNNContextManager&&) = delete; + + emscripten::val GetContext(const WebNNContextOptions& options) { + auto it = contexts_.find(options); + if (it != contexts_.end()) { + return it->second; + } + + emscripten::val ml = emscripten::val::global("navigator")["ml"]; + if (!ml.as()) { + ORT_THROW("Failed to get ml from navigator."); + } + + emscripten::val context = ml.call("createContext", options.AsVal()).await(); + if (!context.as()) { + ORT_THROW("Failed to create WebNN context."); + } + contexts_.emplace(options, context); + return context; + } + + private: + WebNNContextManager() = default; + ~WebNNContextManager() = default; + + InlinedHashMap contexts_; +}; + WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, const std::string& webnn_power_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { @@ -25,34 +113,31 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!ml.as()) { ORT_THROW("Failed to get ml from navigator."); } - emscripten::val context_options = emscripten::val::object(); - context_options.set("deviceType", emscripten::val(webnn_device_flags)); + WebNNContextOptions context_options; + context_options.device_type = webnn_device_flags; // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. - if (webnn_device_flags.compare("cpu") == 0) { + if (webnn_device_flags == "cpu") { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; // Set "numThreads" if it's not default 0. - if (webnn_threads_number.compare("0") != 0) { - context_options.set("numThreads", stoi(webnn_threads_number)); + if (webnn_threads_number != "0") { + context_options.num_threads = stoi(webnn_threads_number); } } else { preferred_layout_ = DataLayout::NCHW; - if (webnn_device_flags.compare("gpu") == 0) { + if (webnn_device_flags == "gpu") { wnn_device_type_ = webnn::WebnnDeviceType::GPU; - } else if (webnn_device_flags.compare("npu") == 0) { + } else if (webnn_device_flags == "npu") { wnn_device_type_ = webnn::WebnnDeviceType::NPU; } else { ORT_THROW("Unknown WebNN deviceType."); } } - if (webnn_power_flags.compare("default") != 0) { - context_options.set("powerPreference", emscripten::val(webnn_power_flags)); + if (webnn_power_flags != "default") { + context_options.power_preference = webnn_power_flags; } - wnn_context_ = ml.call("createContext", context_options).await(); - if (!wnn_context_.as()) { - ORT_THROW("Failed to create WebNN context."); - } + wnn_context_ = WebNNContextManager::GetInstance().GetContext(context_options); wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_); if (!wnn_builder_.as()) { ORT_THROW("Failed to create WebNN builder.");