Skip to content

Commit

Permalink
[js/webgpu] Donot record with computePassEncoder when capturing
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 26, 2024
1 parent 09d2ee6 commit 30d3986
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 30 deletions.
51 changes: 37 additions & 14 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ import {
TimestampQuery,
} from './webgpu/types';

interface CommandInfo {
interface ComputeCommand {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
}

interface MemcpyCommand {
readonly source: GPUBuffer;
readonly dest: GPUBuffer;
readonly size: number;
}

type Command = ComputeCommand | MemcpyCommand;

interface KernelInfo {
readonly kernelType: string;
readonly kernelName: string;
Expand Down Expand Up @@ -234,9 +242,9 @@ export class WebGpuBackend {
env: Env;
sessionStatus: SessionState = 'default';
/**
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
* a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();
capturedCommandList: Map<number, Command[]> = new Map();

/**
* a SessionID -> PendingKernelInfo[] mapping for profiling.
Expand Down Expand Up @@ -837,13 +845,19 @@ export class WebGpuBackend {
}
return gpuData.buffer;
}

async replayAndDownloadGpuData(gpuBuffer: GPUBuffer, originalSize: number): Promise<Uint8Array> {
this.replay();
return downloadGpuData(this, gpuBuffer, originalSize);
}

createDownloader(
gpuBuffer: GPUBuffer,
size: number,
type: Tensor.GpuBufferDataTypes,
): () => Promise<Tensor.DataType> {
return async () => {
const data = await downloadGpuData(this, gpuBuffer, size);
const data = await this.replayAndDownloadGpuData(gpuBuffer, size);
return createView(data.buffer, type);
};
}
Expand Down Expand Up @@ -909,18 +923,27 @@ export class WebGpuBackend {
for (let i = 0; i < length; i++) {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
this.pendingKernels.push(sessionPendingKernels![i]);
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
if ('bindGroup' in command) {
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);

this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
this.pendingKernels.push(sessionPendingKernels![i]);
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
this.endComputePass();
}
} else {
const commandEncoder = this.getCommandEncoder();
this.pendingDispatchNumber++;
this.endComputePass();
commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size);
}

if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
this.flush();
}
Expand Down
32 changes: 22 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,28 @@ class GpuDataManagerImpl implements GpuDataManager {

const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize);

// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
if (this.backend.sessionStatus === 'capturing') {
const command = {
source: sourceGpuDataCache.gpuData.buffer,
dest: destinationGpuDataCache.gpuData.buffer,
size,
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(command);

this.backend.pendingDispatchNumber++;
} else {
// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
}
}

registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {
Expand Down
15 changes: 9 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ export class ProgramManager {
): void {
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
const entries = [];
for (const input of inputs) {
entries.push({ binding: entries.length, resource: { buffer: input.buffer } });
Expand All @@ -68,12 +66,16 @@ export class ProgramManager {
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
} else {
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);

this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
}

computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

if (
Expand All @@ -85,6 +87,7 @@ export class ProgramManager {
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
this.backend.flush();
}

TRACE_FUNC_END(buildArtifact.programInfo.name);
}
dispose(): void {
Expand Down

0 comments on commit 30d3986

Please sign in to comment.