From df236c7894ddf05c12ca78ebe1d24f5b135eb8ff Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Wed, 30 Oct 2024 10:26:33 -0700 Subject: [PATCH] [WebNN EP] Add cache for `MLContext`s in the `WebNNBackend` (#22510) ### Description This change adds a cache of `MLContext`s keyed by their options to the `WebNNBackend`. This makes is so that multiple `InferenceSession`s create with the same options will share the same context. ### Motivation and Context Since `MLTensor`s are tied `MLContext`s, developer can't easily share tensors between `InferenceSession` (outside of manually an `MLContext` and specifying the `context` options). This leads strange behaviors such as, ```js const sessionsA = ort.InferenceSession.create(urlA, { executionProviders: ["webnn"], preferredOutputLocation: "ml-buffer", }); const sessionsB = ort.InferenceSession.create(urlB, { executionProviders: ["webnn"], }); const temp = await sessionA.run({/* arguments */}); const result = await sessionB.run({"input":temp["output"]}); // ERROR: Failed to execute 'dispatch' on 'MLContext': Invalid inputs: The context of MLGraph doesn't match the context of the MLTensor with name "input". ``` We encountered this behavior when updating the transformers.js version in the developer preview demos. microsoft/webnn-developer-preview#46 --- js/web/lib/wasm/jsep/backend-webnn.ts | 61 +++++++++++++++++++++++++++ js/web/lib/wasm/wasm-core-impl.ts | 6 +-- js/web/lib/wasm/wasm-types.ts | 7 +++ onnxruntime/wasm/pre-jsep.js | 8 ++-- 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 47304fdc64ae4..d13c663651127 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -32,6 +32,24 @@ const onnxDataTypeToWebnnDataType = new Map([ [DataType.bool, 'uint8'], ]); +type MLContextEntry = { + gpuDevice?: GPUDevice; + options?: MLContextOptions; + mlContext: MLContext; +}; + +const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => { + if (a === b) { + return true; + } + if (a === undefined || b === undefined) { + return false; + } + const aKeys = Object.keys(a).sort() as Array; + const bKeys = Object.keys(b).sort() as Array; + return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]); +}; + /** * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. @@ -49,6 +67,10 @@ export class WebNNBackend { * Maps from MLContext to session ids. */ private sessionIdsByMLContext = new Map>(); + /** + * Cache of MLContexts. + */ + private mlContextCache: MLContextEntry[] = []; /** * Current session id. */ @@ -69,6 +91,41 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { + if (optionsOrDevice instanceof GPUDevice) { + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext }); + return mlContext; + } + } else if (optionsOrDevice === undefined) { + const mlContextIndex = this.mlContextCache.findIndex( + (entry) => entry.options === undefined && entry.gpuDevice === undefined, + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(); + this.mlContextCache.push({ mlContext }); + return mlContext; + } + } + + const mlContextIndex = this.mlContextCache.findIndex((entry) => + compareMLContextOptions(entry.options, optionsOrDevice), + ); + if (mlContextIndex !== -1) { + return this.mlContextCache[mlContextIndex].mlContext; + } else { + const mlContext = await navigator.ml.createContext(optionsOrDevice); + this.mlContextCache.push({ options: optionsOrDevice, mlContext }); + return mlContext; + } + } + public get currentContext(): MLContext { const mlContext = this.getMLContext(this.currentSessionId); if (!mlContext) { @@ -99,6 +156,10 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext); + if (mlContextIndex !== -1) { + this.mlContextCache.splice(mlContextIndex, 1); + } } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index eb74aa44b3a72..f3794a72efbe8 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -303,12 +303,12 @@ export const createSession = async ( if (context) { wasm.currentContext = context as MLContext; } else if (gpuDevice) { - wasm.currentContext = await navigator.ml.createContext(gpuDevice); + wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice); } else { - wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference }); + wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference }); } } else { - wasm.currentContext = await navigator.ml.createContext(); + wasm.currentContext = await wasm.jsepCreateMLContext!(); } break; } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index dff3ca74de5a4..40c614fdf866a 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -225,6 +225,13 @@ export declare namespace JSEP { * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + + /** + * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. + * @param optionsOrGpuDevice - specify the options or GPUDevice. + * @returns + */ + jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise; } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 0efbcab3a3238..213f0fbc1e458 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -237,11 +237,13 @@ Module['jsepInit'] = (name, params) => { } Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); - } - + }; + Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { + return backend['createMLContext'](optionsOrGpuDevice); + }; Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { return backend['registerMLConstant']( externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); - } + }; } };