Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Validate MLContext creation early to allow fallback to other EP #20735

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ export const registerBackend = (name: string, backend: Backend, priority: number
* @param backendName - the name of the backend.
* @returns the backend instance if resolved and initialized successfully, or an error message if failed.
*/
const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backend|string> => {
const tryResolveAndInitializeBackend = async(
backendName: string,
webnnOptions?: InferenceSession.WebNNExecutionProviderOption,
): Promise<Backend|string> => {
const backendInfo = backends.get(backendName);
if (!backendInfo) {
return 'backend not found.';
Expand All @@ -81,7 +84,7 @@ const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backe
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
backendInfo.initPromise = backendInfo.backend.init(backendName, webnnOptions);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
Expand Down Expand Up @@ -109,17 +112,28 @@ const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backe
*/
export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
// extract backend hints from session options
// extract backend hints from session options.
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;

// try to resolve and initialize all requested backends
if (backendNames.filter(name => name === 'webgpu').length > 1) {
throw new Error(`Registering duplicate 'webgpu' backends in the session options is not permitted`);
}
if (backendNames.filter(name => name === 'webnn').length > 1) {
throw new Error(`Registering duplicate 'webnn' backends in the session options is not permitted`);
}
Comment on lines +122 to +124

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this disable trying do this?

  executionProviders: [
    {
      name: "webnn",
      deviceType: "gpu",
    },
    {
      name: "webnn",
      deviceType: "cpu",
    },
    ...
  ]

Copy link
Contributor Author

@Honry Honry May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating multiple WebNN or WebGPU eps in a single session is not allowed, @fs-eire, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if that is possible or not technically. But in practice, I think there is no use case of multiple instances of the same EP being used in one session.

// try to resolve and initialize all requested backends.
let backend: Backend|undefined;
const errors = [];
const availableBackendNames = new Set<string>();
for (const backendName of backendNames) {
const resolveResult = await tryResolveAndInitializeBackend(backendName);
// initialize webnn backend requires additional WebNNExecutionProviderOption.
let webnnOptions: InferenceSession.WebNNExecutionProviderOption|undefined;
if (backendName === 'webnn') {
webnnOptions = eps.find(e => typeof e !== 'string' && e.name === 'webnn') as
InferenceSession.WebNNExecutionProviderOption;
}
const resolveResult = await tryResolveAndInitializeBackend(backendName, webnnOptions);
if (typeof resolveResult === 'string') {
errors.push({name: backendName, err: resolveResult});
} else {
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export interface Backend {
/**
* Initialize the backend asynchronously. Should throw when failed.
*/
init(backendName: string): Promise<void>;
init(backendName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise<void>;

createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ export class OnnxruntimeWebAssemblyBackend implements Backend {
*
* @param backendName - the registered backend name.
*/
async init(backendName: string): Promise<void> {
async init(backendName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise<void> {
// populate wasm flags
initializeFlags();

// init wasm
await initializeWebAssemblyAndOrtRuntime();

// performe EP specific initialization
await initializeOrtEp(backendName);
await initializeOrtEp(backendName, webnnOptions);
}
createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ interface MessageInitWasm extends MessageError {

interface MessageInitEp extends MessageError {
type: 'init-ep';
in ?: {env: Env; epName: string};
in ?: {env: Env; epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption};
out?: never;
}

Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
});
break;
case 'init-ep': {
const {epName, env} = message!;
initEp(env, epName)
const {epName, env, webnnOptions} = message!;
initEp(env, epName, webnnOptions)
.then(
() => {
postMessage({type});
Expand Down
7 changes: 4 additions & 3 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,17 @@ export const initializeWebAssemblyAndOrtRuntime = async(): Promise<void> => {
}
};

export const initializeOrtEp = async(epName: string): Promise<void> => {
export const initializeOrtEp =
async(epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
enqueueCallbacks('init-ep', [resolve, reject]);
const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}};
const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env, webnnOptions}};
proxyWorker!.postMessage(message);
});
} else {
await core.initEp(env, epName);
await core.initEp(env, epName, webnnOptions);
}
};

Expand Down
16 changes: 14 additions & 2 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param env
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
export const initEp =
async(env: Env, epName: string, webnnOptions?: InferenceSession.WebNNExecutionProviderOption): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU) {
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;
Expand Down Expand Up @@ -128,9 +129,20 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
await initJsep('webgpu', getInstance(), env, adapter);
}
if (epName === 'webnn') {
// perform WebNN availability check
// perform WebNN availability check.
if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) {
throw new Error('WebNN is not supported in current environment');
} else {
try {
if (webnnOptions?.powerPreference === 'default') {
// current implementation of WebNN API in Chromium does not support "default" powerPreference.
webnnOptions.powerPreference = undefined;
}
// validate if WebNN MLContext can be created with current options.
await (navigator as any).ml.createContext(webnnOptions);
} catch (e) {
throw(e);
}
}

await initJsep('webnn', getInstance(), env);
Expand Down