Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] Update API for ort.env.webgpu #23026

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ export declare namespace Env {
*
* This setting is available only when WebAssembly SIMD feature is available in current context.
*
* @defaultValue `true`
*
* @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD
* build is no longer provided. This property will be removed in future release.
* @defaultValue `true`
*/
simd?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
* @defaultValue `false`
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
*/
trace?: boolean;

Expand Down Expand Up @@ -153,7 +155,7 @@ export declare namespace Env {
/**
* Set or get the profiling configuration.
*/
profiling?: {
profiling: {
/**
* Set or get the profiling mode.
*
Expand All @@ -176,6 +178,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific power preference.
*/
powerPreference?: 'low-power' | 'high-performance';
/**
Expand All @@ -187,6 +192,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific fallback option.
*/
forceFallbackAdapter?: boolean;
/**
Expand All @@ -199,16 +207,25 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
*
* @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo`
* (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the
* device. When it's available, there is no need to set/get the {@link adapter} property.
*/
adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
* Set or get the GPU device for WebGPU.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* There are 3 valid scenarios of accessing this property:
* - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend
* to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown.
* - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice
* instance. Returns a `Promise` that resolves to a `GPUDevice` object.
* - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the
* `GPUDevice` object used by the WebGPU backend.
*/
readonly device: TryGetGlobalType<'GPUDevice'>;
get device(): Promise<TryGetGlobalType<'GPUDevice'>>;
set device(value: TryGetGlobalType<'GPUDevice'>);
/**
* Set or get whether validate input content.
*
Expand Down
12 changes: 6 additions & 6 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ export class TensorResultValidator {
}
}

function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise<ort.Tensor> {
if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand All @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
});
}

function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
if (!isGpuBufferSupportedType(type)) {
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
}

const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;

const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand Down Expand Up @@ -725,7 +725,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') {
feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]);
} else {
feeds[name] = createGpuTensorForInput(feeds[name]);
feeds[name] = await createGpuTensorForInput(feeds[name]);
}
}
}
Expand All @@ -742,7 +742,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-tensor') {
fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims);
} else {
fetches[name] = createGpuTensorForOutput(type, dims);
fetches[name] = await createGpuTensorForOutput(type, dims);
}
}
}
Expand Down
Loading