diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d13136d252d2a..4250c0c39c992 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -30,6 +30,33 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort(); + const bKeys = Object.keys(b).sort(); + if (aKeys.length !== bKeys.length) { + return false; + } + type GeneticObject = { [key: string]: object }; + for (const key of aKeys) { + if ((a as GeneticObject)[key] !== (b as GeneticObject)[key]) { + return false; + } + } + return true; +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -47,6 +74,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -67,6 +98,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -97,6 +163,10 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5f219f63aaf61..c0ed891623d45 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -301,12 +301,12 @@ export const createSession = async ( if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 3e08fe97f559d..834f99e5bc7a9 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -219,6 +219,13 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..97f1a4627b9d8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -234,6 +234,9 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; } };