Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webnn] update API of session options for WebNN #20816

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,55 @@ export declare namespace InferenceSession {
readonly name: 'webgpu';
preferredLayout?: 'NCHW'|'NHWC';
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {

// #region WebNN options

interface WebNNExecutionProviderName extends ExecutionProviderOption {
readonly name: 'webnn';
}

/**
* Represents a set of options for creating a WebNN MLContext.
*
* @see https://www.w3.org/TR/webnn/#dictdef-mlcontextoptions
*/
export interface WebNNContextOptions extends WebNNExecutionProviderName {
deviceType?: 'cpu'|'gpu'|'npu';
numThreads?: number;
powerPreference?: 'default'|'low-power'|'high-performance';
}

/**
* Represents a set of options for WebNN execution provider with MLContext.
*
* When MLContext is provided, the deviceType is also required so that the WebNN EP can determine the preferred
* channel layout.
*
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext
*/
export interface WebNNOptionsWithMLContext extends WebNNExecutionProviderName,
Omit<WebNNContextOptions, 'deviceType'>,
Required<Pick<WebNNContextOptions, 'deviceType'>> {
context: unknown /* MLContext */;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Represents a set of options for WebNN execution provider with MLContext which is created from GPUDevice.
*
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice
*/
export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName {
context: unknown /* MLContext */;
gpuDevice: unknown /* GPUDevice */;
}

/**
* Options for WebNN execution provider.
*/
export type WebNNExecutionProviderOption = WebNNContextOptions|WebNNOptionsWithMLContext|WebNNOptionsWebGpu;

// #endregion

export interface QnnExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'qnn';
// TODO add flags
Expand Down
30 changes: 16 additions & 14 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,36 @@ const setExecutionProviders =
epName = 'WEBNN';
if (typeof ep !== 'string') {
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
if (webnnOptions?.deviceType) {
// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs);
const valueDataOffset = allocWasmString(deviceType, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`);
checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
}
}
if (webnnOptions?.numThreads) {
let numThreads = webnnOptions.numThreads;
if (numThreads !== undefined) {
// Just ignore invalid webnnOptions.numThreads.
if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) {
numThreads = 0;
}
const validatedNumThreads =
(typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 :
numThreads;
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(numThreads.toString(), allocs);
const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`);
checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`);
}
}
if (webnnOptions?.powerPreference) {
if (powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
const valueDataOffset = allocWasmString(powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(
`Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}.`);
checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`);
}
}
}
Expand Down
Loading