From 3d45e5bf52f6279b0bfbb06732c54f0d9b16711e Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Sat, 19 Oct 2024 01:04:06 -0700 Subject: [PATCH 1/2] [WebNN EP] Add cache for `MLContext`s in the `WebNNBackend` ### Description This change adds a cache of `MLContext`s keyed by their options to the `WebNNBackend`. This makes is so that multiple `InferenceSession`s create with the same options will share the same context. ### Motivation and Context Since `MLTensor`s are tied `MLContext`s, developer can't easily share tensors between `InferenceSession` (outside of manually an `MLContext` and specifying the `context` options). This leads strange behaviors such as, ```js const sessionsA = ort.InferenceSession.create(urlA, { executionProviders: ["webnn"], preferredOutputLocation: "ml-buffer", }); const sessionsB = ort.InferenceSession.create(urlB, { executionProviders: ["webnn"], }); const temp = await sessionA.run({/* arguments */}); const result = await sessionB.run({"input":temp["output"]}); // ERROR: Failed to execute 'dispatch' on 'MLContext': Invalid inputs: The context of MLGraph doesn't match the context of the MLTensor with name "input". ``` We encountered this behavior when updating the transformers.js version in the developer preview demos. microsoft/webnn-developer-preview#46 --- js/web/lib/wasm/jsep/backend-webnn.ts | 70 +++++++++++++++++++++++++++ js/web/lib/wasm/wasm-core-impl.ts | 6 +-- js/web/lib/wasm/wasm-types.ts | 7 +++ onnxruntime/wasm/pre-jsep.js | 5 +- 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d13136d252d2a..4250c0c39c992 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -30,6 +30,33 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort(); + const bKeys = Object.keys(b).sort(); + if (aKeys.length !== bKeys.length) { + return false; + } + type GeneticObject = { [key: string]: object }; + for (const key of aKeys) { + if ((a as GeneticObject)[key] !== (b as GeneticObject)[key]) { + return false; + } + } + return true; +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -47,6 +74,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -67,6 +98,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -97,6 +163,10 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5f219f63aaf61..c0ed891623d45 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -301,12 +301,12 @@ export const createSession = async ( if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 3e08fe97f559d..834f99e5bc7a9 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -219,6 +219,13 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..97f1a4627b9d8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -234,6 +234,9 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; } }; From 96e116b653beff9d3a8cd0085babf4f45477e26e Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Wed, 23 Oct 2024 10:54:53 -0700 Subject: [PATCH 2/2] PR feedback * Changed to a clearer and more robust way to compare `MLContextOption`s --- js/web/lib/wasm/jsep/backend-webnn.ts | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 4250c0c39c992..607517e4d8a53 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -43,18 +43,9 @@ const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): bo if (a === undefined || b === undefined) { return false; } - const aKeys = Object.keys(a).sort(); - const bKeys = Object.keys(b).sort(); - if (aKeys.length !== bKeys.length) { - return false; - } - type GeneticObject = { [key: string]: object }; - for (const key of aKeys) { - if ((a as GeneticObject)[key] !== (b as GeneticObject)[key]) { - return false; - } - } - return true; + const aKeys = Object.keys(a).sort() as Array; + const bKeys = Object.keys(b).sort() as Array; + return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]); }; /**