Skip to content

Commit

Permalink
[WebNN EP] Add cache for MLContexts in the WebNNBackend (microsof…
Browse files Browse the repository at this point in the history
…t#22510)

### Description
This change adds a cache of `MLContext`s keyed by their options to the
`WebNNBackend`. This makes is so that multiple `InferenceSession`s
create with the same options will share the same context.

### Motivation and Context
Since `MLTensor`s are tied `MLContext`s, developer can't easily share
tensors between `InferenceSession` (outside of manually an `MLContext`
and specifying the `context` options). This leads strange behaviors such
as,
```js
const sessionsA = ort.InferenceSession.create(urlA, {
  executionProviders: ["webnn"],
  preferredOutputLocation: "ml-buffer",
});
const sessionsB = ort.InferenceSession.create(urlB, {
  executionProviders: ["webnn"],
});
const temp = await sessionA.run({/* arguments */});
const result = await sessionB.run({"input":temp["output"]}); // ERROR: Failed to execute 'dispatch' on 'MLContext': Invalid inputs: The context of MLGraph doesn't match the context of the MLTensor with name "input".
```
We encountered this behavior when updating the transformers.js version
in the developer preview demos. microsoft/webnn-developer-preview#46
  • Loading branch information
egalli authored and Ishwar Raut committed Nov 19, 2024
1 parent 693452c commit bb700b1
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
61 changes: 61 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.bool, 'uint8'],
]);

type MLContextEntry = {
gpuDevice?: GPUDevice;
options?: MLContextOptions;
mlContext: MLContext;
};

const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => {
if (a === b) {
return true;
}
if (a === undefined || b === undefined) {
return false;
}
const aKeys = Object.keys(a).sort() as Array<keyof typeof a>;
const bKeys = Object.keys(b).sort() as Array<keyof typeof b>;
return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]);
};

/**
* WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
* of the current MLContext being used by the sessions.
Expand All @@ -49,6 +67,10 @@ export class WebNNBackend {
* Maps from MLContext to session ids.
*/
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
/**
* Cache of MLContexts.
*/
private mlContextCache: MLContextEntry[] = [];
/**
* Current session id.
*/
Expand All @@ -69,6 +91,41 @@ export class WebNNBackend {
this.activeSessionId = sessionId;
}

public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
if (optionsOrDevice instanceof GPUDevice) {
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext });
return mlContext;
}
} else if (optionsOrDevice === undefined) {
const mlContextIndex = this.mlContextCache.findIndex(
(entry) => entry.options === undefined && entry.gpuDevice === undefined,
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext();
this.mlContextCache.push({ mlContext });
return mlContext;
}
}

const mlContextIndex = this.mlContextCache.findIndex((entry) =>
compareMLContextOptions(entry.options, optionsOrDevice),
);
if (mlContextIndex !== -1) {
return this.mlContextCache[mlContextIndex].mlContext;
} else {
const mlContext = await navigator.ml.createContext(optionsOrDevice);
this.mlContextCache.push({ options: optionsOrDevice, mlContext });
return mlContext;
}
}

public get currentContext(): MLContext {
const mlContext = this.getMLContext(this.currentSessionId);
if (!mlContext) {
Expand Down Expand Up @@ -99,6 +156,10 @@ export class WebNNBackend {
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
this.sessionIdsByMLContext.delete(mlContext);
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext);
if (mlContextIndex !== -1) {
this.mlContextCache.splice(mlContextIndex, 1);
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,12 @@ export const createSession = async (
if (context) {
wasm.currentContext = context as MLContext;
} else if (gpuDevice) {
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice);
} else {
wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference });
wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference });
}
} else {
wasm.currentContext = await navigator.ml.createContext();
wasm.currentContext = await wasm.jsepCreateMLContext!();
}
break;
}
Expand Down
7 changes: 7 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ export declare namespace JSEP {
* @returns the MLTensor ID for the external MLTensor.
*/
jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number;

/**
* [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions.
* @param optionsOrGpuDevice - specify the options or GPUDevice.
* @returns
*/
jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise<MLContext>;
}
}

Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,13 @@ Module['jsepInit'] = (name, params) => {
}
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
return backend['registerMLTensor'](tensor, dataType, shape);
}

};
Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => {
return backend['createMLContext'](optionsOrGpuDevice);
};
Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => {
return backend['registerMLConstant'](
externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles);
}
};
}
};

0 comments on commit bb700b1

Please sign in to comment.