Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support capture and replay for jsep #18989

Merged
merged 33 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ccc89af
[js/webgpu] Add record/replay support
qjia7 Dec 27, 2023
c675f71
Add record/replay support in c++
qjia7 Dec 28, 2023
547e005
support record/replay in js
qjia7 Dec 29, 2023
62d64b8
remove unused codes
qjia7 Jan 3, 2024
be58e40
add EP option graphCaptureEnabled
qjia7 Jan 5, 2024
012066b
add sessionId to capture/replay methods
qjia7 Jan 8, 2024
db02615
Add releaseSession interface
qjia7 Jan 8, 2024
2d0c878
Create an internal buffer for each external buffer
qjia7 Jan 8, 2024
c00b29b
Revert "Create an internal buffer for each external buffer"
qjia7 Jan 8, 2024
e1a4bc4
throw errrors when not supported
qjia7 Jan 9, 2024
387ff44
only bind input/output once for IOBinding when graphCaptureEnabled =
qjia7 Jan 9, 2024
79f392c
nits
qjia7 Jan 10, 2024
030d347
update name and annotation
qjia7 Jan 10, 2024
c4cfde0
fix format issues
qjia7 Jan 10, 2024
d105c52
fix lint/format errors
qjia7 Jan 11, 2024
c5137c6
Merge branch 'main' into record_and_replay
qjia7 Jan 15, 2024
a5adf02
Merge branch 'main' into record_and_replay
qjia7 Jan 15, 2024
cc2ff91
nits
qjia7 Jan 15, 2024
e630dbf
enable timestamp query
qjia7 Jan 18, 2024
4c313ad
address Yulong's comments
qjia7 Jan 19, 2024
3f3c6df
reuse the storage buffer
qjia7 Jan 19, 2024
b992f6c
Revert "reuse the storage buffer"
qjia7 Jan 22, 2024
2f13fcd
Merge branch 'main' into record_and_replay
qjia7 Jan 23, 2024
2172984
integrate setQueryType changes
qjia7 Jan 23, 2024
6e0ef20
flush the left commands before status changed
qjia7 Jan 23, 2024
b785a05
address comments
qjia7 Jan 25, 2024
3a80d5c
rename to enableGraphCapture and move to SessionOptions
qjia7 Jan 26, 2024
b6f5d95
Merge branch 'main' into record_and_replay
qjia7 Jan 26, 2024
dc8cc2b
nits
qjia7 Jan 26, 2024
b0de471
Merge branch 'main' into record_and_replay
qjia7 Jan 27, 2024
dff25fa
Address Yulong's comments
qjia7 Jan 29, 2024
1beb3f1
further simplify if (!enableGraphCapture || !inputOutputBound)
qjia7 Jan 29, 2024
59f5f92
call OrtClearBoundOutputs when release session
qjia7 Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export declare namespace InferenceSession {
optimizedModelFilePath?: string;

/**
* Wether enable profiling.
* Whether enable profiling.
*
* This setting is a placeholder for a future use.
*/
Expand Down Expand Up @@ -154,6 +154,12 @@ export declare namespace InferenceSession {
*/
preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};

/**
* Whether enable graph capture.
* This setting is available only in ONNXRuntime Web for WebGPU EP.
*/
enableGraphCapture?: boolean;

/**
* Store configurations for a session. See
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/
Expand Down Expand Up @@ -238,7 +244,6 @@ export declare namespace InferenceSession {
export interface WebGpuExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webgpu';
preferredLayout?: 'NCHW'|'NHWC';
graphCaptureEnabled?: boolean;
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webnn';
Expand Down
15 changes: 9 additions & 6 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ export declare namespace JSEP {
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
type CaptureBeginFunction = (sessionHandle: number) => void;
type CaptureEndFunction = (sessionHandle: number) => void;
type ReplayFunction = (sessionHandle: number) => void;
type CaptureBeginFunction = () => void;
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;
}

export interface OrtWasmModule extends EmscriptenModule {
Expand Down Expand Up @@ -181,11 +181,14 @@ export interface OrtWasmModule extends EmscriptenModule {
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Called when InferenceSession.run started.
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
* _OrtRun[WithBinding]() is called.
* @param sessionId - specify the session ID.
*/
jsepOnRunStart: () => void;
jsepOnRunStart: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Release a session.
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
* called.
* @param sessionId - specify the session ID.
* @returns
*/
Expand Down
54 changes: 27 additions & 27 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, StatusType, TimestampQuery} from './webgpu/types';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';

