diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 353d93bbc34ae..069fd9b49e484 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -242,12 +242,62 @@ export declare namespace InferenceSession { readonly name: 'webgpu'; preferredLayout?: 'NCHW'|'NHWC'; } - export interface WebNNExecutionProviderOption extends ExecutionProviderOption { + + // #region WebNN options + + interface WebNNExecutionProviderName extends ExecutionProviderOption { readonly name: 'webnn'; + } + + /** + * Represents a set of options for creating a WebNN MLContext. + * + * @see https://www.w3.org/TR/webnn/#dictdef-mlcontextoptions + */ + export interface WebNNContextOptions { deviceType?: 'cpu'|'gpu'|'npu'; numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } + + /** + * Represents a set of options for WebNN execution provider without MLContext. + */ + export interface WebNNOptionsWithoutMLContext extends WebNNExecutionProviderName, WebNNContextOptions { + context?: never; + } + + /** + * Represents a set of options for WebNN execution provider with MLContext. + * + * When MLContext is provided, the deviceType is also required so that the WebNN EP can determine the preferred + * channel layout. + * + * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext + */ + export interface WebNNOptionsWithMLContext extends WebNNExecutionProviderName, + Omit, + Required> { + context: unknown /* MLContext */; + } + + /** + * Represents a set of options for WebNN execution provider with MLContext which is created from GPUDevice. + * + * @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice + */ + export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName { + context: unknown /* MLContext */; + gpuDevice: unknown /* GPUDevice */; + } + + /** + * Options for WebNN execution provider. + */ + export type WebNNExecutionProviderOption = WebNNOptionsWithoutMLContext|WebNNOptionsWithMLContext|WebNNOptionsWebGpu; + + // #endregion + export interface QnnExecutionProviderOption extends ExecutionProviderOption { readonly name: 'qnn'; // TODO add flags diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 48eac57494726..4d2b80e31a47e 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -64,34 +64,36 @@ const setExecutionProviders = epName = 'WEBNN'; if (typeof ep !== 'string') { const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption; - if (webnnOptions?.deviceType) { + // const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context; + const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference; + if (deviceType) { const keyDataOffset = allocWasmString('deviceType', allocs); - const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs); + const valueDataOffset = allocWasmString(deviceType, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`); + checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`); } } - if (webnnOptions?.numThreads) { - let numThreads = webnnOptions.numThreads; + if (numThreads !== undefined) { // Just ignore invalid webnnOptions.numThreads. - if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) { - numThreads = 0; - } + const validatedNumThreads = + (typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 : + numThreads; const keyDataOffset = allocWasmString('numThreads', allocs); - const valueDataOffset = allocWasmString(numThreads.toString(), allocs); + const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`); + checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`); } } - if (webnnOptions?.powerPreference) { + if (powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); - const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); + const valueDataOffset = allocWasmString(powerPreference, allocs); if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { - checkLastError( - `Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}.`); + checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`); } } }