Skip to content

Commit

Permalink
[js/webgpu] support IO binding (microsoft#17480)
Browse files Browse the repository at this point in the history
<del>
**This PR is based on a few prerequisites PRs. They are listed as
below:**
- microsoft#17465
- microsoft#17469
- microsoft#17470
- microsoft#17472
- microsoft#17473
- microsoft#17484

Please review the current change by only looking at commit
e2e6623 and later.


</del>

### Description

This PR introduces WebGPU IO binding. This new feature allows
onnxruntime-web users to use tensors created from GPU as model
input/output so that a model inferencing can be done without unnecessary
data copy between CPU and GPU for model input/output.

### Examples

An E2E demo/example is being worked on.

Following is some simple demo with code snippet.

Let's first check today how we do:
```js
// STEP.1 - create an inference session:
const mySession = await ort.InferenceSession.create('./my_model.onnx', { executionProviders: ['webgpu'] });

// STEP.2 - create model input: (supposing myImageCpuData is a Float32Array)
const feeds = {
  'input_image:0': new ort.Tensor('float32', myImageCpuData, [1, 224, 224, 3])
};

// STEP.3 - run model
const myResults = await mySession.run(feeds);

// STEP.4 - get output data
const myData = myResults['output_image:0'].data; // Float32Array

```

#### for inputs (GPU tensor):

Now, with IO binding, you can create a tensor from a GPU buffer, and
feed it to the model:
```js
// new STEP.2.A - create model input from a GPU buffer: (supposing myInputGpuBuffer is a `GPUBuffer` object with input data)
const feeds = {
  'input_image:0': ort.Tensor.fromGpuBuffer(myInputGpuBuffer, { dataType: 'float32', dims: [1, 224, 224, 3] })
};
```

### for outputs (pre-allocated GPU tensor)

you can also do that for output, **if you know the output shape**:
```js
// new STEP.2.B - create model output from a GPU buffer: (supposing myOutputGpuBuffer is a pre-allocated `GPUBuffer` object)
const fetches = {
  'output_image:0': ort.Tensor.fromGpuBuffer(myOutputGpuBuffer, { dataType: 'float32', dims: [1, 512, 512, 3] })
};

// new STEP.3 - run model with pre-allocated output (fetches)
const myResults = await mySession.run(feeds, fetches);
```

### for outputs (specify location)

if you do not know the output shape, you can specify the output location
when creating the session:

```js
// new STEP.1 - create an inference session with an option "preferredOutputLocation":
const mySession = await ort.InferenceSession.create('./my_model.onnx', {
    executionProviders: ['webgpu'],
    preferredOutputLocation: "gpu-buffer"
});
```

if the model has multiple outputs, you can specify them seperately:
```js
// new STEP.1 - create an inference session with an option "preferredOutputLocation":
const mySession = await ort.InferenceSession.create('./my_model.onnx', {
    executionProviders: ['webgpu'],
    preferredOutputLocation: {
         "output_image:0": "gpu-buffer"
    }
});
```

now you don't need to prepare the `fetches` object and onnxruntime-web
will prepare output data on the location that specified.

#### read data

when you get the output tensor, you can:
```js
// get the gpu buffer object:
const gpuBuffer = myOutputTensor.gpuBuffer; // GPUBuffer

// get the CPU data asynchronizely
const cpuData = await myOutputTensor.getData();

// get the CPU data asynchronizely and release the underlying GPU resources
const cpuData = await myOutputTensor.getData(true);

// dispose the tensor (release the underlying GPU resources). This tensor object will be invalid after dispose() is called.
myOutputTensor.dispose();
```

#### resource management

JavaScript has GC so you don't need to worry about managing JavaScript
objects. But there are 2 types of resources that are not managed by GC:
- GPU buffer that used in tensors
- Underlying ORT native resources

To simplify, most of the unmanaged resources and handled inside ORT web.
But there are a few resources that need users to manage:
- All external GPU resources, including GPU buffers inside all tensors
created by `Tensor.fromGpuBuffer()`, will not be managed by ORT. User
should manage those GPU buffers themselves.
- When a session is created with `preferredOutputLocation` ==
"gpu-buffer" specified in session options, and the corresponding output
is not pre-allocated, user need to call the output tensor's `dispose()`
or `getData(true)` to manually release the underlying GPU buffers.
- ORT internal errors (including providing a pre-allocated output tensor
with wrong type/dims) will invalidate the whole wasm memory and is not
recoverable. An exception is thrown in this situation.
  • Loading branch information
fs-eire authored Sep 29, 2023
1 parent b4fbc25 commit 561aca9
Show file tree
Hide file tree
Showing 18 changed files with 1,177 additions and 288 deletions.
80 changes: 69 additions & 11 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import type {Tensor} from 'onnxruntime-common';

export declare namespace JSEP {
type BackendType = unknown;
type AllocFunction = (size: number) => number;
Expand All @@ -9,11 +11,8 @@ export declare namespace JSEP {
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise<void>;
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number;
export interface SessionState {
sessionId: number;
errors: Array<Promise<string|null>>;
}
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
}

export interface OrtWasmModule extends EmscriptenModule {
Expand All @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule {

_OrtFree(stringHandle: number): void;

_OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number):
number;
_OrtCreateTensor(
dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number,
dataLocation: number): number;
_OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number):
number;
_OrtReleaseTensor(tensorHandle: number): void;
_OrtCreateBinding(sessionHandle: number): number;
_OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise<number>;
_OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number;
_OrtClearBoundOutputs(ioBindingHandle: number): void;
_OrtReleaseBinding(ioBindingHandle: number): void;
_OrtRunWithBinding(
sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number,
runOptionsHandle: number): Promise<number>;
_OrtRun(
sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number,
outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number;
outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise<number>;

_OrtCreateSessionOptions(
graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number,
Expand Down Expand Up @@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule {
// #endregion

// #region JSEP
/**
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
* This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
*/
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;

/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
*
* @param context - specify the kernel context pointer.
* @param index - specify the index of the output.
* @param data - specify the pointer to encoded data of type and dims.
*/
_JsepOutput(context: number, index: number, data: number): number;
/**
* [exported from wasm] Get name of an operator node.
*
* @param kernel - specify the kernel pointer.
* @returns the pointer to a C-style UTF8 encoded string representing the node name.
*/
_JsepGetNodeName(kernel: number): number;

jsepOnRunStart?(sessionId: number): void;
jsepOnRunEnd?(sessionId: number): Promise<void>;
jsepRunPromise?: Promise<number>;
/**
* [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
*
* @param sessionId - specify the session ID.
* @param index - specify an integer to represent which input/output it is registering for. For input, it is the
* input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
* corresponding to the session's ouputNames.
* @param buffer - specify the GPU buffer to register.
* @param size - specify the original data size in byte.
* @returns the GPU data ID for the registered GPU buffer.
*/
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
/**
* [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
*
* @param sessionId - specify the session ID.
*/
jsepUnregisterBuffers?: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
*
* @param dataId - specify the GPU data ID
* @returns the GPU buffer.
*/
jsepGetBuffer: (dataId: number) => GPUBuffer;
/**
* [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
*
* @param gpuBuffer - specify the GPU buffer
* @param size - specify the original data size in byte.
* @param type - specify the tensor type.
* @returns the generated downloader function.
*/
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
// #endregion
}

Expand Down
66 changes: 53 additions & 13 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Env} from 'onnxruntime-common';
import {Env, Tensor} from 'onnxruntime-common';

import {configureLogger, LOG_DEBUG} from './log';
import {TensorView} from './tensor-view';
import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager';
import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types';
Expand Down Expand Up @@ -98,6 +98,11 @@ export class WebGpuBackend {

env: Env;

/**
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
*/
sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();

async initialize(env: Env): Promise<void> {
if (!navigator.gpu) {
// WebGPU is not available.
Expand Down Expand Up @@ -192,11 +197,13 @@ export class WebGpuBackend {
}

flush(): void {
this.endComputePass();
this.device.queue.submit([this.getCommandEncoder().finish()]);
this.gpuDataManager.refreshPendingBuffers();
this.commandEncoder = null;
this.pendingDispatchNumber = 0;
if (this.commandEncoder) {
this.endComputePass();
this.device.queue.submit([this.getCommandEncoder().finish()]);
this.gpuDataManager.refreshPendingBuffers();
this.commandEncoder = null;
this.pendingDispatchNumber = 0;
}
}

/**
Expand Down Expand Up @@ -304,12 +311,9 @@ export class WebGpuBackend {
}

async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise<void> {
const arrayBuffer = await this.gpuDataManager.download(gpuDataId);

// the underlying buffer may be changed after the async function is called. so we use a getter function to make sure
// the buffer is up-to-date.
const data = getTargetBuffer();
data.set(new Uint8Array(arrayBuffer, 0, data.byteLength));
await this.gpuDataManager.download(gpuDataId, getTargetBuffer);
}

alloc(size: number): number {
Expand Down Expand Up @@ -372,7 +376,7 @@ export class WebGpuBackend {
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`);
errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`));
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
Expand All @@ -387,4 +391,40 @@ export class WebGpuBackend {
this.currentKernelId = null;
}
}

// #region external buffer
registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number {
let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
if (!sessionInputOutputMapping) {
sessionInputOutputMapping = new Map();
this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
}

const previousBuffer = sessionInputOutputMapping.get(index);
const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
sessionInputOutputMapping.set(index, [id, buffer]);
return id;
}
unregisterBuffers(sessionId: number): void {
const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
if (sessionInputOutputMapping) {
sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
this.sessionExternalDataMapping.delete(sessionId);
}
}
getBuffer(gpuDataId: number): GPUBuffer {
const gpuData = this.gpuDataManager.get(gpuDataId);
if (!gpuData) {
throw new Error(`no GPU data for buffer: ${gpuDataId}`);
}
return gpuData.buffer;
}
createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes):
() => Promise<Tensor.DataType> {
return async () => {
const data = await downloadGpuData(this, gpuBuffer, size);
return createView(data.buffer, type);
};
}
// #endregion
}
15 changes: 10 additions & 5 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Env} from 'onnxruntime-common';

import {JSEP, OrtWasmModule} from '../binding/ort-wasm';
import {OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';

import {WebGpuBackend} from './backend-webgpu';
Expand Down Expand Up @@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext {
this.module.HEAPU32[offset++] = dims[i];
}
return this.module._JsepOutput(this.opKernelContext, index, data);
} catch (e) {
throw new Error(
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
`Error: ${e}`);
} finally {
this.module.stackRestore(stack);
}
Expand All @@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {

init(
// backend
{backend},
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),
Expand Down Expand Up @@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => {
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, sessionState.errors);
return backend.computeKernel(kernel, context, errors);
});
}
};
Loading

0 comments on commit 561aca9

Please sign in to comment.