diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..dbebae5b5e9fe 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -23,13 +23,21 @@ import { TimestampQuery, } from './webgpu/types'; -interface CommandInfo { +interface ComputeCommand { readonly kernelId: number; readonly computePipeline: GPUComputePipeline; readonly bindGroup: GPUBindGroup; readonly dispatchGroup: [number, number, number]; } +interface MemcpyCommand { + readonly source: GPUBuffer; + readonly dest: GPUBuffer; + readonly size: number; +} + +type Command = ComputeCommand | MemcpyCommand; + interface KernelInfo { readonly kernelType: string; readonly kernelName: string; @@ -234,9 +242,9 @@ export class WebGpuBackend { env: Env; sessionStatus: SessionState = 'default'; /** - * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + * a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session. */ - capturedCommandList: Map = new Map(); + capturedCommandList: Map = new Map(); /** * a SessionID -> PendingKernelInfo[] mapping for profiling. @@ -837,13 +845,19 @@ export class WebGpuBackend { } return gpuData.buffer; } + + async replayAndDownloadGpuData(gpuBuffer: GPUBuffer, originalSize: number): Promise { + this.replay(); + return downloadGpuData(this, gpuBuffer, originalSize); + } + createDownloader( gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes, ): () => Promise { return async () => { - const data = await downloadGpuData(this, gpuBuffer, size); + const data = await this.replayAndDownloadGpuData(gpuBuffer, size); return createView(data.buffer, type); }; } @@ -909,18 +923,27 @@ export class WebGpuBackend { for (let i = 0; i < length; i++) { const computePassEncoder = this.getComputePassEncoder(); const command = sessionCommandList![i]; - this.writeTimestamp(this.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(command.computePipeline); - computePassEncoder.setBindGroup(0, command.bindGroup); - computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); - this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); - this.pendingDispatchNumber++; - if (this.queryType !== 'none') { - this.pendingKernels.push(sessionPendingKernels![i]); - } - if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + if ('bindGroup' in command) { + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + } else { + const commandEncoder = this.getCommandEncoder(); + this.pendingDispatchNumber++; this.endComputePass(); + commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size); } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { this.flush(); } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 1c6016500e7d3..b468e665bd8cb 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -274,16 +274,28 @@ class GpuDataManagerImpl implements GpuDataManager { const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize); - // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - commandEncoder.copyBufferToBuffer( - sourceGpuDataCache.gpuData.buffer, - 0, - destinationGpuDataCache.gpuData.buffer, - 0, - size, - ); + if (this.backend.sessionStatus === 'capturing') { + const command = { + source: sourceGpuDataCache.gpuData.buffer, + dest: destinationGpuDataCache.gpuData.buffer, + size, + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(command); + + this.backend.pendingDispatchNumber++; + } else { + // GPU copy + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + sourceGpuDataCache.gpuData.buffer, + 0, + destinationGpuDataCache.gpuData.buffer, + 0, + size, + ); + } } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..702d49ccd20b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -41,8 +41,6 @@ export class ProgramManager { ): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); - this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { entries.push({ binding: entries.length, resource: { buffer: input.buffer } }); @@ -68,12 +66,16 @@ export class ProgramManager { }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); sessionCommandList!.push(commandInfo); + } else { + const computePassEncoder = this.backend.getComputePassEncoder(); + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); + computePassEncoder.dispatchWorkgroups(...dispatchGroup); + + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); } - computePassEncoder.setPipeline(buildArtifact.computePipeline); - computePassEncoder.setBindGroup(0, bindGroup); - computePassEncoder.dispatchWorkgroups(...dispatchGroup); - this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; if ( @@ -85,6 +87,7 @@ export class ProgramManager { if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) { this.backend.flush(); } + TRACE_FUNC_END(buildArtifact.programInfo.name); } dispose(): void {