From 4c3c809bdbcde4ea96f0a31a242ca6877a10c40a Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Mon, 8 Jul 2024 10:19:39 -0700 Subject: [PATCH] [js/webnn] Enable user-supplied MLContext (#20600) ### Description This PR enables the API added in #20816 as well as moving context creation to JS. ### Motivation and Context In order to enable I/O Binding with the upcoming [MLBuffer](https://github.com/webmachinelearning/webnn/issues/542) API in the WebNN specification, we need to share the same `MLContext` across multiple sessions. This is because `MLBuffer`s are restricted to the `MLContext` where they were created. This PR enables developers to use the same `MLContext` across multiple sessions. --- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 401 ++++++++++++++++++ js/web/lib/wasm/session-options.ts | 22 - js/web/lib/wasm/wasm-core-impl.ts | 37 ++ js/web/lib/wasm/wasm-types.ts | 14 +- .../webnn/webnn_execution_provider.cc | 19 +- .../webnn/webnn_execution_provider.h | 3 +- .../providers/webnn/webnn_provider_factory.cc | 13 +- .../core/session/provider_registration.cc | 4 - 8 files changed, 458 insertions(+), 55 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webnn/webnn.d.ts diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts new file mode 100644 index 0000000000000..f8a1e1966fd4c --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +interface NavigatorML { + readonly ml: ML; +} +interface Navigator extends NavigatorML {} +interface WorkerNavigator extends NavigatorML {} +type MLDeviceType = 'cpu'|'gpu'|'npu'; +type MLPowerPreference = 'default'|'high-performance'|'low-power'; +interface MLContextOptions { + deviceType?: MLDeviceType; + powerPreference?: MLPowerPreference; + numThreads?: number; +} +interface ML { + createContext(options?: MLContextOptions): Promise; + createContext(gpuDevice: GPUDevice): Promise; +} +type MLNamedArrayBufferViews = Record; +interface MLComputeResult { + inputs?: MLNamedArrayBufferViews; + outputs?: MLNamedArrayBufferViews; +} +interface MLContext { + compute(graph: MLGraph, inputs: MLNamedArrayBufferViews, outputs: MLNamedArrayBufferViews): Promise; +} +interface MLGraph {} +type MLInputOperandLayout = 'nchw'|'nhwc'; +type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; +interface MLOperandDescriptor { + dataType: MLOperandDataType; + dimensions?: number[]; +} +interface MLOperand { + dataType(): MLOperandDataType; + shape(): number[]; +} +interface MLActivation {} +type MLNamedOperands = Record; +interface MLGraphBuilder { + // eslint-disable-next-line @typescript-eslint/no-misused-new + new(context: MLContext): MLGraphBuilder; + input(name: string, descriptor: MLOperandDescriptor): MLOperand; + constant(descriptor: MLOperandDescriptor, bufferView: ArrayBufferView): MLOperand; + constant(type: MLOperandDataType, value: number): MLOperand; + build(outputs: MLNamedOperands): Promise; +} +interface MLArgMinMaxOptions { + axes?: number[]; + keepDimensions?: boolean; + selectLastIndex?: boolean; +} +interface MLGraphBuilder { + argMin(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; + argMax(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; +} +interface MLBatchNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axis?: number; + epsilon?: number; +} +interface MLGraphBuilder { + batchNormalization(input: MLOperand, mean: MLOperand, variance: MLOperand, options?: MLBatchNormalizationOptions): + MLOperand; +} +interface MLGraphBuilder { + cast(input: MLOperand, type: MLOperandDataType): MLOperand; +} +interface MLClampOptions { + minValue?: number; + maxValue?: number; +} +interface MLGraphBuilder { + clamp(input: MLOperand, options?: MLClampOptions): MLOperand; + clamp(options?: MLClampOptions): MLActivation; +} +interface MLGraphBuilder { + concat(inputs: MLOperand[], axis: number): MLOperand; +} +type MLConv2dFilterOperandLayout = 'oihw'|'hwio'|'ohwi'|'ihwo'; +interface MLConv2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConv2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + conv2d(input: MLOperand, filter: MLOperand, options?: MLConv2dOptions): MLOperand; +} +type MLConvTranspose2dFilterOperandLayout = 'iohw'|'hwoi'|'ohwi'; +interface MLConvTranspose2dOptions { + padding?: number[]; + strides?: number[]; + dilations?: number[]; + outputPadding?: number[]; + outputSizes?: number[]; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConvTranspose2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + convTranspose2d(input: MLOperand, filter: MLOperand, options?: MLConvTranspose2dOptions): MLOperand; +} +interface MLGraphBuilder { + add(a: MLOperand, b: MLOperand): MLOperand; + sub(a: MLOperand, b: MLOperand): MLOperand; + mul(a: MLOperand, b: MLOperand): MLOperand; + div(a: MLOperand, b: MLOperand): MLOperand; + max(a: MLOperand, b: MLOperand): MLOperand; + min(a: MLOperand, b: MLOperand): MLOperand; + pow(a: MLOperand, b: MLOperand): MLOperand; +} +interface MLGraphBuilder { + equal(a: MLOperand, b: MLOperand): MLOperand; + greater(a: MLOperand, b: MLOperand): MLOperand; + greaterOrEqual(a: MLOperand, b: MLOperand): MLOperand; + lesser(a: MLOperand, b: MLOperand): MLOperand; + lesserOrEqual(a: MLOperand, b: MLOperand): MLOperand; + logicalNot(a: MLOperand): MLOperand; +} +interface MLGraphBuilder { + abs(input: MLOperand): MLOperand; + ceil(input: MLOperand): MLOperand; + cos(input: MLOperand): MLOperand; + erf(input: MLOperand): MLOperand; + exp(input: MLOperand): MLOperand; + floor(input: MLOperand): MLOperand; + identity(input: MLOperand): MLOperand; + log(input: MLOperand): MLOperand; + neg(input: MLOperand): MLOperand; + reciprocal(input: MLOperand): MLOperand; + sin(input: MLOperand): MLOperand; + sqrt(input: MLOperand): MLOperand; + tan(input: MLOperand): MLOperand; +} +interface MLEluOptions { + alpha?: number; +} +interface MLGraphBuilder { + elu(input: MLOperand, options?: MLEluOptions): MLOperand; + elu(options?: MLEluOptions): MLActivation; +} +interface MLGraphBuilder { + expand(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGatherOptions { + axis?: number; +} +interface MLGraphBuilder { + gather(input: MLOperand, indices: MLOperand, options?: MLGatherOptions): MLOperand; +} +interface MLGraphBuilder { + gelu(input: MLOperand): MLOperand; + gelu(): MLActivation; +} +interface MLGemmOptions { + c?: MLOperand; + alpha?: number; + beta?: number; + aTranspose?: boolean; + bTranspose?: boolean; +} +interface MLGraphBuilder { + gemm(a: MLOperand, b: MLOperand, options?: MLGemmOptions): MLOperand; +} +type MLGruWeightLayout = 'zrn'|'rzn'; +type MLRecurrentNetworkDirection = 'forward'|'backward'|'both'; +interface MLGruOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + initialHiddenState?: MLOperand; + resetAfter?: boolean; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gru(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLGruOptions): MLOperand[]; +} +interface MLGruCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + resetAfter?: boolean; + layout?: MLGruWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + gruCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, hiddenSize: number, + options?: MLGruCellOptions): MLOperand; +} +interface MLHardSigmoidOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + hardSigmoid(input: MLOperand, options?: MLHardSigmoidOptions): MLOperand; + hardSigmoid(options?: MLHardSigmoidOptions): MLActivation; +} +interface MLGraphBuilder { + hardSwish(input: MLOperand): MLOperand; + hardSwish(): MLActivation; +} +interface MLInstanceNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + epsilon?: number; + layout?: MLInputOperandLayout; +} +interface MLGraphBuilder { + instanceNormalization(input: MLOperand, options?: MLInstanceNormalizationOptions): MLOperand; +} +interface MLLayerNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axes?: number[]; + epsilon?: number; +} +interface MLGraphBuilder { + layerNormalization(input: MLOperand, options?: MLLayerNormalizationOptions): MLOperand; +} +interface MLLeakyReluOptions { + alpha?: number; +} +interface MLGraphBuilder { + leakyRelu(input: MLOperand, options?: MLLeakyReluOptions): MLOperand; + leakyRelu(options?: MLLeakyReluOptions): MLActivation; +} +interface MLLinearOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + linear(input: MLOperand, options?: MLLinearOptions): MLOperand; + linear(options?: MLLinearOptions): MLActivation; +} +type MLLstmWeightLayout = 'iofg'|'ifgo'; +interface MLLstmOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + initialHiddenState?: MLOperand; + initialCellState?: MLOperand; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstm( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, + options?: MLLstmOptions): MLOperand[]; +} +interface MLLstmCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + layout?: MLLstmWeightLayout; + activations?: MLActivation[]; +} +interface MLGraphBuilder { + lstmCell( + input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, cellState: MLOperand, + hiddenSize: number, options?: MLLstmCellOptions): MLOperand[]; +} +interface MLGraphBuilder { + matmul(a: MLOperand, b: MLOperand): MLOperand; +} +type MLPaddingMode = 'constant'|'edge'|'reflection'|'symmetric'; +interface MLPadOptions { + mode?: MLPaddingMode; + value?: number; +} +interface MLGraphBuilder { + pad(input: MLOperand, beginningPadding: number[], endingPadding: number[], options?: MLPadOptions): MLOperand; +} +type MLRoundingType = 'floor'|'ceil'; +interface MLPool2dOptions { + windowDimensions?: number[]; + padding?: number[]; + strides?: number[]; + dilations?: number[]; + layout?: MLInputOperandLayout; + roundingType?: MLRoundingType; + outputSizes?: number[]; +} +interface MLGraphBuilder { + averagePool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + l2Pool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + maxPool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; +} +interface MLGraphBuilder { + prelu(input: MLOperand, slope: MLOperand): MLOperand; +} +interface MLReduceOptions { + axes?: number[]; + keepDimensions?: boolean; +} +interface MLGraphBuilder { + reduceL1(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceL2(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSumExp(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMax(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMean(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMin(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceProduct(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSumSquare(input: MLOperand, options?: MLReduceOptions): MLOperand; +} +interface MLGraphBuilder { + relu(input: MLOperand): MLOperand; + relu(): MLActivation; +} +type MLInterpolationMode = 'nearest-neighbor'|'linear'; +interface MLResample2dOptions { + mode?: MLInterpolationMode; + scales?: number[]; + sizes?: number[]; + axes?: number[]; +} +interface MLGraphBuilder { + resample2d(input: MLOperand, options?: MLResample2dOptions): MLOperand; +} +interface MLGraphBuilder { + reshape(input: MLOperand, newShape: number[]): MLOperand; +} +interface MLGraphBuilder { + sigmoid(input: MLOperand): MLOperand; + sigmoid(): MLActivation; +} +interface MLGraphBuilder { + slice(input: MLOperand, starts: number[], sizes: number[]): MLOperand; +} +interface MLGraphBuilder { + softmax(input: MLOperand, axis: number): MLOperand; + softmax(axis: number): MLActivation; +} +interface MLGraphBuilder { + softplus(input: MLOperand): MLOperand; + softplus(): MLActivation; +} +interface MLGraphBuilder { + softsign(input: MLOperand): MLOperand; + softsign(): MLActivation; +} +interface MLSplitOptions { + axis?: number; +} +interface MLGraphBuilder { + split(input: MLOperand, splits: number|number[], options?: MLSplitOptions): MLOperand[]; +} +interface MLGraphBuilder { + tanh(input: MLOperand): MLOperand; + tanh(): MLActivation; +} +interface MLTransposeOptions { + permutation?: number[]; +} +interface MLGraphBuilder { + transpose(input: MLOperand, options?: MLTransposeOptions): MLOperand; +} +interface MLTriangularOptions { + upper?: boolean; + diagonal?: number; +} +interface MLGraphBuilder { + triangular(input: MLOperand, options?: MLTriangularOptions): MLOperand; +} +interface MLGraphBuilder { + where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; +} + +// Experimental MLBuffer interface + +type MLSize64Out = number; +interface MLBuffer { + readonly size: MLSize64Out; + destroy(): void; +} +type MLSize64 = number; +interface MLBufferDescriptor { + size: MLSize64; +} +type MLNamedBuffers = Record; +interface MLContext { + createBuffer(descriptor: MLBufferDescriptor): MLBuffer; + writeBuffer( + dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64, + srcElementSize?: MLSize64): void; + readBuffer(srcBuffer: MLBuffer): Promise; + dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4d2b80e31a47e..f289fc20bba40 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -66,8 +66,6 @@ const setExecutionProviders = const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; - const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(deviceType, allocs); @@ -76,26 +74,6 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (numThreads !== undefined) { - // Just ignore invalid webnnOptions.numThreads. - const validatedNumThreads = - (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : - numThreads; - const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); - } - } - if (powerPreference) { - const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(powerPreference, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); - } - } } break; case 'webgpu': diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a483ff09f0003..905bbf0621014 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; @@ -253,11 +258,43 @@ export const createSession = async( await Promise.all(loadingPromises); } + for (const provider of options?.executionProviders ?? []) { + const providerName = typeof provider === 'string' ? provider : provider.name; + if (providerName === 'webnn') { + if (wasm.currentContext) { + throw new Error('WebNN execution provider is already set.'); + } + if (typeof provider !== 'string') { + const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; + const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (context) { + wasm.currentContext = context as MLContext; + } else if (gpuDevice) { + wasm.currentContext = await navigator.ml.createContext(gpuDevice); + } else { + wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference}); + } + } else { + wasm.currentContext = await navigator.ml.createContext(); + } + break; + } + } + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } + // clear current MLContext after session creation + if (wasm.currentContext) { + wasm.currentContext = undefined; + } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); const enableGraphCapture = !!options?.enableGraphCapture; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9ced89651e844..70728c82e7753 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import type {Tensor} from 'onnxruntime-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -19,7 +24,7 @@ export declare namespace JSEP { type CaptureEndFunction = () => void; type ReplayFunction = () => void; - export interface Module extends WebGpuModule { + export interface Module extends WebGpuModule, WebNnModule { /** * Mount the external data file to an internal map, which will be used during session initialization. * @@ -106,6 +111,13 @@ export declare namespace JSEP { */ jsepOnReleaseSession: (sessionId: number) => void; } + + export interface WebNnModule { + /** + * Active MLContext used to create WebNN EP. + */ + currentContext: MLContext; + } } export interface OrtInferenceAPIs { diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 13ed29667debe..3f8f17048b8a5 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,24 +17,12 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, - const std::string& webnn_threads_number, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // Create WebNN context and graph builder. - const emscripten::val ml = emscripten::val::global("navigator")["ml"]; - if (!ml.as()) { - ORT_THROW("Failed to get ml from navigator."); - } - emscripten::val context_options = emscripten::val::object(); - context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; - // Set "numThreads" if it's not default 0. - if (webnn_threads_number.compare("0") != 0) { - context_options.set("numThreads", stoi(webnn_threads_number)); - } } else { preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { @@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f ORT_THROW("Unknown WebNN deviceType."); } } - if (webnn_power_flags.compare("default") != 0) { - context_options.set("powerPreference", emscripten::val(webnn_power_flags)); - } - wnn_context_ = ml.call("createContext", context_options).await(); + wnn_context_ = emscripten::val::module_property("currentContext"); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index d9cfa5f17c0d4..ec02f249b673c 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -19,8 +19,7 @@ class Model; class WebNNExecutionProvider : public IExecutionProvider { public: - WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags); + explicit WebNNExecutionProvider(const std::string& webnn_device_flags); virtual ~WebNNExecutionProvider(); std::vector> diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 11acec8b1f354..7792aeabaabf2 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,27 +10,22 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} + explicit WebNNProviderFactory(const std::string& webnn_device_flags) + : webnn_device_flags_(webnn_device_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; - std::string webnn_threads_number_; - std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - return std::make_shared(provider_options.at("deviceType"), - provider_options.at("numThreads"), - provider_options.at("powerPreference")); + return std::make_shared(provider_options.at("deviceType")); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 688ee76c591f6..db8b97f6d2c13 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -127,11 +127,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); - std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); - std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; - provider_options["numThreads"] = numThreads; - provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status();