diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 1221b52cd4985..4f85c3b46e253 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 24d7062c85fcb..5dd715191c830 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,6 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, + captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; /** * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule { * @returns the GPU data ID for the registered GPU buffer. */ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; /** * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. * @@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule { (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns */ - jsepOnRunStart: () => void; + jsepOnReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a48fe99570abf..e1faecfc046e3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} interface KernelInfo { readonly kernelType: string; @@ -103,6 +110,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -155,6 +169,16 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -228,6 +252,7 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -238,7 +263,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -494,7 +519,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -502,6 +527,9 @@ export class WebGpuBackend { outputTensorViews, }; this.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); @@ -672,7 +700,71 @@ export class WebGpuBackend { } } } - onRunStart(): void { + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + if (!sessionCommandList) { + sessionCommandList = []; + this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); + sessionPendingKernels = []; + this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f1794d71579bf..786ae41646554 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay()); }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 72eb9713e26a8..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,7 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,8 +50,20 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 789ac70a6913a..a34b6190b7244 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,6 +5,8 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 41ab2d52ca209..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 046336dc9cac0..37b9ed6a1002f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -139,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,6 +235,8 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; + const inputNames = []; const outputNames = []; const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; @@ -256,12 +258,20 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } @@ -281,7 +291,9 @@ export const createSession = async( }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -313,13 +325,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -328,70 +343,75 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -404,7 +424,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -427,13 +452,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -449,7 +476,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -486,9 +513,12 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -595,10 +625,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index af9658271d210..308e1c7d952d9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -3,6 +3,7 @@ #include "js_execution_provider.h" +#include #include #include #include @@ -681,9 +682,13 @@ std::unique_ptr RegisterKernels() { using namespace js; -JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, preferred_data_layout_{info.data_layout} { + if (session_options) { + enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; + LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; + } } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -751,4 +756,46 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } +Status JsExecutionProvider::OnRunStart() { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; + EM_ASM({ Module.jsepCaptureBegin(); }); + } + return Status::OK(); +} + +Status JsExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + EM_ASM({ Module.jsepCaptureEnd(); }); + is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool JsExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status JsExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + EM_ASM({ Module.jsepReplay(); }); + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 39d43498c0717..91a3256ec2bd5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -5,6 +5,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" @@ -38,7 +39,7 @@ struct JsExecutionProviderInfo { class JsExecutionProvider : public IExecutionProvider { public: - JsExecutionProvider(const JsExecutionProviderInfo& info); + JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ~JsExecutionProvider() override; std::vector> GetCapability( @@ -57,7 +58,22 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + + Status OnRunStart() override; + Status OnRunEnd(bool sync_stream) override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc index 5b7329a87cf6a..cbdf99f702150 100644 --- a/onnxruntime/core/providers/js/js_provider_factory.cc +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -10,21 +10,22 @@ namespace onnxruntime { struct JsProviderFactory : IExecutionProviderFactory { - JsProviderFactory(const ProviderOptions& provider_options) - : info_{provider_options} { + JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) + : info_{provider_options}, session_options_(session_options) { } std::unique_ptr CreateProvider() override { - return std::make_unique(info_); + return std::make_unique(info_, session_options_); } private: JsExecutionProviderInfo info_; + const SessionOptions* session_options_; }; std::shared_ptr JsProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - return std::make_shared(provider_options); + const ProviderOptions& provider_options, const SessionOptions* session_options) { + return std::make_shared(provider_options, session_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h index dbabe255c2d7b..510b0fb4248ca 100644 --- a/onnxruntime/core/providers/js/js_provider_factory_creator.h +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -9,9 +9,11 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; struct JsProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c8fc812fe1238..1dfba6039ded7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -146,28 +146,29 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { - bool nodes_on_cpu_and_cuda_eps_only = true; +static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { + bool nodes_on_cpu_and_cuda_and_js_eps_only = true; for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || + node_provider == kJsExecutionProvider) && node_provider != kCpuExecutionProvider) { - nodes_on_cpu_and_cuda_eps_only = false; + nodes_on_cpu_and_cuda_and_js_eps_only = false; break; } } - // If we see nodes assigned to EPs other than CPU or CUDA + // If we see nodes assigned to EPs other than CPU, or CUDA/JS // (or) if there are Memcpy nodes, then all compute nodes have - // not been parititoned to the CUDA EP. + // not been parititoned to the CUDA/JS EP. // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy // involved as shape subgraphs will be forced onto CPU and these will not have // Memcpy nodes involved. - return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); } static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -1761,7 +1762,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently graph capture is only considered by CUDA EP, TRT EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1774,47 +1775,62 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for JS EP: + // If the JS EP is part of the providers list for this session AND + // The JS EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the JS EP, + // Then the JS EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kJsExecutionProvider}; for (auto& it : cuda_graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // Graphs capture can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << "as the model has control flow nodes which can't be supported by " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the graph capture feature as requested by the user " + "as the model has control flow nodes which can't be supported by" + + target_ep->Type())); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { - // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { + // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the graph capture feature as requested by the user " + " as all compute graph nodes have not been partitioned to the " + + target_ep->Type())); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the graph capture feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the graph capture feature."; } } else { // Following code path is for TRT EP currently. diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ac059bfd00668..de5d93ef0a434 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -149,7 +149,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } - options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7e9c0a6f99c32..cbc60c70b57aa 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => { /** * init JSEP */ -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; @@ -33,6 +33,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; + Module.jsepCaptureBegin = captureBegin; + Module.jsepCaptureEnd = captureEnd; + Module.jsepReplay = replay; // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. @@ -181,16 +184,16 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module['jsepUnregisterBuffers'] = sessionId => { - backend['unregisterBuffers'](sessionId); - }; Module['jsepGetBuffer'] = (dataId) => { return backend['getBuffer'](dataId); }; Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepOnRunStart'] = () => { - return backend['onRunStart'](); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; };