Skip to content

Commit

Permalink
[WebNN EP] Add cache for MLContexts in the WebNNBackend
Browse files Browse the repository at this point in the history
### Description
This change adds a cache of `MLContext`s keyed by their options to the `WebNNBackend`. This makes is so that multiple `InferenceSession`s create with the same options will share the same context.

### Motivation and Context
Since `MLTensor`s are tied `MLContext`s, developer can't easily share tensors between `InferenceSession` (outside of manually an `MLContext` and specifying the `context` options). This leads strange behaviors such as,
```js
const sessionsA = ort.InferenceSession.create(urlA, {
  executionProviders: ["webnn"],
  preferredOutputLocation: "ml-buffer",
});
const sessionsB = ort.InferenceSession.create(urlB, {
  executionProviders: ["webnn"],
});
const temp = await sessionA.run({/* arguments */});
const result = await sessionB.run({"input":temp["output"]}); // ERROR: Failed to execute 'dispatch' on 'MLContext': Invalid inputs: The context of MLGraph doesn't match the context of the MLTensor with name "input".
```
We encountered this behavior when updating the transformers.js version in the developer preview demos. microsoft/webnn-developer-preview#46
  • Loading branch information
egalli committed Oct 19, 2024
1 parent 5aabc53 commit 3d45e5b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
70 changes: 70 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.bool, 'uint8'],
]);

type MLContextEntry = {
gpuDevice?: GPUDevice;
options?: MLContextOptions;
mlContext: MLContext;
};

const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => {
if (a === b) {
return true;
}
if (a === undefined || b === undefined) {
return false;
}
const aKeys = Object.keys(a).sort();
const bKeys = Object.keys(b).sort();
if (aKeys.length !== bKeys.length) {
return false;
}
type GeneticObject = { [key: string]: object };
for (const key of aKeys) {
if ((a as GeneticObject)[key] !== (b as GeneticObject)[key]) {
return false;
}
}
return true;
};

/**
* WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
* of the current MLContext being used by the sessions.
Expand All @@ -47,6 +74,10 @@ export class WebNNBackend {
* Maps from MLContext to session ids.
*/
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
/**
* Cache of MLContexts.
*/
private mlContextCache: MLContextEntry[] = [];
/**
* Current session id.
*/
Expand All @@ -67,6 +98,41 @@ export class WebNNBackend {
this.activeSessionId = sessionId;
}

public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
if (optionsOrDevice instanceof GPUDevice) {
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext });
return mlContext;
}
} else if (optionsOrDevice === undefined) {
const mlContextIndex = this.mlContextCache.findIndex(
(entry) => entry.options === undefined && entry.gpuDevice === undefined,
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext();
this.mlContextCache.push({ mlContext });
return mlContext;
}
}

const mlContextIndex = this.mlContextCache.findIndex((entry) =>
compareMLContextOptions(entry.options, optionsOrDevice),
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ options: optionsOrDevice, mlContext });
return mlContext;
}
}

public get currentContext(): MLContext {
const mlContext = this.getMLContext(this.currentSessionId);
if (!mlContext) {
Expand Down Expand Up @@ -97,6 +163,10 @@ export class WebNNBackend {
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
this.sessionIdsByMLContext.delete(mlContext);
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext);
if (mlContextIndex !== -1) {
this.mlContextCache.splice(mlContextIndex, 1);
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,12 @@ export const createSession = async (
if (context) {
wasm.currentContext = context as MLContext;
} else if (gpuDevice) {
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice);
} else {
wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference });
wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference });
}
} else {
wasm.currentContext = await navigator.ml.createContext();
wasm.currentContext = await wasm.jsepCreateMLContext!();
}
break;
}
Expand Down
7 changes: 7 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ export declare namespace JSEP {
* @returns the MLTensor ID for the external MLTensor.
*/
jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number;

/**
* [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions.
* @param optionsOrGpuDevice - specify the options or GPUDevice.
* @returns
*/
jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise<MLContext>;
}
}

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ Module['jsepInit'] = (name, params) => {
}
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
return backend['registerMLTensor'](tensor, dataType, shape);
}
};
Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => {
return backend['createMLContext'](optionsOrGpuDevice);
};
}
};

0 comments on commit 3d45e5b

Please sign in to comment.