interface CommandInfo {
readonly kernelId: number;
Expand Down Expand Up @@ -111,9 +111,9 @@ export class WebGpuBackend {
programManager: ProgramManager;

/**
* representing the session ID of which is currently being captured/replay.
* `null` means no session is being captured.
* only valid when captureGraphEnabled = true.
* representing the session ID of which is currently being run.
* `null` means no session is being run.
* only valid when session.run is executed.
*/
currentSessionId: number|null = null;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -169,9 +169,9 @@ export class WebGpuBackend {
queryType: TimestampQuery;

env: Env;
status: StatusType = StatusType.default;
sessionStatus: SessionState = 'default';
/**
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
* a SessionID -> CommandInfo[] mapping.
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();

Expand Down Expand Up @@ -519,7 +519,7 @@ export class WebGpuBackend {
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);

if (this.queryType !== 'none' || this.status === StatusType.capture) {
if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
const pendingKernelInfo: PendingKernelInfo = {
kernelId: this.currentKernelId!,
programName: artifact.programInfo.name,
Expand Down Expand Up @@ -701,32 +701,31 @@ export class WebGpuBackend {
}
}

captureBegin(sessionHandle: number): void {
LOG_DEBUG('info', () => `captureBegin ${sessionHandle}`);
this.currentSessionId = sessionHandle;
let sessionCommandList = this.capturedCommandList.get(sessionHandle);
let sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle);
captureBegin(): void {
LOG_DEBUG('info', () => 'captureBegin');
let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);

Check warning

Code scanning / CodeQL

Useless assignment to local variable Warning

The initial value of sessionPendingKernels is unused, since it is always overwritten.
if (!sessionCommandList) {
sessionCommandList = [];
this.capturedCommandList.set(sessionHandle, sessionCommandList);
this.capturedCommandList.set(this.currentSessionId!, sessionCommandList);
sessionPendingKernels = [];
this.capturedPendingKernels.set(sessionHandle, sessionPendingKernels);
this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels);
}
this.status = StatusType.capture;
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'capturing';
}
captureEnd(sessionHandle: number): void {
LOG_DEBUG('info', () => `captureEnd ${sessionHandle}`);
captureEnd(): void {
LOG_DEBUG('info', () => 'captureEnd');
// flush the left commands before we change the status.
this.flush();
this.currentSessionId = null;
this.status = StatusType.default;
this.sessionStatus = 'default';
}
replay(sessionHandle: number): void {
LOG_DEBUG('info', () => `replay ${sessionHandle}`);
this.currentSessionId = sessionHandle;
this.status = StatusType.replay;
const sessionCommandList = this.capturedCommandList.get(sessionHandle);
const sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle);
replay(): void {
LOG_DEBUG('info', () => 'replay');
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
this.sessionStatus = 'replaying';
const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
const length = sessionCommandList!.length;
this.pendingKernels = [];
for (let i = 0; i < length; i++) {
Expand All @@ -750,7 +749,7 @@ export class WebGpuBackend {
}
// flush the left commands before we change the status.
this.flush();
this.status = StatusType.default;
this.sessionStatus = 'default';
}

onReleaseSession(sessionId: number): void {
Expand All @@ -764,7 +763,8 @@ export class WebGpuBackend {
this.gpuDataManager.onReleaseSession(sessionId);
}

onRunStart(): void {
onRunStart(sessionId: number): void {
this.currentSessionId = sessionId;
this.setQueryType();
}
}
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
return backend.computeKernel(kernel, context, errors);
},
// jsepCaptureBegin
(sessionHandle: number) => backend.captureBegin(sessionHandle),
() => backend.captureBegin(),
// jsepCaptureEnd
(sessionHandle: number) => backend.captureEnd(sessionHandle),
() => backend.captureEnd(),
// jsepReplay
(sessionHandle: number) => backend.replay(sessionHandle));
() => backend.replay());
};
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';

import {GpuData, GpuDataId, GpuDataType, StatusType} from './types';
import {GpuData, GpuDataId, GpuDataType} from './types';

/**
* manages GpuDataId -> GpuBuffer
Expand Down Expand Up @@ -331,7 +331,7 @@ class GpuDataManagerImpl implements GpuDataManager {
return;
}

if (this.backend.status === StatusType.default) {
if (this.backend.sessionStatus === 'default') {
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';

import {createShaderHelper} from './ops/common';
import {Artifact, GpuData, ProgramInfo, StatusType} from './types';
import {Artifact, GpuData, ProgramInfo} from './types';

/**
* ProgramManager is the main class behind running computations
Expand Down Expand Up @@ -51,7 +51,7 @@ export class ProgramManager {
const bindGroup = device.createBindGroup(
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});

if (this.backend.status === StatusType.capture) {
if (this.backend.sessionStatus === 'capturing') {
const commandInfo = {
kernelId: this.backend.currentKernelId!,
computePipeline: buildArtifact.computePipeline,
Expand Down
6 changes: 1 addition & 5 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ import {TensorView} from '../tensor-view';

import {ShaderHelper} from './ops/common';

export enum StatusType {
default = 0,
capture = 1,
replay = 2
}
export type SessionState = 'default'|'capturing'|'replaying';

export enum GpuDataType {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
default = 0,
Expand Down
24 changes: 12 additions & 12 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,6 @@ const setExecutionProviders =
`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`);
}
}
if (webgpuOptions?.graphCaptureEnabled !== undefined) {
if (typeof webgpuOptions.graphCaptureEnabled !== 'boolean') {
throw new Error(`graphCaptureEnabled must be a boolean value: ${webgpuOptions.graphCaptureEnabled}`);
}
const keyDataOffset = allocWasmString('graphCaptureEnabled', allocs);
const valueDataOffset = allocWasmString(webgpuOptions.graphCaptureEnabled.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'graphCaptureEnabled' - ${
webgpuOptions.graphCaptureEnabled}.`);
}
}
}
break;
case 'wasm':
Expand Down Expand Up @@ -180,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n
setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);
}

if (sessionOptions.enableGraphCapture !== undefined) {
if (typeof sessionOptions.enableGraphCapture !== 'boolean') {
throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`);
}
const keyDataOffset = allocWasmString('enableGraphCapture', allocs);
const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
checkLastError(
`Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`);
}
}

if (sessionOptions.freeDimensionOverrides) {
for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) {
if (typeof name !== 'string') {
Expand Down
Loading