Skip to content

Commit

Permalink
[js/webgpu] allow setting env.webgpu.adapter (#19940)
Browse files Browse the repository at this point in the history
### Description
Allow user to set `env.webgpu.adapter` before creating the first
inference session.

Feature request:
#19857 (comment)

@xenova
  • Loading branch information
fs-eire authored Mar 19, 2024
1 parent 8293aa1 commit 01c7aaf
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
10 changes: 7 additions & 3 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,20 @@ export declare namespace Env {
*/
forceFallbackAdapter?: boolean;
/**
* Get the adapter for WebGPU.
* Set or get the adapter for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
* Setting this property only has effect before the first WebGPU inference session is created. The value will be
* used as the GPU adapter for the underlying WebGPU backend to create GPU device.
*
* If this property is not set, it will be available to get after the first WebGPU inference session is created. The
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
*
* see comments on {@link Tensor.GpuBufferType}
*/
readonly adapter: unknown;
adapter: unknown;
/**
* Get the device for WebGPU.
*
Expand Down
6 changes: 4 additions & 2 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ export class WebGpuBackend {
}
};

Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter});
Object.defineProperty(
this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
Object.defineProperty(
this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});

// init queryType, which is necessary for InferenceSession.create
this.setQueryType();
Expand Down
35 changes: 24 additions & 11 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise<void> => {
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
const powerPreference = env.webgpu?.powerPreference;
if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') {
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
}
const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter;
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
}
const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});

let adapter = env.webgpu.adapter as GPUAdapter | null;
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
// if adapter is not set, request a new adapter.
const powerPreference = env.webgpu.powerPreference;
if (powerPreference !== undefined && powerPreference !== 'low-power' &&
powerPreference !== 'high-performance') {
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
}
const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
}
adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. ' +
'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
}
} else {
// if adapter is set, validate it.
if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' ||
typeof adapter.requestDevice !== 'function') {
throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
}
}

if (!env.wasm.simd) {
Expand Down

0 comments on commit 01c7aaf

Please sign in to comment.