diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 47304fdc64ae4..d13c663651127 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -32,6 +32,24 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort() as Array; + const bKeys = Object.keys(b).sort() as Array; + return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]); +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -49,6 +67,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -69,6 +91,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -99,6 +156,10 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index eb74aa44b3a72..f3794a72efbe8 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -303,12 +303,12 @@ export const createSession = async ( if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index dff3ca74de5a4..40c614fdf866a 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -225,6 +225,13 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0efbcab3a3238..213f0fbc1e458 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -237,11 +237,13 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } - + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { return backend['registerMLConstant']( externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); - } + }; } };