diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts new file mode 100644 index 0000000000000..d7b7f8e07678b --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -0,0 +1,391 @@ +interface NavigatorML { + readonly ml: ML; +} +interface Navigator extends NavigatorML { +} +interface WorkerNavigator extends NavigatorML { +} +type MLDeviceType = "cpu" | "gpu" | "npu"; +type MLPowerPreference = "default" | "high-performance" | "low-power"; +interface MLContextOptions { + deviceType?: MLDeviceType; + powerPreference?: MLPowerPreference; + numThreads?: number; +} +interface ML { + createContext(options?: MLContextOptions): Promise; + createContext(gpuDevice: GPUDevice): Promise; +} +type MLNamedArrayBufferViews = Record; +interface MLComputeResult { + inputs?: MLNamedArrayBufferViews; + outputs?: MLNamedArrayBufferViews; +} +interface MLContext { + compute(graph: MLGraph, inputs: MLNamedArrayBufferViews, outputs: MLNamedArrayBufferViews): Promise; +} +interface MLGraph { +} +type MLInputOperandLayout = "nchw" | "nhwc"; +type MLOperandDataType = "float32" | "float16" | "int32" | "uint32" | "int64" | "uint64" | "int8" | "uint8"; +interface MLOperandDescriptor { + dataType: MLOperandDataType; + dimensions?: Array; +} +interface MLOperand { + dataType(): MLOperandDataType; + shape(): Array; +} +interface MLActivation { +} +type MLNamedOperands = Record; +interface MLGraphBuilder { + new (context: MLContext): MLGraphBuilder; + input(name: string, descriptor: MLOperandDescriptor): MLOperand; + constant(descriptor: MLOperandDescriptor, bufferView: ArrayBufferView): MLOperand; + constant(type: MLOperandDataType, value: number): MLOperand; + build(outputs: MLNamedOperands): Promise; +} +interface MLArgMinMaxOptions { + axes?: Array; + keepDimensions?: boolean; + selectLastIndex?: boolean; +} +interface MLGraphBuilder { + argMin(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; + argMax(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand; +} +interface MLBatchNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axis?: number; + epsilon?: number; +} +interface MLGraphBuilder { + batchNormalization(input: MLOperand, mean: MLOperand, variance: MLOperand, options?: MLBatchNormalizationOptions): MLOperand; +} +interface MLGraphBuilder { + cast(input: MLOperand, type: MLOperandDataType): MLOperand; +} +interface MLClampOptions { + minValue?: number; + maxValue?: number; +} +interface MLGraphBuilder { + clamp(input: MLOperand, options?: MLClampOptions): MLOperand; + clamp(options?: MLClampOptions): MLActivation; +} +interface MLGraphBuilder { + concat(inputs: Array, axis: number): MLOperand; +} +type MLConv2dFilterOperandLayout = "oihw" | "hwio" | "ohwi" | "ihwo"; +interface MLConv2dOptions { + padding?: Array; + strides?: Array; + dilations?: Array; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConv2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + conv2d(input: MLOperand, filter: MLOperand, options?: MLConv2dOptions): MLOperand; +} +type MLConvTranspose2dFilterOperandLayout = "iohw" | "hwoi" | "ohwi"; +interface MLConvTranspose2dOptions { + padding?: Array; + strides?: Array; + dilations?: Array; + outputPadding?: Array; + outputSizes?: Array; + groups?: number; + inputLayout?: MLInputOperandLayout; + filterLayout?: MLConvTranspose2dFilterOperandLayout; + bias?: MLOperand; +} +interface MLGraphBuilder { + convTranspose2d(input: MLOperand, filter: MLOperand, options?: MLConvTranspose2dOptions): MLOperand; +} +interface MLGraphBuilder { + add(a: MLOperand, b: MLOperand): MLOperand; + sub(a: MLOperand, b: MLOperand): MLOperand; + mul(a: MLOperand, b: MLOperand): MLOperand; + div(a: MLOperand, b: MLOperand): MLOperand; + max(a: MLOperand, b: MLOperand): MLOperand; + min(a: MLOperand, b: MLOperand): MLOperand; + pow(a: MLOperand, b: MLOperand): MLOperand; +} +interface MLGraphBuilder { + equal(a: MLOperand, b: MLOperand): MLOperand; + greater(a: MLOperand, b: MLOperand): MLOperand; + greaterOrEqual(a: MLOperand, b: MLOperand): MLOperand; + lesser(a: MLOperand, b: MLOperand): MLOperand; + lesserOrEqual(a: MLOperand, b: MLOperand): MLOperand; + logicalNot(a: MLOperand): MLOperand; +} +interface MLGraphBuilder { + abs(input: MLOperand): MLOperand; + ceil(input: MLOperand): MLOperand; + cos(input: MLOperand): MLOperand; + erf(input: MLOperand): MLOperand; + exp(input: MLOperand): MLOperand; + floor(input: MLOperand): MLOperand; + identity(input: MLOperand): MLOperand; + log(input: MLOperand): MLOperand; + neg(input: MLOperand): MLOperand; + reciprocal(input: MLOperand): MLOperand; + sin(input: MLOperand): MLOperand; + sqrt(input: MLOperand): MLOperand; + tan(input: MLOperand): MLOperand; +} +interface MLEluOptions { + alpha?: number; +} +interface MLGraphBuilder { + elu(input: MLOperand, options?: MLEluOptions): MLOperand; + elu(options?: MLEluOptions): MLActivation; +} +interface MLGraphBuilder { + expand(input: MLOperand, newShape: Array): MLOperand; +} +interface MLGatherOptions { + axis?: number; +} +interface MLGraphBuilder { + gather(input: MLOperand, indices: MLOperand, options?: MLGatherOptions): MLOperand; +} +interface MLGraphBuilder { + gelu(input: MLOperand): MLOperand; + gelu(): MLActivation; +} +interface MLGemmOptions { + c?: MLOperand; + alpha?: number; + beta?: number; + aTranspose?: boolean; + bTranspose?: boolean; +} +interface MLGraphBuilder { + gemm(a: MLOperand, b: MLOperand, options?: MLGemmOptions): MLOperand; +} +type MLGruWeightLayout = "zrn" | "rzn"; +type MLRecurrentNetworkDirection = "forward" | "backward" | "both"; +interface MLGruOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + initialHiddenState?: MLOperand; + resetAfter?: boolean; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLGruWeightLayout; + activations?: Array; +} +interface MLGraphBuilder { + gru(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, options?: MLGruOptions): Array; +} +interface MLGruCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + resetAfter?: boolean; + layout?: MLGruWeightLayout; + activations?: Array; +} +interface MLGraphBuilder { + gruCell(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, hiddenSize: number, options?: MLGruCellOptions): MLOperand; +} +interface MLHardSigmoidOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + hardSigmoid(input: MLOperand, options?: MLHardSigmoidOptions): MLOperand; + hardSigmoid(options?: MLHardSigmoidOptions): MLActivation; +} +interface MLGraphBuilder { + hardSwish(input: MLOperand): MLOperand; + hardSwish(): MLActivation; +} +interface MLInstanceNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + epsilon?: number; + layout?: MLInputOperandLayout; +} +interface MLGraphBuilder { + instanceNormalization(input: MLOperand, options?: MLInstanceNormalizationOptions): MLOperand; +} +interface MLLayerNormalizationOptions { + scale?: MLOperand; + bias?: MLOperand; + axes?: Array; + epsilon?: number; +} +interface MLGraphBuilder { + layerNormalization(input: MLOperand, options?: MLLayerNormalizationOptions): MLOperand; +} +interface MLLeakyReluOptions { + alpha?: number; +} +interface MLGraphBuilder { + leakyRelu(input: MLOperand, options?: MLLeakyReluOptions): MLOperand; + leakyRelu(options?: MLLeakyReluOptions): MLActivation; +} +interface MLLinearOptions { + alpha?: number; + beta?: number; +} +interface MLGraphBuilder { + linear(input: MLOperand, options?: MLLinearOptions): MLOperand; + linear(options?: MLLinearOptions): MLActivation; +} +type MLLstmWeightLayout = "iofg" | "ifgo"; +interface MLLstmOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + initialHiddenState?: MLOperand; + initialCellState?: MLOperand; + returnSequence?: boolean; + direction?: MLRecurrentNetworkDirection; + layout?: MLLstmWeightLayout; + activations?: Array; +} +interface MLGraphBuilder { + lstm(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number, options?: MLLstmOptions): Array; +} +interface MLLstmCellOptions { + bias?: MLOperand; + recurrentBias?: MLOperand; + peepholeWeight?: MLOperand; + layout?: MLLstmWeightLayout; + activations?: Array; +} +interface MLGraphBuilder { + lstmCell(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, cellState: MLOperand, hiddenSize: number, options?: MLLstmCellOptions): Array; +} +interface MLGraphBuilder { + matmul(a: MLOperand, b: MLOperand): MLOperand; +} +type MLPaddingMode = "constant" | "edge" | "reflection" | "symmetric"; +interface MLPadOptions { + mode?: MLPaddingMode; + value?: number; +} +interface MLGraphBuilder { + pad(input: MLOperand, beginningPadding: Array, endingPadding: Array, options?: MLPadOptions): MLOperand; +} +type MLRoundingType = "floor" | "ceil"; +interface MLPool2dOptions { + windowDimensions?: Array; + padding?: Array; + strides?: Array; + dilations?: Array; + layout?: MLInputOperandLayout; + roundingType?: MLRoundingType; + outputSizes?: Array; +} +interface MLGraphBuilder { + averagePool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + l2Pool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; + maxPool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand; +} +interface MLGraphBuilder { + prelu(input: MLOperand, slope: MLOperand): MLOperand; +} +interface MLReduceOptions { + axes?: Array; + keepDimensions?: boolean; +} +interface MLGraphBuilder { + reduceL1(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceL2(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceLogSumExp(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMax(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMean(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceMin(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceProduct(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSum(input: MLOperand, options?: MLReduceOptions): MLOperand; + reduceSumSquare(input: MLOperand, options?: MLReduceOptions): MLOperand; +} +interface MLGraphBuilder { + relu(input: MLOperand): MLOperand; + relu(): MLActivation; +} +type MLInterpolationMode = "nearest-neighbor" | "linear"; +interface MLResample2dOptions { + mode?: MLInterpolationMode; + scales?: Array; + sizes?: Array; + axes?: Array; +} +interface MLGraphBuilder { + resample2d(input: MLOperand, options?: MLResample2dOptions): MLOperand; +} +interface MLGraphBuilder { + reshape(input: MLOperand, newShape: Array): MLOperand; +} +interface MLGraphBuilder { + sigmoid(input: MLOperand): MLOperand; + sigmoid(): MLActivation; +} +interface MLGraphBuilder { + slice(input: MLOperand, starts: Array, sizes: Array): MLOperand; +} +interface MLGraphBuilder { + softmax(input: MLOperand, axis: number): MLOperand; + softmax(axis: number): MLActivation; +} +interface MLGraphBuilder { + softplus(input: MLOperand): MLOperand; + softplus(): MLActivation; +} +interface MLGraphBuilder { + softsign(input: MLOperand): MLOperand; + softsign(): MLActivation; +} +interface MLSplitOptions { + axis?: number; +} +interface MLGraphBuilder { + split(input: MLOperand, splits: number | Array, options?: MLSplitOptions): Array; +} +interface MLGraphBuilder { + tanh(input: MLOperand): MLOperand; + tanh(): MLActivation; +} +interface MLTransposeOptions { + permutation?: Array; +} +interface MLGraphBuilder { + transpose(input: MLOperand, options?: MLTransposeOptions): MLOperand; +} +interface MLTriangularOptions { + upper?: boolean; + diagonal?: number; +} +interface MLGraphBuilder { + triangular(input: MLOperand, options?: MLTriangularOptions): MLOperand; +} +interface MLGraphBuilder { + where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; +} + +// Experimental MLBuffer interface + +type MLSize64Out = number; +interface MLBuffer { + readonly size: MLSize64Out; + destroy(): void; +} +type MLSize64 = number; +interface MLBufferDescriptor { + size: MLSize64; +} +type MLNamedBuffers = Record; +interface MLContext { + createBuffer(descriptor: MLBufferDescriptor): MLBuffer; + writeBuffer(dstBuffer: MLBuffer, srcData: ArrayBufferView | ArrayBuffer, srcElementOffset?: MLSize64, srcElementSize?: MLSize64): void; + readBuffer(srcBuffer: MLBuffer): Promise; + dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 4d2b80e31a47e..f289fc20bba40 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -66,8 +66,6 @@ const setExecutionProviders = const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; - const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; - const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); const valueDataOffset = allocWasmString(deviceType, allocs); @@ -76,26 +74,6 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (numThreads !== undefined) { - // Just ignore invalid webnnOptions.numThreads. - const validatedNumThreads = - (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : - numThreads; - const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); - } - } - if (powerPreference) { - const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(powerPreference, allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); - } - } } break; case 'webgpu': diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a483ff09f0003..6ac3e3b3981d1 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; @@ -253,11 +258,43 @@ export const createSession = async( await Promise.all(loadingPromises); } + for (const provider of options?.executionProviders ?? []) { + const providerName = typeof provider === 'string' ? provider : provider.name; + if (providerName === 'webnn') { + if(wasm.currentContext) { + throw new Error('WebNN execution provider is already set.'); + } + if (typeof provider !== 'string') { + const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption; + const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (context) { + wasm.currentContext = context as MLContext; + } else if (gpuDevice) + wasm.currentContext = await navigator.ml.createContext(gpuDevice); + else { + wasm.currentContext = await navigator.ml.createContext({ deviceType, numThreads, powerPreference }); + } + } else { + wasm.currentContext = await navigator.ml.createContext(); + } + break; + } + } + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } + // clear current MLContext after session creation + if(wasm.currentContext) { + wasm.currentContext = undefined; + } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); const enableGraphCapture = !!options?.enableGraphCapture; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9ced89651e844..70728c82e7753 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import type {Tensor} from 'onnxruntime-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -19,7 +24,7 @@ export declare namespace JSEP { type CaptureEndFunction = () => void; type ReplayFunction = () => void; - export interface Module extends WebGpuModule { + export interface Module extends WebGpuModule, WebNnModule { /** * Mount the external data file to an internal map, which will be used during session initialization. * @@ -106,6 +111,13 @@ export declare namespace JSEP { */ jsepOnReleaseSession: (sessionId: number) => void; } + + export interface WebNnModule { + /** + * Active MLContext used to create WebNN EP. + */ + currentContext: MLContext; + } } export interface OrtInferenceAPIs { diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 13ed29667debe..6c6250e1c08c9 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,24 +17,13 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, - const std::string& webnn_threads_number, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { - // Create WebNN context and graph builder. - const emscripten::val ml = emscripten::val::global("navigator")["ml"]; - if (!ml.as()) { - ORT_THROW("Failed to get ml from navigator."); - } - emscripten::val context_options = emscripten::val::object(); - context_options.set("deviceType", emscripten::val(webnn_device_flags)); + // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; - // Set "numThreads" if it's not default 0. - if (webnn_threads_number.compare("0") != 0) { - context_options.set("numThreads", stoi(webnn_threads_number)); - } } else { preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { @@ -45,11 +34,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f ORT_THROW("Unknown WebNN deviceType."); } } - if (webnn_power_flags.compare("default") != 0) { - context_options.set("powerPreference", emscripten::val(webnn_power_flags)); - } - wnn_context_ = ml.call("createContext", context_options).await(); + wnn_context_ = emscripten::val::module_property("currentContext"); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index d9cfa5f17c0d4..ec02f249b673c 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -19,8 +19,7 @@ class Model; class WebNNExecutionProvider : public IExecutionProvider { public: - WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags); + explicit WebNNExecutionProvider(const std::string& webnn_device_flags); virtual ~WebNNExecutionProvider(); std::vector> diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 11acec8b1f354..7792aeabaabf2 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,27 +10,22 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, - const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} + explicit WebNNProviderFactory(const std::string& webnn_device_flags) + : webnn_device_flags_(webnn_device_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; - std::string webnn_threads_number_; - std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { - return std::make_shared(provider_options.at("deviceType"), - provider_options.at("numThreads"), - provider_options.at("powerPreference")); + return std::make_shared(provider_options.at("deviceType")); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 05408db9884cd..8bfb6c4ad2bcd 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -128,11 +128,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); - std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); - std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; - provider_options["numThreads"] = numThreads; - provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status();