Skip to content

Commit

Permalink
[WebNN EP] Move MLContext creation to a singleton
Browse files Browse the repository at this point in the history
In order to enable I/O Binding with the upcoming MLBuffer changes to
the WebNN specification, we need to share the same MLContext across
multiple sessions. This is because MLBuffers are tied to the MLContext
where they were created.
  • Loading branch information
egalli committed May 10, 2024
1 parent 5a18818 commit 81695d1
Showing 1 changed file with 98 additions and 13 deletions.
111 changes: 98 additions & 13 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,101 @@
#include "core/graph/graph_viewer.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/common/safeint.h"
#include "core/common/inlined_containers.h"
#include "core/common/hash_combine.h"

#include "builders/model.h"
#include "builders/helper.h"
#include "builders/model_builder.h"

namespace onnxruntime {

struct WebNNContextOptions {
std::optional<std::string> device_type;
std::optional<int> num_threads;
std::optional<std::string> power_preference;

[[nodiscard]] emscripten::val AsVal() const {
emscripten::val options = emscripten::val::object();
if (device_type.has_value()) {
options.set("deviceType", device_type.value());
}
if (num_threads.has_value()) {
options.set("numThreads", num_threads.value());
}
if (power_preference.has_value()) {
options.set("powerPreference", power_preference.value());
}
return options;
}

bool operator==(const WebNNContextOptions& other) const {
return std::tie(device_type, num_threads, power_preference) ==
std::tie(other.device_type, other.num_threads, other.power_preference);
}
};

} // namespace onnxruntime

// Specialize std::hash for WebNNContextOptions.
template <>
struct std::hash<::onnxruntime::WebNNContextOptions> {
size_t operator()(const ::onnxruntime::WebNNContextOptions& options) const {
size_t hash{0xbc9f1d34}; // seed
if (options.device_type.has_value()) {
onnxruntime::HashCombine(options.device_type.value(), hash);
}
if (options.num_threads.has_value()) {
onnxruntime::HashCombine(options.num_threads.value(), hash);
}
if (options.power_preference.has_value()) {
onnxruntime::HashCombine(options.power_preference.value(), hash);
}
return hash;
}
};

namespace onnxruntime {

// WebNNContextManager is a singleton object that is used to create MLContexts.
class WebNNContextManager {
public:
static WebNNContextManager& GetInstance() {
static WebNNContextManager instance;
return instance;
}

WebNNContextManager(const WebNNContextManager&) = delete;
WebNNContextManager& operator=(const WebNNContextManager&) = delete;
WebNNContextManager(WebNNContextManager&&) = delete;
WebNNContextManager& operator=(WebNNContextManager&&) = delete;

emscripten::val GetContext(const WebNNContextOptions& options) {
auto it = contexts_.find(options);
if (it != contexts_.end()) {
return it->second;
}

emscripten::val ml = emscripten::val::global("navigator")["ml"];
if (!ml.as<bool>()) {
ORT_THROW("Failed to get ml from navigator.");
}

emscripten::val context = ml.call<emscripten::val>("createContext", options.AsVal()).await();
if (!context.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
contexts_.emplace(options, context);
return context;
}

private:
WebNNContextManager() = default;
~WebNNContextManager() = default;

InlinedHashMap<WebNNContextOptions, emscripten::val> contexts_;
};

WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
Expand All @@ -25,34 +113,31 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (!ml.as<bool>()) {
ORT_THROW("Failed to get ml from navigator.");
}
emscripten::val context_options = emscripten::val::object();
context_options.set("deviceType", emscripten::val(webnn_device_flags));
WebNNContextOptions context_options;
context_options.device_type = webnn_device_flags;
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
if (webnn_device_flags.compare("cpu") == 0) {
if (webnn_device_flags == "cpu") {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
// Set "numThreads" if it's not default 0.
if (webnn_threads_number.compare("0") != 0) {
context_options.set("numThreads", stoi(webnn_threads_number));
if (webnn_threads_number != "0") {
context_options.num_threads = stoi(webnn_threads_number);
}
} else {
preferred_layout_ = DataLayout::NCHW;
if (webnn_device_flags.compare("gpu") == 0) {
if (webnn_device_flags == "gpu") {
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
} else if (webnn_device_flags.compare("npu") == 0) {
} else if (webnn_device_flags == "npu") {
wnn_device_type_ = webnn::WebnnDeviceType::NPU;
} else {
ORT_THROW("Unknown WebNN deviceType.");
}
}
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
if (webnn_power_flags != "default") {
context_options.power_preference = webnn_power_flags;
}

wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
wnn_context_ = WebNNContextManager::GetInstance().GetContext(context_options);
wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(wnn_context_);
if (!wnn_builder_.as<bool>()) {
ORT_THROW("Failed to create WebNN builder.");
Expand Down

0 comments on commit 81695d1

Please sign in to comment.