diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index e90efd7b97c29..e4bfd2f565912 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -67,7 +67,10 @@ export const registerBackend = (name: string, backend: Backend, priority: number * @param backendName - the name of the backend. * @returns the backend instance if resolved and initialized successfully, or an error message if failed. */ -const tryResolveAndInitializeBackend = async(backendName: string): Promise => { +const tryResolveAndInitializeBackend = async( + backendName: string, + webnnOptions?: InferenceSession.WebNNExecutionProviderOption, + ): Promise => { const backendInfo = backends.get(backendName); if (!backendInfo) { return 'backend not found.'; @@ -81,7 +84,7 @@ const tryResolveAndInitializeBackend = async(backendName: string): Promise => { - // extract backend hints from session options + // extract backend hints from session options. const eps = options.executionProviders || []; const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - - // try to resolve and initialize all requested backends + if (backendNames.filter(name => name === 'webgpu').length > 1) { + throw new Error(`Registering duplicate 'webgpu' backends in the session options is not permitted`); + } + if (backendNames.filter(name => name === 'webnn').length > 1) { + throw new Error(`Registering duplicate 'webnn' backends in the session options is not permitted`); + } + // try to resolve and initialize all requested backends. let backend: Backend|undefined; const errors = []; const availableBackendNames = new Set(); for (const backendName of backendNames) { - const resolveResult = await tryResolveAndInitializeBackend(backendName); + // initialize webnn backend requires additional WebNNExecutionProviderOption. + let webnnOptions: InferenceSession.WebNNExecutionProviderOption|undefined; + if (backendName === 'webnn') { + webnnOptions = eps.find(e => typeof e !== 'string' && e.name === 'webnn') as + InferenceSession.WebNNExecutionProviderOption; + } + const resolveResult = await tryResolveAndInitializeBackend(backendName, webnnOptions); if (typeof resolveResult === 'string') { errors.push({name: backendName, err: resolveResult}); } else { diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 8c07bdd5c5c4a..2b024abdca788 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -71,7 +71,7 @@ export interface Backend { /** * Initialize the backend asynchronously. Should throw when failed. */ - init(backendName: string): Promise; + init(backendName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise; createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 31ecffb07e40c..6215b5d723cac 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -51,7 +51,7 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { * * @param backendName - the registered backend name. */ - async init(backendName: string): Promise { + async init(backendName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise { // populate wasm flags initializeFlags(); @@ -59,7 +59,7 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { await initializeWebAssemblyAndOrtRuntime(); // performe EP specific initialization - await initializeOrtEp(backendName); + await initializeOrtEp(backendName, webnnOptions); } createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 02246c9ee4767..54ae16b65b61e 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -50,7 +50,7 @@ interface MessageInitWasm extends MessageError { interface MessageInitEp extends MessageError { type: 'init-ep'; - in ?: {env: Env; epName: string}; + in ?: {env: Env; epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption}; out?: never; } diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 3ce37a2d6b652..6649578e858f6 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -60,8 +60,8 @@ self.onmessage = (ev: MessageEvent): void => { }); break; case 'init-ep': { - const {epName, env} = message!; - initEp(env, epName) + const {epName, env, webnnOptions} = message!; + initEp(env, epName, webnnOptions) .then( () => { postMessage({type}); diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 6ff4e86b1235e..b646ef0f2ca1f 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -118,16 +118,17 @@ export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { } }; -export const initializeOrtEp = async(epName: string): Promise => { +export const initializeOrtEp = + async(epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('init-ep', [resolve, reject]); - const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}}; + const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env, webnnOptions}}; proxyWorker!.postMessage(message); }); } else { - await core.initEp(env, epName); + await core.initEp(env, epName, webnnOptions); } }; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 9b27051f1b9fe..887f3e2b9c185 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -83,7 +83,8 @@ export const initRuntime = async(env: Env): Promise => { * @param env * @param epName */ -export const initEp = async(env: Env, epName: string): Promise => { +export const initEp = + async(env: Env, epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise => { if (!BUILD_DEFS.DISABLE_WEBGPU) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; @@ -128,9 +129,20 @@ export const initEp = async(env: Env, epName: string): Promise => { await initJsep('webgpu', getInstance(), env, adapter); } if (epName === 'webnn') { - // perform WebNN availability check + // perform WebNN availability check. if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) { throw new Error('WebNN is not supported in current environment'); + } else { + try { + if (webnnOptions?.powerPreference === 'default') { + // current implementation of WebNN API in Chromium does not support "default" powerPreference. + webnnOptions.powerPreference = undefined; + } + // validate if WebNN MLContext can be created with current options. + await (navigator as any).ml.createContext(webnnOptions); + } catch (e) { + throw(e); + } } await initJsep('webnn', getInstance(), env);