From 561aca97cfcf76ce6d190a2403cae34c17bee75a Mon Sep 17 00:00:00 2001
From: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Date: Fri, 29 Sep 2023 11:24:42 -0700
Subject: [PATCH] [js/webgpu] support IO binding (#17480)
**This PR is based on a few prerequisites PRs. They are listed as
below:**
- #17465
- #17469
- #17470
- #17472
- #17473
- #17484
Please review the current change by only looking at commit
e2e6623e673ec6de55a5c1f8edcbd3a46b535a89 and later.
### Description
This PR introduces WebGPU IO binding. This new feature allows
onnxruntime-web users to use tensors created from GPU as model
input/output so that a model inferencing can be done without unnecessary
data copy between CPU and GPU for model input/output.
### Examples
An E2E demo/example is being worked on.
Following is some simple demo with code snippet.
Let's first check today how we do:
```js
// STEP.1 - create an inference session:
const mySession = await ort.InferenceSession.create('./my_model.onnx', { executionProviders: ['webgpu'] });
// STEP.2 - create model input: (supposing myImageCpuData is a Float32Array)
const feeds = {
'input_image:0': new ort.Tensor('float32', myImageCpuData, [1, 224, 224, 3])
};
// STEP.3 - run model
const myResults = await mySession.run(feeds);
// STEP.4 - get output data
const myData = myResults['output_image:0'].data; // Float32Array
```
#### for inputs (GPU tensor):
Now, with IO binding, you can create a tensor from a GPU buffer, and
feed it to the model:
```js
// new STEP.2.A - create model input from a GPU buffer: (supposing myInputGpuBuffer is a `GPUBuffer` object with input data)
const feeds = {
'input_image:0': ort.Tensor.fromGpuBuffer(myInputGpuBuffer, { dataType: 'float32', dims: [1, 224, 224, 3] })
};
```
### for outputs (pre-allocated GPU tensor)
you can also do that for output, **if you know the output shape**:
```js
// new STEP.2.B - create model output from a GPU buffer: (supposing myOutputGpuBuffer is a pre-allocated `GPUBuffer` object)
const fetches = {
'output_image:0': ort.Tensor.fromGpuBuffer(myOutputGpuBuffer, { dataType: 'float32', dims: [1, 512, 512, 3] })
};
// new STEP.3 - run model with pre-allocated output (fetches)
const myResults = await mySession.run(feeds, fetches);
```
### for outputs (specify location)
if you do not know the output shape, you can specify the output location
when creating the session:
```js
// new STEP.1 - create an inference session with an option "preferredOutputLocation":
const mySession = await ort.InferenceSession.create('./my_model.onnx', {
executionProviders: ['webgpu'],
preferredOutputLocation: "gpu-buffer"
});
```
if the model has multiple outputs, you can specify them seperately:
```js
// new STEP.1 - create an inference session with an option "preferredOutputLocation":
const mySession = await ort.InferenceSession.create('./my_model.onnx', {
executionProviders: ['webgpu'],
preferredOutputLocation: {
"output_image:0": "gpu-buffer"
}
});
```
now you don't need to prepare the `fetches` object and onnxruntime-web
will prepare output data on the location that specified.
#### read data
when you get the output tensor, you can:
```js
// get the gpu buffer object:
const gpuBuffer = myOutputTensor.gpuBuffer; // GPUBuffer
// get the CPU data asynchronizely
const cpuData = await myOutputTensor.getData();
// get the CPU data asynchronizely and release the underlying GPU resources
const cpuData = await myOutputTensor.getData(true);
// dispose the tensor (release the underlying GPU resources). This tensor object will be invalid after dispose() is called.
myOutputTensor.dispose();
```
#### resource management
JavaScript has GC so you don't need to worry about managing JavaScript
objects. But there are 2 types of resources that are not managed by GC:
- GPU buffer that used in tensors
- Underlying ORT native resources
To simplify, most of the unmanaged resources and handled inside ORT web.
But there are a few resources that need users to manage:
- All external GPU resources, including GPU buffers inside all tensors
created by `Tensor.fromGpuBuffer()`, will not be managed by ORT. User
should manage those GPU buffers themselves.
- When a session is created with `preferredOutputLocation` ==
"gpu-buffer" specified in session options, and the corresponding output
is not pre-allocated, user need to call the output tensor's `dispose()`
or `getData(true)` to manually release the underlying GPU buffers.
- ORT internal errors (including providing a pre-allocated output tensor
with wrong type/dims) will invalidate the whole wasm memory and is not
recoverable. An exception is thrown in this situation.
---
js/web/lib/wasm/binding/ort-wasm.d.ts | 80 +++-
js/web/lib/wasm/jsep/backend-webgpu.ts | 66 ++-
js/web/lib/wasm/jsep/init.ts | 15 +-
.../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 125 ++++-
js/web/lib/wasm/proxy-messages.ts | 32 +-
js/web/lib/wasm/proxy-wrapper.ts | 30 +-
js/web/lib/wasm/session-handler.ts | 51 +-
js/web/lib/wasm/wasm-common.ts | 32 ++
js/web/lib/wasm/wasm-core-impl.ts | 434 ++++++++++++------
js/web/script/test-runner-cli-args.ts | 16 +
js/web/script/test-runner-cli.ts | 21 +-
js/web/test/test-runner.ts | 181 +++++++-
js/web/test/test-types.ts | 14 +
onnxruntime/core/providers/js/js_kernel.h | 2 +-
onnxruntime/wasm/api.cc | 113 ++++-
onnxruntime/wasm/api.h | 58 ++-
onnxruntime/wasm/js_internal_api.js | 178 +++++--
.../azure-pipelines/templates/win-web-ci.yml | 17 +-
18 files changed, 1177 insertions(+), 288 deletions(-)
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 59da1369e152e..b7b2ff4537095 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import type {Tensor} from 'onnxruntime-common';
+
export declare namespace JSEP {
type BackendType = unknown;
type AllocFunction = (size: number) => number;
@@ -9,11 +11,8 @@ export declare namespace JSEP {
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise;
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
type ReleaseKernelFunction = (kernel: number) => void;
- type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number;
- export interface SessionState {
- sessionId: number;
- errors: Array>;
- }
+ type RunFunction =
+ (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number;
}
export interface OrtWasmModule extends EmscriptenModule {
@@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtFree(stringHandle: number): void;
- _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number):
- number;
+ _OrtCreateTensor(
+ dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number,
+ dataLocation: number): number;
_OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number):
number;
_OrtReleaseTensor(tensorHandle: number): void;
+ _OrtCreateBinding(sessionHandle: number): number;
+ _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise;
+ _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number;
+ _OrtClearBoundOutputs(ioBindingHandle: number): void;
+ _OrtReleaseBinding(ioBindingHandle: number): void;
+ _OrtRunWithBinding(
+ sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number,
+ runOptionsHandle: number): Promise;
_OrtRun(
sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number,
- outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number;
+ outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise;
_OrtCreateSessionOptions(
graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number,
@@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule {
// #endregion
// #region JSEP
+ /**
+ * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
+ * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
+ */
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
+ /**
+ * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
+ *
+ * @param context - specify the kernel context pointer.
+ * @param index - specify the index of the output.
+ * @param data - specify the pointer to encoded data of type and dims.
+ */
_JsepOutput(context: number, index: number, data: number): number;
+ /**
+ * [exported from wasm] Get name of an operator node.
+ *
+ * @param kernel - specify the kernel pointer.
+ * @returns the pointer to a C-style UTF8 encoded string representing the node name.
+ */
_JsepGetNodeName(kernel: number): number;
- jsepOnRunStart?(sessionId: number): void;
- jsepOnRunEnd?(sessionId: number): Promise;
- jsepRunPromise?: Promise;
+ /**
+ * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
+ *
+ * @param sessionId - specify the session ID.
+ * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
+ * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
+ * corresponding to the session's ouputNames.
+ * @param buffer - specify the GPU buffer to register.
+ * @param size - specify the original data size in byte.
+ * @returns the GPU data ID for the registered GPU buffer.
+ */
+ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
+ /**
+ * [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
+ *
+ * @param sessionId - specify the session ID.
+ */
+ jsepUnregisterBuffers?: (sessionId: number) => void;
+ /**
+ * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
+ *
+ * @param dataId - specify the GPU data ID
+ * @returns the GPU buffer.
+ */
+ jsepGetBuffer: (dataId: number) => GPUBuffer;
+ /**
+ * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
+ *
+ * @param gpuBuffer - specify the GPU buffer
+ * @param size - specify the original data size in byte.
+ * @param type - specify the tensor type.
+ * @returns the generated downloader function.
+ */
+ jsepCreateDownloader:
+ (gpuBuffer: GPUBuffer, size: number,
+ type: Tensor.GpuBufferDataTypes) => () => Promise;
// #endregion
}
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 5e77a0343b4ee..5bec562b157ac 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Env} from 'onnxruntime-common';
+import {Env, Tensor} from 'onnxruntime-common';
import {configureLogger, LOG_DEBUG} from './log';
-import {TensorView} from './tensor-view';
-import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager';
+import {createView, TensorView} from './tensor-view';
+import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types';
@@ -98,6 +98,11 @@ export class WebGpuBackend {
env: Env;
+ /**
+ * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
+ */
+ sessionExternalDataMapping: Map> = new Map();
+
async initialize(env: Env): Promise {
if (!navigator.gpu) {
// WebGPU is not available.
@@ -192,11 +197,13 @@ export class WebGpuBackend {
}
flush(): void {
- this.endComputePass();
- this.device.queue.submit([this.getCommandEncoder().finish()]);
- this.gpuDataManager.refreshPendingBuffers();
- this.commandEncoder = null;
- this.pendingDispatchNumber = 0;
+ if (this.commandEncoder) {
+ this.endComputePass();
+ this.device.queue.submit([this.getCommandEncoder().finish()]);
+ this.gpuDataManager.refreshPendingBuffers();
+ this.commandEncoder = null;
+ this.pendingDispatchNumber = 0;
+ }
}
/**
@@ -304,12 +311,9 @@ export class WebGpuBackend {
}
async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise {
- const arrayBuffer = await this.gpuDataManager.download(gpuDataId);
-
// the underlying buffer may be changed after the async function is called. so we use a getter function to make sure
// the buffer is up-to-date.
- const data = getTargetBuffer();
- data.set(new Uint8Array(arrayBuffer, 0, data.byteLength));
+ await this.gpuDataManager.download(gpuDataId, getTargetBuffer);
}
alloc(size: number): number {
@@ -372,7 +376,7 @@ export class WebGpuBackend {
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
- LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`);
+ errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`));
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
@@ -387,4 +391,40 @@ export class WebGpuBackend {
this.currentKernelId = null;
}
}
+
+ // #region external buffer
+ registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number {
+ let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
+ if (!sessionInputOutputMapping) {
+ sessionInputOutputMapping = new Map();
+ this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
+ }
+
+ const previousBuffer = sessionInputOutputMapping.get(index);
+ const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
+ sessionInputOutputMapping.set(index, [id, buffer]);
+ return id;
+ }
+ unregisterBuffers(sessionId: number): void {
+ const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
+ if (sessionInputOutputMapping) {
+ sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
+ this.sessionExternalDataMapping.delete(sessionId);
+ }
+ }
+ getBuffer(gpuDataId: number): GPUBuffer {
+ const gpuData = this.gpuDataManager.get(gpuDataId);
+ if (!gpuData) {
+ throw new Error(`no GPU data for buffer: ${gpuDataId}`);
+ }
+ return gpuData.buffer;
+ }
+ createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes):
+ () => Promise {
+ return async () => {
+ const data = await downloadGpuData(this, gpuBuffer, size);
+ return createView(data.buffer, type);
+ };
+ }
+ // #endregion
}
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 78316cbe1c825..6ff3971d720fd 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -3,7 +3,7 @@
import {Env} from 'onnxruntime-common';
-import {JSEP, OrtWasmModule} from '../binding/ort-wasm';
+import {OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';
import {WebGpuBackend} from './backend-webgpu';
@@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext {
this.module.HEAPU32[offset++] = dims[i];
}
return this.module._JsepOutput(this.opKernelContext, index, data);
+ } catch (e) {
+ throw new Error(
+ `Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
+ 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
+ `Error: ${e}`);
} finally {
this.module.stackRestore(stack);
}
@@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => {
init(
// backend
- {backend},
+ backend,
// jsepAlloc()
(size: number) => backend.alloc(size),
@@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => {
(kernel: number) => backend.releaseKernel(kernel),
// jsepRun
- (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => {
+ (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => {
LOG_DEBUG(
'verbose',
- () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${
+ () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
- return backend.computeKernel(kernel, context, sessionState.errors);
+ return backend.computeKernel(kernel, context, errors);
});
}
};
diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
index 92fdd5abc3892..131f7a9bfa29b 100644
--- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
@@ -35,7 +35,7 @@ export interface GpuDataManager {
/**
* copy data from GPU to CPU.
*/
- download(id: GpuDataId): Promise;
+ download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise;
/**
* refresh the buffers that marked for release.
@@ -46,6 +46,19 @@ export interface GpuDataManager {
*/
refreshPendingBuffers(): void;
+ /**
+ * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID.
+ *
+ * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of
+ * the external buffer.
+ */
+ registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number;
+
+ /**
+ * unregister an external buffer for IO Binding.
+ */
+ unregisterExternalBuffer(buffer: GPUBuffer): void;
+
/**
* destroy all gpu buffers. Call this when the session.release is called.
*/
@@ -62,12 +75,56 @@ interface StorageCacheValue {
*/
const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16;
-let guid = 0;
+let guid = 1;
const createNewGpuDataId = () => guid++;
+/**
+ * exported standard download function. This function is used by the session to download the data from GPU, and also by
+ * factory to create GPU tensors with the capacity of downloading data from GPU.
+ *
+ * @param backend - the WebGPU backend
+ * @param gpuBuffer - the GPU buffer to download
+ * @param originalSize - the original size of the data
+ * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer
+ * will be created and returned.
+ */
+export const downloadGpuData =
+ async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array):
+ Promise => {
+ const bufferSize = calcNormalizedBufferSize(originalSize);
+ const gpuReadBuffer = backend.device.createBuffer(
+ // eslint-disable-next-line no-bitwise
+ {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ});
+ try {
+ const commandEncoder = backend.getCommandEncoder();
+ backend.endComputePass();
+ commandEncoder.copyBufferToBuffer(
+ gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */,
+ 0 /* destination offset */, bufferSize /* size */
+ );
+ backend.flush();
+
+ await gpuReadBuffer.mapAsync(GPUMapMode.READ);
+
+ const arrayBuffer = gpuReadBuffer.getMappedRange();
+ if (getTargetBuffer) {
+ // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer.
+ const targetBuffer = getTargetBuffer();
+ targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize));
+ return targetBuffer;
+ } else {
+ // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the
+ // ArrayBuffer.
+ return new Uint8Array(arrayBuffer.slice(0, originalSize));
+ }
+ } finally {
+ gpuReadBuffer.destroy();
+ }
+ };
+
class GpuDataManagerImpl implements GpuDataManager {
// GPU Data ID => GPU Data ( storage buffer )
- storageCache: Map;
+ private storageCache: Map;
// pending buffers for uploading ( data is unmapped )
private buffersForUploadingPending: GPUBuffer[];
@@ -77,11 +134,15 @@ class GpuDataManagerImpl implements GpuDataManager {
// The reusable storage buffers for computing.
private freeBuffers: Map;
+ // The external buffers registered users for IO Binding.
+ private externalBuffers: Map;
+
constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
+ this.externalBuffers = new Map();
}
upload(id: GpuDataId, data: Uint8Array): void {
@@ -143,6 +204,42 @@ class GpuDataManagerImpl implements GpuDataManager {
sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size);
}
+ registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number {
+ let id: number|undefined;
+ if (previousBuffer) {
+ id = this.externalBuffers.get(previousBuffer);
+ if (id === undefined) {
+ throw new Error('previous buffer is not registered');
+ }
+ if (buffer === previousBuffer) {
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
+ id}, buffer is the same, skip.`);
+ return id;
+ }
+ this.externalBuffers.delete(previousBuffer);
+ } else {
+ id = createNewGpuDataId();
+ }
+
+ this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize});
+ this.externalBuffers.set(buffer, id);
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`);
+ return id;
+ }
+
+ unregisterExternalBuffer(buffer: GPUBuffer): void {
+ const id = this.externalBuffers.get(buffer);
+ if (id !== undefined) {
+ this.storageCache.delete(id);
+ this.externalBuffers.delete(buffer);
+ LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`);
+ }
+ }
+
// eslint-disable-next-line no-bitwise
create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData {
const bufferSize = calcNormalizedBufferSize(size);
@@ -193,31 +290,13 @@ class GpuDataManagerImpl implements GpuDataManager {
return cachedData.originalSize;
}
- async download(id: GpuDataId): Promise {
+ async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise {
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('data does not exist');
}
- const commandEncoder = this.backend.getCommandEncoder();
- this.backend.endComputePass();
- const bufferSize = calcNormalizedBufferSize(cachedData.originalSize);
- const gpuReadBuffer = this.backend.device.createBuffer(
- // eslint-disable-next-line no-bitwise
- {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ});
- commandEncoder.copyBufferToBuffer(
- cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */,
- 0 /* destination offset */, bufferSize /* size */
- );
- this.backend.flush();
-
- return new Promise((resolve) => {
- gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => {
- const data = gpuReadBuffer.getMappedRange().slice(0);
- gpuReadBuffer.destroy();
- resolve(data);
- });
- });
+ await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer);
}
refreshPendingBuffers(): void {
diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts
index e5a2d8c2351b8..43f70c23f7193 100644
--- a/js/web/lib/wasm/proxy-messages.ts
+++ b/js/web/lib/wasm/proxy-messages.ts
@@ -3,20 +3,24 @@
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
-/**
- * tuple elements are: ORT element type; dims; tensor data
- */
-export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType];
+export type SerializableTensorMetadata =
+ [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];
-/**
- * tuple elements are: InferenceSession handle; input names; output names
- */
-export type SerializableSessionMetadata = [number, string[], string[]];
+export type GpuBufferMetadata = {
+ gpuBuffer: Tensor.GpuBufferType;
+ download?: () => Promise;
+ dispose?: () => void;
+};
-/**
- * tuple elements are: modeldata.offset, modeldata.length
- */
-export type SerializableModeldata = [number, number];
+export type UnserializableTensorMetadata =
+ [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
+ [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];
+
+export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;
+
+export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];
+
+export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
interface MessageError {
err?: string;
@@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError {
interface MessageRun extends MessageError {
type: 'run';
in ?: {
- sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[];
+ sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[];
options: InferenceSession.RunOptions;
};
- out?: SerializableTensor[];
+ out?: SerializableTensorMetadata[];
}
interface MesssageEndProfiling extends MessageError {
diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts
index 815b223e40379..202209ed3bfed 100644
--- a/js/web/lib/wasm/proxy-wrapper.ts
+++ b/js/web/lib/wasm/proxy-wrapper.ts
@@ -3,7 +3,7 @@
import {Env, env, InferenceSession} from 'onnxruntime-common';
-import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages';
+import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
import * as core from './wasm-core-impl';
import {initializeWebAssembly} from './wasm-factory';
@@ -22,7 +22,7 @@ const createSessionAllocateCallbacks: Array> = [];
const createSessionCallbacks: Array> = [];
const releaseSessionCallbacks: Array> = [];
-const runCallbacks: Array> = [];
+const runCallbacks: Array> = [];
const endProfilingCallbacks: Array> = [];
const ensureWorker = (): void => {
@@ -177,6 +177,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt
export const createSession =
async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
+ // check unsupported options
+ if (options?.preferredOutputLocation) {
+ throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
+ }
ensureWorker();
return new Promise((resolve, reject) => {
createSessionCallbacks.push([resolve, reject]);
@@ -202,17 +206,27 @@ export const releaseSession = async(sessionId: number): Promise => {
};
export const run = async(
- sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[],
- options: InferenceSession.RunOptions): Promise => {
+ sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[],
+ outputs: Array, options: InferenceSession.RunOptions): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
+ // check inputs location
+ if (inputs.some(t => t[3] !== 'cpu')) {
+ throw new Error('input tensor on GPU is not supported for proxy.');
+ }
+ // check outputs location
+ if (outputs.some(t => t)) {
+ throw new Error('pre-allocated output tensor is not supported for proxy.');
+ }
ensureWorker();
- return new Promise((resolve, reject) => {
+ return new Promise((resolve, reject) => {
runCallbacks.push([resolve, reject]);
- const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}};
- proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs));
+ const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU.
+ const message: OrtWasmMessage =
+ {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}};
+ proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs));
});
} else {
- return core.run(sessionId, inputIndices, inputs, outputIndices, options);
+ return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options);
}
};
diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts
index d8c5ae7886fe4..4e00878d0063b 100644
--- a/js/web/lib/wasm/session-handler.ts
+++ b/js/web/lib/wasm/session-handler.ts
@@ -5,12 +5,41 @@ import {readFile} from 'fs';
import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common';
import {promisify} from 'util';
-import {SerializableModeldata} from './proxy-messages';
+import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
+import {isGpuBufferSupportedType} from './wasm-common';
let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise|undefined;
+const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
+ switch (tensor.location) {
+ case 'cpu':
+ return [tensor.type, tensor.dims, tensor.data, 'cpu'];
+ case 'gpu-buffer':
+ return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer'];
+ default:
+ throw new Error(`invalid data location: ${tensor.location} for ${getName()}`);
+ }
+};
+
+const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
+ switch (tensor[3]) {
+ case 'cpu':
+ return new Tensor(tensor[0], tensor[2], tensor[1]);
+ case 'gpu-buffer': {
+ const dataType = tensor[0];
+ if (!isGpuBufferSupportedType(dataType)) {
+ throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);
+ }
+ const {gpuBuffer, download, dispose} = tensor[2];
+ return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose});
+ }
+ default:
+ throw new Error(`invalid data location: ${tensor[3]}`);
+ }
+};
+
export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
private sessionId: number;
@@ -74,25 +103,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
inputIndices.push(index);
});
+ const outputArray: Array = [];
const outputIndices: number[] = [];
Object.entries(fetches).forEach(kvp => {
const name = kvp[0];
- // TODO: support pre-allocated output
+ const tensor = kvp[1];
const index = this.outputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid output '${name}'`);
}
+ outputArray.push(tensor);
outputIndices.push(index);
});
- const outputs =
- await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options);
+ const inputs =
+ inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`));
+ const outputs = outputArray.map(
+ (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null);
+
+ const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
- const result: SessionHandler.ReturnType = {};
- for (let i = 0; i < outputs.length; i++) {
- result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]);
+ const resultMap: SessionHandler.ReturnType = {};
+ for (let i = 0; i < results.length; i++) {
+ resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
- return result;
+ return resultMap;
}
startProfiling(): void {
diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts
index 389773f3e8884..b9eff45e890c4 100644
--- a/js/web/lib/wasm/wasm-common.ts
+++ b/js/web/lib/wasm/wasm-common.ts
@@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
throw new Error(`unsupported logging level: ${logLevel}`);
}
};
+
+/**
+ * Check whether the given tensor type is supported by GPU buffer
+ */
+export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' ||
+ type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32';
+
+/**
+ * Map string data location to integer value
+ */
+export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => {
+ switch (location) {
+ case 'none':
+ return 0;
+ case 'cpu':
+ return 1;
+ case 'cpu-pinned':
+ return 2;
+ case 'texture':
+ return 3;
+ case 'gpu-buffer':
+ return 4;
+ default:
+ throw new Error(`unsupported data location: ${location}`);
+ }
+};
+
+/**
+ * Map integer data location to string value
+ */
+export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined =>
+ (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location];
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index fcca82ab2aa54..5b49a1d4202e3 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -3,10 +3,10 @@
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
-import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages';
+import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
-import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
+import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
@@ -60,9 +60,36 @@ export const initRuntime = async(env: Env): Promise => {
};
/**
- * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded
+ * valid data locations for input/output tensors.
*/
-type SessionMetadata = [number, number[], number[]];
+type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer';
+
+type IOBindingState = {
+ /**
+ * the handle of IO binding.
+ */
+ readonly handle: number;
+
+ /**
+ * the preferred location for each output tensor.
+ *
+ * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'.
+ */
+ readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];
+
+ /**
+ * enum value of the preferred location for each output tensor.
+ */
+ readonly outputPreferredLocationsEncoded: readonly number[];
+};
+
+/**
+ * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
+ */
+type SessionMetadata = [
+ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[],
+ bindingState: IOBindingState|null
+];
const activeSessions = new Map();
@@ -92,6 +119,7 @@ export const createSessionFinalize =
let sessionHandle = 0;
let sessionOptionsHandle = 0;
+ let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];
@@ -108,6 +136,7 @@ export const createSessionFinalize =
const inputNames = [];
const outputNames = [];
+ const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(sessionHandle, i);
if (name === 0) {
@@ -122,15 +151,45 @@ export const createSessionFinalize =
checkLastError('Can\'t get an output name.');
}
outputNamesUTF8Encoded.push(name);
- outputNames.push(wasm.UTF8ToString(name));
+ const nameString = wasm.UTF8ToString(name);
+ outputNames.push(nameString);
+
+ if (!BUILD_DEFS.DISABLE_WEBGPU) {
+ const location = typeof options?.preferredOutputLocation === 'string' ?
+ options.preferredOutputLocation :
+ options?.preferredOutputLocation?.[nameString] ?? 'cpu';
+ if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
+ throw new Error(`Not supported preferred output location: ${location}.`);
+ }
+ outputPreferredLocations.push(location);
+ }
+ }
+
+ // use IO binding only when at least one output is preffered to be on GPU.
+ let bindingState: IOBindingState|null = null;
+ if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
+ ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
+ if (ioBindingHandle === 0) {
+ checkLastError('Can\'t create IO binding.');
+ }
+
+ bindingState = {
+ handle: ioBindingHandle,
+ outputPreferredLocations,
+ outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
+ };
}
- activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]);
+ activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+ if (ioBindingHandle !== 0) {
+ wasm._OrtReleaseBinding(ioBindingHandle);
+ }
+
if (sessionHandle !== 0) {
wasm._OrtReleaseSession(sessionHandle);
}
@@ -161,7 +220,13 @@ export const releaseSession = (sessionId: number): void => {
if (!session) {
throw new Error(`cannot release session. invalid session id: ${sessionId}`);
}
- const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session;
+ const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session;
+
+ if (ioBindingState) {
+ wasm._OrtReleaseBinding(ioBindingState.handle);
+ }
+
+ wasm.jsepUnregisterBuffers?.(sessionId);
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
@@ -169,18 +234,84 @@ export const releaseSession = (sessionId: number): void => {
activeSessions.delete(sessionId);
};
+const prepareInputOutputTensor =
+ (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number):
+ void => {
+ if (!tensor) {
+ tensorHandles.push(0);
+ return;
+ }
+
+ const wasm = getInstance();
+
+ const dataType = tensor[0];
+ const dims = tensor[1];
+ const location = tensor[3];
+
+ let rawData: number;
+ let dataByteLength: number;
+
+ if (dataType === 'string' && location === 'gpu-buffer') {
+ throw new Error('String tensor is not supported on GPU.');
+ }
+
+ if (location === 'gpu-buffer') {
+ const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
+ const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
+ dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
+ rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
+ } else {
+ const data = tensor[2];
+
+ if (Array.isArray(data)) {
+ // string tensor
+ dataByteLength = 4 * data.length;
+ rawData = wasm._malloc(dataByteLength);
+ allocs.push(rawData);
+ let dataIndex = rawData / 4;
+ for (let i = 0; i < data.length; i++) {
+ if (typeof data[i] !== 'string') {
+ throw new TypeError(`tensor data at index ${i} is not a string`);
+ }
+ wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
+ }
+ } else {
+ dataByteLength = data.byteLength;
+ rawData = wasm._malloc(dataByteLength);
+ allocs.push(rawData);
+ wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
+ }
+ }
+
+ const stack = wasm.stackSave();
+ const dimsOffset = wasm.stackAlloc(4 * dims.length);
+ try {
+ let dimIndex = dimsOffset / 4;
+ dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
+ const tensor = wasm._OrtCreateTensor(
+ tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
+ dataLocationStringToEnum(location));
+ if (tensor === 0) {
+ checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
+ }
+ tensorHandles.push(tensor);
+ } finally {
+ wasm.stackRestore(stack);
+ }
+ };
+
/**
* perform inference run
*/
export const run = async(
- sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[],
- options: InferenceSession.RunOptions): Promise => {
+ sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
+ outputTensors: Array, options: InferenceSession.RunOptions): Promise => {
const wasm = getInstance();
const session = activeSessions.get(sessionId);
if (!session) {
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
}
- const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session;
+ const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session;
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
@@ -188,171 +319,200 @@ export const run = async(
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
- const inputValues: number[] = [];
- const inputAllocs: number[] = [];
+ const inputTensorHandles: number[] = [];
+ const outputTensorHandles: number[] = [];
+ const inputOutputAllocs: number[] = [];
+
+ const beforeRunStack = wasm.stackSave();
+ const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
+ const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
+ const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
+ const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
try {
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// create input tensors
for (let i = 0; i < inputCount; i++) {
- const dataType = inputs[i][0];
- const dims = inputs[i][1];
- const data = inputs[i][2];
-
- let dataOffset: number;
- let dataByteLength: number;
-
- if (Array.isArray(data)) {
- // string tensor
- dataByteLength = 4 * data.length;
- dataOffset = wasm._malloc(dataByteLength);
- inputAllocs.push(dataOffset);
- let dataIndex = dataOffset / 4;
- for (let i = 0; i < data.length; i++) {
- if (typeof data[i] !== 'string') {
- throw new TypeError(`tensor data at index ${i} is not a string`);
- }
- wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs);
- }
- } else {
- dataByteLength = data.byteLength;
- dataOffset = wasm._malloc(dataByteLength);
- inputAllocs.push(dataOffset);
- wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset);
- }
+ prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]);
+ }
- const stack = wasm.stackSave();
- const dimsOffset = wasm.stackAlloc(4 * dims.length);
- try {
- let dimIndex = dimsOffset / 4;
- dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
- const tensor = wasm._OrtCreateTensor(
- tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length);
- if (tensor === 0) {
- checkLastError(`Can't create tensor for input[${i}].`);
- }
- inputValues.push(tensor);
- } finally {
- wasm.stackRestore(stack);
- }
+ // create output tensors
+ for (let i = 0; i < outputCount; i++) {
+ prepareInputOutputTensor(
+ outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]);
+ }
+
+ let inputValuesIndex = inputValuesOffset / 4;
+ let inputNamesIndex = inputNamesOffset / 4;
+ let outputValuesIndex = outputValuesOffset / 4;
+ let outputNamesIndex = outputNamesOffset / 4;
+ for (let i = 0; i < inputCount; i++) {
+ wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];
+ wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];
+ }
+ for (let i = 0; i < outputCount; i++) {
+ wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];
+ wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
}
- const beforeRunStack = wasm.stackSave();
- const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
- const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
- const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
- const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
-
- try {
- let inputValuesIndex = inputValuesOffset / 4;
- let inputNamesIndex = inputNamesOffset / 4;
- let outputValuesIndex = outputValuesOffset / 4;
- let outputNamesIndex = outputNamesOffset / 4;
+ if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
+ const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState;
+
+ if (inputNamesUTF8Encoded.length !== inputCount) {
+ throw new Error(`input count from feeds (${
+ inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`);
+ }
+
+ // process inputs
for (let i = 0; i < inputCount; i++) {
- wasm.HEAPU32[inputValuesIndex++] = inputValues[i];
- wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];
+ const index = inputIndices[i];
+ const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
+ if (errorCode !== 0) {
+ checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
+ }
}
+
+ // process pre-allocated outputs
for (let i = 0; i < outputCount; i++) {
- wasm.HEAPU32[outputValuesIndex++] = 0;
- wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
+ const index = outputIndices[i];
+ const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.
+
+ if (location) {
+ // output is pre-allocated. bind the tensor.
+ const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
+ if (errorCode !== 0) {
+ checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
+ }
+ } else {
+ // output is not pre-allocated. reset preferred location.
+ const errorCode =
+ wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]);
+ if (errorCode !== 0) {
+ checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
+ }
+ }
}
+ }
- // jsepOnRunStart is only available when JSEP is enabled.
- wasm.jsepOnRunStart?.(sessionId);
+ let errorCode: number;
- // support RunOptions
- let errorCode = wasm._OrtRun(
+ if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
+ errorCode = await wasm._OrtRunWithBinding(
+ sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
+ } else {
+ errorCode = await wasm._OrtRun(
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset, runOptionsHandle);
+ }
- const runPromise = wasm.jsepRunPromise;
- if (runPromise) {
- // jsepRunPromise is a Promise object. It is only available when JSEP is enabled.
- //
- // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to
- // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than
- // the async functions are finished.
- //
- // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of
- // OrtRun(). If the promise exists, we need to await for the promise to be resolved.
- errorCode = await runPromise;
- }
+ if (errorCode !== 0) {
+ checkLastError('failed to call OrtRun().');
+ }
- const jsepOnRunEnd = wasm.jsepOnRunEnd;
- if (jsepOnRunEnd) {
- // jsepOnRunEnd is only available when JSEP is enabled.
- //
- // it returns a promise, which is resolved or rejected when the following async functions are finished:
- // - collecting GPU validation errors.
- await jsepOnRunEnd(sessionId);
+ const output: TensorMetadata[] = [];
+
+ for (let i = 0; i < outputCount; i++) {
+ const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
+ if (tensor === outputTensorHandles[i]) {
+ // output tensor is pre-allocated. no need to copy data.
+ output.push(outputTensors[i]!);
+ continue;
}
- const output: SerializableTensor[] = [];
+ const beforeGetTensorDataStack = wasm.stackSave();
+ // stack allocate 4 pointer value
+ const tensorDataOffset = wasm.stackAlloc(4 * 4);
- if (errorCode !== 0) {
- checkLastError('failed to call OrtRun().');
- }
+ let keepOutputTensor = false;
+ let type: Tensor.Type|undefined, dataOffset = 0;
+ try {
+ const errorCode = wasm._OrtGetTensorData(
+ tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
+ if (errorCode !== 0) {
+ checkLastError(`Can't access output tensor data on index ${i}.`);
+ }
+ let tensorDataIndex = tensorDataOffset / 4;
+ const dataType = wasm.HEAPU32[tensorDataIndex++];
+ dataOffset = wasm.HEAPU32[tensorDataIndex++];
+ const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
+ const dimsLength = wasm.HEAPU32[tensorDataIndex++];
+ const dims = [];
+ for (let i = 0; i < dimsLength; i++) {
+ dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
+ }
+ wasm._OrtFree(dimsOffset);
- for (let i = 0; i < outputCount; i++) {
- const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
+ const size = dims.reduce((a, b) => a * b, 1);
+ type = tensorDataTypeEnumToString(dataType);
- const beforeGetTensorDataStack = wasm.stackSave();
- // stack allocate 4 pointer value
- const tensorDataOffset = wasm.stackAlloc(4 * 4);
+ const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];
- let type: Tensor.Type|undefined, dataOffset = 0;
- try {
- errorCode = wasm._OrtGetTensorData(
- tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
- if (errorCode !== 0) {
- checkLastError(`Can't access output tensor data on index ${i}.`);
+ if (type === 'string') {
+ if (preferredLocation === 'gpu-buffer') {
+ throw new Error('String tensor is not supported on GPU.');
}
- let tensorDataIndex = tensorDataOffset / 4;
- const dataType = wasm.HEAPU32[tensorDataIndex++];
- dataOffset = wasm.HEAPU32[tensorDataIndex++];
- const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
- const dimsLength = wasm.HEAPU32[tensorDataIndex++];
- const dims = [];
- for (let i = 0; i < dimsLength; i++) {
- dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
+ const stringData: string[] = [];
+ let dataIndex = dataOffset / 4;
+ for (let i = 0; i < size; i++) {
+ const offset = wasm.HEAPU32[dataIndex++];
+ const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
+ stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
- wasm._OrtFree(dimsOffset);
-
- const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b);
- type = tensorDataTypeEnumToString(dataType);
- if (type === 'string') {
- const stringData: string[] = [];
- let dataIndex = dataOffset / 4;
- for (let i = 0; i < size; i++) {
- const offset = wasm.HEAPU32[dataIndex++];
- const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
- stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
+ output.push([type, dims, stringData, 'cpu']);
+ } else {
+ // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
+ // tensor for it. There is no mapping GPU buffer for an empty tensor.
+ if (preferredLocation === 'gpu-buffer' && size > 0) {
+ const gpuBuffer = wasm.jsepGetBuffer(dataOffset);
+ const elementSize = getTensorElementSize(dataType);
+ if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
+ throw new Error(`Unsupported data type: ${type}`);
}
- output.push([type, dims, stringData]);
+
+ // do not release the tensor right now. it will be released when user calls tensor.dispose().
+ keepOutputTensor = true;
+
+ output.push([
+ type, dims, {
+ gpuBuffer,
+ download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type),
+ dispose: () => {
+ wasm._OrtReleaseTensor(tensor);
+ }
+ },
+ 'gpu-buffer'
+ ]);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength));
- output.push([type, dims, data]);
- }
- } finally {
- wasm.stackRestore(beforeGetTensorDataStack);
- if (type === 'string' && dataOffset) {
- wasm._free(dataOffset);
+ output.push([type, dims, data, 'cpu']);
}
+ }
+ } finally {
+ wasm.stackRestore(beforeGetTensorDataStack);
+ if (type === 'string' && dataOffset) {
+ wasm._free(dataOffset);
+ }
+ if (!keepOutputTensor) {
wasm._OrtReleaseTensor(tensor);
}
}
+ }
- return output;
- } finally {
- wasm.stackRestore(beforeRunStack);
+ if (ioBindingState) {
+ wasm._OrtClearBoundOutputs(ioBindingState.handle);
}
+
+ return output;
} finally {
- inputValues.forEach(v => wasm._OrtReleaseTensor(v));
- inputAllocs.forEach(p => wasm._free(p));
+ wasm.stackRestore(beforeRunStack);
+
+ inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
+ outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
+ inputOutputAllocs.forEach(p => wasm._free(p));
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
@@ -380,11 +540,11 @@ export const endProfiling = (sessionId: number): void => {
wasm._OrtFree(profileFileName);
};
-export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => {
+export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {
const buffers: ArrayBufferLike[] = [];
for (const tensor of tensors) {
const data = tensor[2];
- if (!Array.isArray(data) && data.buffer) {
+ if (!Array.isArray(data) && 'buffer' in data) {
buffers.push(data.buffer);
}
}
diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts
index f90f568879146..3f903515694db 100644
--- a/js/web/script/test-runner-cli-args.ts
+++ b/js/web/script/test-runner-cli-args.ts
@@ -51,6 +51,10 @@ Options:
-P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug.
This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10.
-c, --file-cache Enable file cache.
+ -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following:
+ none (default)
+ gpu-tensor use pre-allocated GPU tensors for inputs and outputs
+ gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer'
*** Session Options ***
-u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model.
@@ -109,6 +113,7 @@ export declare namespace TestRunnerCliArgs {
type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn';
type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs';
type BundleMode = 'prod'|'dev'|'perf';
+ type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location';
}
export interface TestRunnerCliArgs {
@@ -140,6 +145,8 @@ export interface TestRunnerCliArgs {
*/
bundleMode: TestRunnerCliArgs.BundleMode;
+ ioBindingMode: TestRunnerCliArgs.IOBindingMode;
+
logConfig: Test.Config['log'];
/**
@@ -416,6 +423,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}});
}
+ // Option: -i=<...>, --io-binding=<...>
+ const ioBindingArg = args['io-binding'] || args.i;
+ const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg;
+ if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) {
+ throw new Error(`not supported io binding mode ${ioBindingMode}`);
+ }
+
// Option: -u, --optimized-model-file-path
const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined;
if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') {
@@ -455,6 +469,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`);
npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`);
npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`);
+ npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`);
npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE');
return {
@@ -467,6 +482,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
logConfig,
profile,
times: perf ? times : undefined,
+ ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'],
optimizedModelFilePath,
graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'],
fileCache,
diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts
index f3764e63fcf45..d8fecec1b8084 100644
--- a/js/web/script/test-runner-cli.ts
+++ b/js/web/script/test-runner-cli.ts
@@ -257,7 +257,7 @@ async function main() {
times?: number): Test.ModelTest {
if (times === 0) {
npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`);
- return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []};
+ return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode};
}
let modelUrl: string|null = null;
@@ -323,6 +323,16 @@ async function main() {
}
}
+ let ioBinding: Test.IOBindingMode;
+ if (backend !== 'webgpu' && args.ioBindingMode !== 'none') {
+ npmlog.warn(
+ 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`);
+ ioBinding = 'none';
+ } else {
+ ioBinding = args.ioBindingMode;
+ }
+
+
npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.');
npmlog.verbose('TestRunnerCli.Init.Model', '===============================================================');
npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`);
@@ -330,7 +340,7 @@ async function main() {
npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`);
npmlog.verbose('TestRunnerCli.Init.Model', '===============================================================');
- return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases};
+ return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding};
}
function tryLocateModelTestFolder(searchPattern: string): string {
@@ -390,6 +400,13 @@ async function main() {
for (const test of tests) {
test.backend = backend;
test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION};
+ if (backend !== 'webgpu' && args.ioBindingMode !== 'none') {
+ npmlog.warn(
+ 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`);
+ test.ioBinding = 'none';
+ } else {
+ test.ioBinding = args.ioBindingMode;
+ }
}
npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.');
npmlog.verbose('TestRunnerCli.Init.Op', '===============================================================');
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index 46d80a9f56f35..628e5408150f8 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators';
import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx';
import {Tensor} from '../lib/onnxjs/tensor';
import {ProtoUtil} from '../lib/onnxjs/util';
-import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common';
+import {createView} from '../lib/wasm/jsep/tensor-view';
+import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common';
import {base64toBuffer, createMockGraph, readFile} from './test-shared';
import {Test} from './test-types';
@@ -136,8 +137,8 @@ async function loadTensors(
}
async function initializeSession(
- modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions,
- fileCache?: FileCacheBuffer): Promise {
+ modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean,
+ sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise {
const preloadModelData: Uint8Array|undefined =
fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined;
Logger.verbose(
@@ -146,8 +147,14 @@ async function initializeSession(
preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`);
const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined;
- const sessionConfig =
- {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile};
+ const sessionConfig = {
+ ...sessionOptions,
+ executionProviders: [backendHint],
+ profiler: profilerConfig,
+ enableProfiling: profile,
+ preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined
+ };
+
let session: ort.InferenceSession;
try {
@@ -181,6 +188,7 @@ export class ModelTestContext {
readonly session: ort.InferenceSession,
readonly backend: string,
readonly perfData: ModelTestContext.ModelTestPerfData,
+ readonly ioBinding: Test.IOBindingMode,
private readonly profile: boolean,
) {}
@@ -232,8 +240,8 @@ export class ModelTestContext {
this.initializing = true;
const initStart = now();
- const session =
- await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache);
+ const session = await initializeSession(
+ modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache);
const initEnd = now();
for (const testCase of modelTest.cases) {
@@ -244,6 +252,7 @@ export class ModelTestContext {
session,
modelTest.backend!,
{init: initEnd - initStart, firstRun: -1, runs: [], count: 0},
+ modelTest.ioBinding,
profile,
);
} finally {
@@ -481,6 +490,130 @@ export class TensorResultValidator {
}
}
+function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
+ if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
+ throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`);
+ }
+ const device = ort.env.webgpu.device as GPUDevice;
+ const gpuBuffer = device.createBuffer({
+ // eslint-disable-next-line no-bitwise
+ usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
+ size: Math.ceil(cpuTensor.data.byteLength / 16) * 16,
+ mappedAtCreation: true
+ });
+ const arrayBuffer = gpuBuffer.getMappedRange();
+ new Uint8Array(arrayBuffer)
+ .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength));
+ gpuBuffer.unmap();
+
+ // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data?
+
+ return ort.Tensor.fromGpuBuffer(
+ gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()});
+}
+
+function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
+ if (!isGpuBufferSupportedType(type)) {
+ throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
+ }
+
+ const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!;
+ const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
+
+ const device = ort.env.webgpu.device as GPUDevice;
+ const gpuBuffer = device.createBuffer({
+ // eslint-disable-next-line no-bitwise
+ usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
+ size: Math.ceil(size / 16) * 16
+ });
+
+ return ort.Tensor.fromGpuBuffer(gpuBuffer, {
+ dataType: type,
+ dims,
+ dispose: () => gpuBuffer.destroy(),
+ download: async () => {
+ const stagingBuffer = device.createBuffer({
+ // eslint-disable-next-line no-bitwise
+ usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
+ size: gpuBuffer.size
+ });
+ const encoder = device.createCommandEncoder();
+ encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size);
+ device.queue.submit([encoder.finish()]);
+
+ await stagingBuffer.mapAsync(GPUMapMode.READ);
+ const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size);
+ stagingBuffer.destroy();
+
+ return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes];
+ }
+ });
+}
+
+export async function sessionRun(options: {
+ session: ort.InferenceSession; feeds: Record;
+ outputsMetaInfo: Record>;
+ ioBinding: Test.IOBindingMode;
+}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> {
+ const session = options.session;
+ const feeds = options.feeds;
+ const fetches: Record = {};
+
+ // currently we only support IO Binding for WebGPU
+ //
+ // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode.
+ // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only.
+ // in 'gpu-device' binding mode, outputs are not pre-allocated.
+ const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location';
+ const shouldUploadOutput = options.ioBinding === 'gpu-tensor';
+ try {
+ if (shouldUploadInput) {
+ // replace the CPU tensors in feeds into GPU tensors
+ for (const name in feeds) {
+ if (Object.hasOwnProperty.call(feeds, name)) {
+ feeds[name] = createGpuTensorForInput(feeds[name]);
+ }
+ }
+ }
+
+ if (shouldUploadOutput) {
+ for (const name in options.outputsMetaInfo) {
+ if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) {
+ const {type, dims} = options.outputsMetaInfo[name];
+ fetches[name] = createGpuTensorForOutput(type, dims);
+ }
+ }
+ }
+
+ const start = now();
+ Logger.verbose('TestRunner', `Timestamp before session run: ${start}`);
+ const outputs = await (
+ shouldUploadOutput ? session.run(feeds, fetches) :
+ session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo)));
+ const end = now();
+ Logger.verbose('TestRunner', `Timestamp after session run: ${end}`);
+
+ // download each output tensor if needed
+ for (const name in outputs) {
+ if (Object.hasOwnProperty.call(outputs, name)) {
+ const tensor = outputs[name];
+ // Tensor.getData(true) release the underlying resource
+ await tensor.getData(true);
+ }
+ }
+
+ return [start, end, outputs];
+ } finally {
+ // dispose the GPU tensors in feeds
+ for (const name in feeds) {
+ if (Object.hasOwnProperty.call(feeds, name)) {
+ const tensor = feeds[name];
+ tensor.dispose();
+ }
+ }
+ }
+}
+
/**
* run a single model test case. the inputs/outputs tensors should already been prepared.
*/
@@ -491,12 +624,11 @@ export async function runModelTestSet(
const validator = new TensorResultValidator(context.backend);
try {
const feeds: Record = {};
+ const outputsMetaInfo: Record = {};
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
- const start = now();
- Logger.verbose('TestRunner', `Timestamp before session run: ${start}`);
- const outputs = await context.session.run(feeds);
- const end = now();
- Logger.verbose('TestRunner', `Timestamp after session run: ${end}`);
+ testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
+ const [start, end, outputs] =
+ await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {
context.perfData.firstRun = end - start;
} else {
@@ -575,6 +707,7 @@ export class ProtoOpTestContext {
private readonly loadedData: Uint8Array; // model data, inputs, outputs
session: ort.InferenceSession;
readonly backendHint: string;
+ readonly ioBindingMode: Test.IOBindingMode;
constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) {
const opsetImport = onnx.OperatorSetIdProto.create(test.opset);
const operator = test.operator;
@@ -713,6 +846,7 @@ export class ProtoOpTestContext {
model.graph.name = test.name;
this.backendHint = test.backend!;
+ this.ioBindingMode = test.ioBinding;
this.loadedData = onnx.ModelProto.encode(model).finish();
// in debug mode, open a new tab in browser for the generated onnx model.
@@ -729,8 +863,11 @@ export class ProtoOpTestContext {
}
}
async init(): Promise {
- this.session = await ort.InferenceSession.create(
- this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions});
+ this.session = await ort.InferenceSession.create(this.loadedData, {
+ executionProviders: [this.backendHint],
+ preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined,
+ ...this.sessionOptions
+ });
}
async dispose(): Promise {
@@ -739,10 +876,11 @@ export class ProtoOpTestContext {
}
async function runProtoOpTestcase(
- session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise {
+ session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode,
+ validator: TensorResultValidator): Promise {
const feeds: Record = {};
- const fetches: string[] = [];
- testCase.inputs!.forEach((input, i) => {
+ const fetches: Record> = {};
+ testCase.inputs.forEach((input, i) => {
if (input.data) {
let data: number[]|BigUint64Array|BigInt64Array = input.data;
if (input.type === 'uint64') {
@@ -756,7 +894,7 @@ async function runProtoOpTestcase(
const outputs: ort.Tensor[] = [];
const expectedOutputNames: string[] = [];
- testCase.outputs!.forEach((output, i) => {
+ testCase.outputs.forEach((output, i) => {
if (output.data) {
let data: number[]|BigUint64Array|BigInt64Array = output.data;
if (output.type === 'uint64') {
@@ -766,11 +904,11 @@ async function runProtoOpTestcase(
}
outputs.push(new ort.Tensor(output.type, data, output.dims));
expectedOutputNames.push(`output_${i}`);
- fetches.push(`output_${i}`);
+ fetches[`output_${i}`] = {dims: output.dims, type: output.type};
}
});
- const results = await session.run(feeds, fetches);
+ const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode});
const actualOutputNames = Object.getOwnPropertyNames(results);
expect(actualOutputNames.length).to.equal(expectedOutputNames.length);
@@ -821,7 +959,8 @@ async function runOpTestcase(
export async function runOpTest(
testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise {
if (context instanceof ProtoOpTestContext) {
- await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint));
+ await runProtoOpTestcase(
+ context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint));
} else {
await runOpTestcase(
context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint));
diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts
index 1f95d1cd8e682..88915e7972383 100644
--- a/js/web/test/test-types.ts
+++ b/js/web/test/test-types.ts
@@ -43,6 +43,18 @@ export declare namespace Test {
*/
export type PlatformCondition = string;
+ /**
+ * The IOBindingMode represents how to test a model with GPU data.
+ *
+ * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation`
+ * will not be set.
+ * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated;
+ * `preferredOutputLocation` will be set to `gpu-buffer`.
+ * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation`
+ * will not be set.
+ */
+ export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location';
+
export interface ModelTestCase {
name: string;
dataFiles: readonly string[];
@@ -54,6 +66,7 @@ export declare namespace Test {
name: string;
modelUrl: string;
backend?: string; // value should be populated at build time
+ ioBinding: IOBindingMode;
platformCondition?: PlatformCondition;
cases: readonly ModelTestCase[];
}
@@ -82,6 +95,7 @@ export declare namespace Test {
inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray;
opset?: OperatorTestOpsetImport;
backend?: string; // value should be populated at build time
+ ioBinding: IOBindingMode;
platformCondition?: PlatformCondition;
attributes?: readonly AttributeValue[];
cases: readonly OperatorTestCase[];
diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h
index 177c0a9e691ed..fdd5c7dee5bfc 100644
--- a/onnxruntime/core/providers/js/js_kernel.h
+++ b/onnxruntime/core/providers/js/js_kernel.h
@@ -196,7 +196,7 @@ class JsKernel : public OpKernel {
}
int status_code = EM_ASM_INT(
- { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); },
+ { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); },
this, reinterpret_cast(p_serialized_kernel_context));
LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc
index 174edabbc91fe..968eece361724 100644
--- a/onnxruntime/wasm/api.cc
+++ b/onnxruntime/wasm/api.cc
@@ -9,6 +9,7 @@
#include "api.h"
#include
+#include
#include
namespace {
@@ -17,6 +18,14 @@ OrtErrorCode g_last_error_code;
std::string g_last_error_message;
} // namespace
+enum DataLocation {
+ DATA_LOCATION_NONE = 0,
+ DATA_LOCATION_CPU = 1,
+ DATA_LOCATION_CPU_PINNED = 2,
+ DATA_LOCATION_TEXTURE = 3,
+ DATA_LOCATION_GPU_BUFFER = 4
+};
+
static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same.");
static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32).");
@@ -223,13 +232,23 @@ void OrtFree(void* ptr) {
}
}
-OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length) {
+OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) {
+ if (data_location != DATA_LOCATION_CPU &&
+ data_location != DATA_LOCATION_CPU_PINNED &&
+ data_location != DATA_LOCATION_GPU_BUFFER) {
+ std::ostringstream ostr;
+ ostr << "Invalid data location: " << data_location;
+ CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str()));
+ return nullptr;
+ }
+
std::vector shapes(dims_length);
for (size_t i = 0; i < dims_length; i++) {
shapes[i] = dims[i];
}
if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
+ // data_location is ignored for string tensor. It is always CPU.
OrtAllocator* allocator = nullptr;
RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator);
@@ -244,12 +263,16 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t*
return UNREGISTER_AUTO_RELEASE(value);
} else {
- OrtMemoryInfo* memoryInfo = nullptr;
- RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memoryInfo);
- REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memoryInfo);
+ OrtMemoryInfo* memory_info = nullptr;
+ if (data_location != DATA_LOCATION_GPU_BUFFER) {
+ RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info);
+ } else {
+ RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info);
+ }
+ REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info);
OrtValue* value = nullptr;
- int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memoryInfo, data, data_length,
+ int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memory_info, data, data_length,
dims_length > 0 ? shapes.data() : nullptr, dims_length,
static_cast(data_type), &value);
@@ -373,15 +396,85 @@ void OrtReleaseRunOptions(OrtRunOptions* run_options) {
Ort::GetApi().ReleaseRunOptions(run_options);
}
+OrtIoBinding* OrtCreateBinding(OrtSession* session) {
+ OrtIoBinding* binding = nullptr;
+ int error_code = CHECK_STATUS(CreateIoBinding, session, &binding);
+ return (error_code == ORT_OK) ? binding : nullptr;
+}
+
+int EMSCRIPTEN_KEEPALIVE OrtBindInput(OrtIoBinding* io_binding,
+ const char* name,
+ OrtValue* input) {
+ return CHECK_STATUS(BindInput, io_binding, name, input);
+}
+
+int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding,
+ const char* name,
+ OrtValue* output,
+ int output_location) {
+ if (output) {
+ return CHECK_STATUS(BindOutput, io_binding, name, output);
+ } else {
+ if (output_location != DATA_LOCATION_NONE &&
+ output_location != DATA_LOCATION_CPU &&
+ output_location != DATA_LOCATION_CPU_PINNED &&
+ output_location != DATA_LOCATION_GPU_BUFFER) {
+ std::ostringstream ostr;
+ ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\".";
+ return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str()));
+ }
+
+ OrtMemoryInfo* memory_info = nullptr;
+ if (output_location != DATA_LOCATION_GPU_BUFFER) {
+ RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info);
+ } else {
+ RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info);
+ }
+ REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info);
+ return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info);
+ }
+}
+
+void OrtClearBoundOutputs(OrtIoBinding* io_binding) {
+ Ort::GetApi().ClearBoundOutputs(io_binding);
+}
+
+void OrtReleaseBinding(OrtIoBinding* io_binding) {
+ Ort::GetApi().ReleaseIoBinding(io_binding);
+}
+
+int OrtRunWithBinding(OrtSession* session,
+ OrtIoBinding* io_binding,
+ size_t output_count,
+ OrtValue** outputs,
+ OrtRunOptions* run_options) {
+ RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding);
+
+ OrtAllocator* allocator = nullptr;
+ RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator);
+
+ size_t binding_output_count = 0;
+ OrtValue** binding_outputs = nullptr;
+ RETURN_ERROR_CODE_IF_ERROR(GetBoundOutputValues, io_binding, allocator, &binding_outputs, &binding_output_count);
+ REGISTER_AUTO_RELEASE_BUFFER(OrtValue*, binding_outputs, allocator);
+
+ if (binding_output_count != output_count) {
+ return CheckStatus(
+ Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Output count is inconsistent with IO Binding output data."));
+ }
+
+ for (size_t i = 0; i < output_count; i++) {
+ outputs[i] = binding_outputs[i];
+ }
+
+ return ORT_OK;
+}
+
int OrtRun(OrtSession* session,
const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count,
const char** output_names, size_t output_count, ort_tensor_handle_t* outputs,
OrtRunOptions* run_options) {
- auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
-#if defined(USE_JSEP)
- EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code);
-#endif
- return status_code;
+ return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
}
char* OrtEndProfiling(ort_session_handle_t session) {
diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h
index 398c901e0e5ed..9a0664697f0ff 100644
--- a/onnxruntime/wasm/api.h
+++ b/onnxruntime/wasm/api.h
@@ -15,6 +15,9 @@
struct OrtSession;
using ort_session_handle_t = OrtSession*;
+struct OrtIoBinding;
+using ort_io_binding_handle_t = OrtIoBinding*;
+
struct OrtSessionOptions;
using ort_session_options_handle_t = OrtSessionOptions*;
@@ -164,9 +167,10 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr);
* @param data_length size of the buffer 'data' in bytes.
* @param dims a pointer to an array of dims. the array should contain (dims_length) element(s).
* @param dims_length the length of the tensor's dimension
+ * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer.
* @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor().
*/
-ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length);
+ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location);
/**
* get type, shape info and data of the specified tensor.
@@ -216,6 +220,58 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio
*/
void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options);
+/**
+ * create an instance of ORT IO binding.
+ */
+ort_io_binding_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateBinding(ort_session_handle_t session);
+
+/**
+ * bind an input tensor to the IO binding instance. A cross device copy will be performed if necessary.
+ * @param io_binding handle of the IO binding
+ * @param name name of the input
+ * @param input handle of the input tensor
+ * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
+ */
+int EMSCRIPTEN_KEEPALIVE OrtBindInput(ort_io_binding_handle_t io_binding,
+ const char* name,
+ ort_tensor_handle_t input);
+
+/**
+ * bind an output tensor or location to the IO binding instance.
+ * @param io_binding handle of the IO binding
+ * @param name name of the output
+ * @param output handle of the output tensor. nullptr for output location binding.
+ * @param output_location specify the memory location of the output tensor data.
+ * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
+ */
+int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding,
+ const char* name,
+ ort_tensor_handle_t output,
+ int output_location);
+
+/**
+ * clear all bound outputs.
+ */
+void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding);
+
+/**
+ * release the specified ORT IO binding.
+ */
+void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding);
+
+/**
+ * inference the model.
+ * @param session handle of the specified session
+ * @param io_binding handle of the IO binding
+ * @param run_options handle of the run options
+ * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
+ */
+int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session,
+ ort_io_binding_handle_t io_binding,
+ size_t output_count,
+ ort_tensor_handle_t* outputs,
+ ort_run_options_handle_t run_options);
+
/**
* inference the model.
* @param session handle of the specified session
diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js
index 15d393f4ce62d..427ad6f6d14f3 100644
--- a/onnxruntime/wasm/js_internal_api.js
+++ b/onnxruntime/wasm/js_internal_api.js
@@ -14,40 +14,156 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRunKernel = runKernel;
- Module['jsepOnRunStart'] = sessionId => {
- Module['jsepRunPromise'] = new Promise(r => {
- Module.jsepRunPromiseResolve = r;
- });
-
- if (Module.jsepSessionState) {
- throw new Error('Session already started');
- }
-
- Module.jsepSessionState = {
- sessionId,
- errors: []
+ // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1)
+ // It removes some overhead in cwarp() and ccall() that we don't need.
+ //
+ // Currently in JSEP build, we only use this for the following functions:
+ // - OrtRun()
+ // - OrtRunWithBinding()
+ // - OrtBindInput()
+ //
+ // Note: about parameters "getFunc" and "setFunc":
+ // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper.
+ //
+ // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a
+ // wrapper for OrtRun() like this (minified):
+ // ```
+ // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun");
+ // ```
+ //
+ // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates
+ // a wrapper for OrtRun() like this (minified):
+ // ```
+ // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q);
+ // ```
+ //
+ // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once
+ // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will
+ // reset d._OrtRun to J.ka when the first time it is called.
+ //
+ // The difference is important because we need to design the async wrapper in a way that it can handle both cases.
+ //
+ // Now, let's look at how the async wrapper is designed to work for both cases:
+ //
+ // - Debug build:
+ // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`.
+ // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async
+ // wrapper function.
+ // Value of `Module["_OrtRun"]` will not be changed again.
+ //
+ // - Release build:
+ // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function.
+ // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async
+ // wrapper function.
+ // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this
+ // function:
+ // ```
+ // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q);
+ // ```
+ // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka).
+ // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored
+ // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper.
+ // Value of `Module["_OrtRun"]` will not be changed again.
+ //
+ // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release
+ // build.
+ //
+ // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an
+ // exported function and set the new value of an exported function.
+ //
+ const jsepWrapAsync = (func, getFunc, setFunc) => {
+ return (...args) => {
+ // cache the async data before calling the function.
+ const previousAsync = Asyncify.currData;
+
+ const previousFunc = getFunc?.();
+ const ret = func(...args);
+ const newFunc = getFunc?.();
+ if (previousFunc !== newFunc) {
+ // The exported function has been updated.
+ // Set the sync function reference to the new function.
+ func = newFunc;
+ // Set the exported function back to the async wrapper.
+ setFunc(previousFunc);
+ // Remove getFunc and setFunc. They are no longer needed.
+ setFunc = null;
+ getFunc = null;
+ }
+
+ // If the async data has been changed, it means that the function started an async operation.
+ if (Asyncify.currData != previousAsync) {
+ // returns the promise
+ return Asyncify.whenDone();
+ }
+ // the function is synchronous. returns the result.
+ return ret;
};
};
- Module['jsepOnRunEnd'] = sessionId => {
- if (Module.jsepSessionState.sessionId !== sessionId) {
- throw new Error('Session ID mismatch');
- }
-
- const errorPromises = Module.jsepSessionState.errors;
- Module.jsepSessionState = null;
-
- return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => {
- Promise.all(errorPromises).then(errors => {
- errors = errors.filter(e => e);
- if (errors.length > 0) {
- reject(new Error(errors.join('\n')));
- } else {
- resolve();
+ // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly.
+ const runAsync = (runAsyncFunc) => {
+ return async (...args) => {
+ try {
+ // Module.jsepSessionState should be null, unless we are in the middle of a session.
+ // If it is not null, it means that the previous session has not finished yet.
+ if (Module.jsepSessionState) {
+ throw new Error('Session already started');
+ }
+ const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []};
+
+ // Run the acyncified function: OrtRun() or OrtRunWithBinding()
+ const ret = await runAsyncFunc(...args);
+
+ // Check if the session is still valid. this object should be the same as the one we set above.
+ if (Module.jsepSessionState !== state) {
+ throw new Error('Session mismatch');
+ }
+
+ // Flush the backend. This will submit all pending commands to the GPU.
+ backend['flush']();
+
+ // Await all pending promises. This includes GPU validation promises for diagnostic purposes.
+ const errorPromises = state.errors;
+ if (errorPromises.length > 0) {
+ let errors = await Promise.all(errorPromises);
+ errors = errors.filter(e => e);
+ if (errors.length > 0) {
+ throw new Error(errors.join('\n'));
+ }
}
- }, reason => {
- reject(reason);
- });
- });
+
+ return ret;
+ } finally {
+ Module.jsepSessionState = null;
+ }
+ };
+ };
+
+ // replace the original functions with asyncified versions
+ Module['_OrtRun'] = runAsync(jsepWrapAsync(
+ Module['_OrtRun'],
+ () => Module['_OrtRun'],
+ v => Module['_OrtRun'] = v));
+ Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync(
+ Module['_OrtRunWithBinding'],
+ () => Module['_OrtRunWithBinding'],
+ v => Module['_OrtRunWithBinding'] = v));
+ Module['_OrtBindInput'] = jsepWrapAsync(
+ Module['_OrtBindInput'],
+ () => Module['_OrtBindInput'],
+ v => Module['_OrtBindInput'] = v);
+
+ // expose webgpu backend functions
+ Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
+ return backend['registerBuffer'](sessionId, index, buffer, size);
+ };
+ Module['jsepUnregisterBuffers'] = sessionId => {
+ backend['unregisterBuffers'](sessionId);
+ };
+ Module['jsepGetBuffer'] = (dataId) => {
+ return backend['getBuffer'](dataId);
+ };
+ Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
+ return backend['createDownloader'](gpuBuffer, size, type);
};
};
diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
index d737376eb99b5..788b02f539821 100644
--- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
+++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml
@@ -29,6 +29,7 @@ jobs:
pool: ${{ parameters.PoolName }}
variables:
+ webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de'
runCodesignValidationInjection: false
timeoutInMinutes: 60
workspace:
@@ -159,12 +160,22 @@ jobs:
npm test -- -e=edge -b=webgl,wasm,xnnpack
workingDirectory: '$(Build.SourcesDirectory)\js\web'
displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)'
- condition: ne('${{ parameters.RunWebGpuTests }}', 'true')
+ condition: eq('${{ parameters.RunWebGpuTests }}', 'false')
- script: |
- npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de
+ npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags)
workingDirectory: '$(Build.SourcesDirectory)\js\web'
displayName: 'Run ort-web tests (ALL backends)'
- condition: ne('${{ parameters.RunWebGpuTests }}', 'false')
+ condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
+ - script: |
+ npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags)
+ workingDirectory: '$(Build.SourcesDirectory)\js\web'
+ displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)'
+ condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
+ - script: |
+ npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags)
+ workingDirectory: '$(Build.SourcesDirectory)\js\web'
+ displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)'
+ condition: eq('${{ parameters.RunWebGpuTests }}', 'true')
- script: |
npm test -- --webgl-texture-pack-mode -b=webgl -e=edge
workingDirectory: '$(Build.SourcesDirectory)\js\web'