Skip to content

Commit

Permalink
[js/webnn] update API of session options for WebNN (#20816)
Browse files Browse the repository at this point in the history
### Description

This PR is an API-only change to address the requirements being
discussed in #20729.

There are multiple ways that users may create an ORT session by
specifying the session options differently.

All the code snippet below will use the variable `webnnOptions` as this:
```js
const myWebnnSession = await ort.InferenceSession.create('./model.onnx', {
   executionProviders: [
     webnnOptions
   ]
});
```

### The old way (backward-compatibility)

```js
// all-default, name only
const webnnOptions_0 = 'webnn';

// all-default, properties omitted
const webnnOptions_1 = { name: 'webnn' };

// partial
const webnnOptions_2 = {
  name: 'webnn',
  deviceType: 'cpu'
};

// full
const webnnOptions_3 = {
  name: 'webnn',
  deviceType: 'gpu',
  numThreads: 1,
  powerPreference: 'high-performance'
};
```

### The new way (specify with MLContext)

```js
// options to create MLcontext
const options = {
  deviceType: 'gpu',
  powerPreference: 'high-performance'
};

const myMlContext = await navigator.ml.createContext(options);

// options for session options
const webnnOptions = {
  name: 'webnn',
  context: myMlContext,
  ...options
};
```

This should throw (because no deviceType is specified):
```js
const myMlContext = await navigator.ml.createContext({ ... });
const webnnOptions = {
  name: 'webnn',
  context: myMlContext
};
```

### Interop with WebGPU
```js
// get WebGPU device
const adaptor = await navigator.gpu.requestAdapter({ ... });
const device = await adaptor.requestDevice({ ... });

// set WebGPU adaptor and device
ort.env.webgpu.adaptor = adaptor;
ort.env.webgpu.device = device;

const myMlContext = await navigator.ml.createContext(device);
const webnnOptions = {
  name: 'webnn',
  context: myMlContext,
  gpuDevice: device
};
```

This should throw (because cannot specify both gpu device and MLContext
option at the same time):
```js
const webnnOptions = {
  name: 'webnn',
  context: myMlContext,
  gpuDevice: device,
  deviceType: 'gpu'
};
```
  • Loading branch information
fs-eire authored May 31, 2024
1 parent 67bc943 commit 35697d2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
52 changes: 51 additions & 1 deletion js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,62 @@ export declare namespace InferenceSession {
readonly name: 'webgpu';
preferredLayout?: 'NCHW'|'NHWC';
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {

// #region WebNN options

interface WebNNExecutionProviderName extends ExecutionProviderOption {
readonly name: 'webnn';
}

/**
* Represents a set of options for creating a WebNN MLContext.
*
* @see https://www.w3.org/TR/webnn/#dictdef-mlcontextoptions
*/
export interface WebNNContextOptions {
deviceType?: 'cpu'|'gpu'|'npu';
numThreads?: number;
powerPreference?: 'default'|'low-power'|'high-performance';
}

/**
* Represents a set of options for WebNN execution provider without MLContext.
*/
export interface WebNNOptionsWithoutMLContext extends WebNNExecutionProviderName, WebNNContextOptions {
context?: never;
}

/**
* Represents a set of options for WebNN execution provider with MLContext.
*
* When MLContext is provided, the deviceType is also required so that the WebNN EP can determine the preferred
* channel layout.
*
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext
*/
export interface WebNNOptionsWithMLContext extends WebNNExecutionProviderName,
Omit<WebNNContextOptions, 'deviceType'>,
Required<Pick<WebNNContextOptions, 'deviceType'>> {
context: unknown /* MLContext */;
}

/**
* Represents a set of options for WebNN execution provider with MLContext which is created from GPUDevice.
*
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice
*/
export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName {
context: unknown /* MLContext */;
gpuDevice: unknown /* GPUDevice */;
}

/**
* Options for WebNN execution provider.
*/
export type WebNNExecutionProviderOption = WebNNOptionsWithoutMLContext|WebNNOptionsWithMLContext|WebNNOptionsWebGpu;

// #endregion

export interface QnnExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'qnn';
// TODO add flags
Expand Down
30 changes: 16 additions & 14 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,36 @@ const setExecutionProviders =
epName = 'WEBNN';
if (typeof ep !== 'string') {
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
if (webnnOptions?.deviceType) {
// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs);
const valueDataOffset = allocWasmString(deviceType, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`);
checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
}
}
if (webnnOptions?.numThreads) {
let numThreads = webnnOptions.numThreads;
if (numThreads !== undefined) {
// Just ignore invalid webnnOptions.numThreads.
if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) {
numThreads = 0;
}
const validatedNumThreads =
(typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 :
numThreads;
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(numThreads.toString(), allocs);
const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`);
checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`);
}
}
if (webnnOptions?.powerPreference) {
if (powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
const valueDataOffset = allocWasmString(powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(
`Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}.`);
checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`);
}
}
}
Expand Down

0 comments on commit 35697d2

Please sign in to comment.