diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 8c1e69a68ca7e..c7760692eed00 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -241,6 +241,7 @@ export declare namespace InferenceSession { export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; deviceType?: 'cpu'|'gpu'; + numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 02ff229cc4954..45ea48a2df209 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -75,6 +75,19 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`); } } + if (webnnOptions?.numThreads) { + let numThreads = webnnOptions.numThreads; + // Just ignore invalid webnnOptions.numThreads. + if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) { + numThreads = 0; + } + const keyDataOffset = allocWasmString('numThreads', allocs); + const valueDataOffset = allocWasmString(numThreads.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`); + } + } if (webnnOptions?.powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 02a3d16b5b64f..4da54aaad3a33 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,8 +17,8 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider( - const std::string& webnn_device_flags, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, + const std::string& webnn_threads_number, const std::string& webnn_power_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { // Create WebNN context and graph builder. const emscripten::val ml = emscripten::val::global("navigator")["ml"]; @@ -31,6 +31,10 @@ WebNNExecutionProvider::WebNNExecutionProvider( if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; + // Set "numThreads" if it's not default 0. + if (webnn_threads_number.compare("0") != 0) { + context_options.set("numThreads", stoi(webnn_threads_number)); + } } else { preferred_layout_ = DataLayout::NCHW; wnn_device_type_ = webnn::WebnnDeviceType::GPU; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index f8d9a1c33f6c8..13a475327dc0c 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -18,7 +18,8 @@ class Model; class WebNNExecutionProvider : public IExecutionProvider { public: - WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags); + WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, + const std::string& webnn_power_flags); virtual ~WebNNExecutionProvider(); std::vector> diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 4d6b04c8e76d8..11acec8b1f354 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,23 +10,26 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {} + WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, + const std::string& webnn_power_flags) + : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; + std::string webnn_threads_number_; std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { return std::make_shared(provider_options.at("deviceType"), + provider_options.at("numThreads"), provider_options.at("powerPreference")); } diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 4649ac35c3647..cb51a0c460d9a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -104,8 +104,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); + std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; + provider_options["numThreads"] = numThreads; provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else