Skip to content

Commit

Permalink
[js/webgpu] Refactor timestamp-query and introduce timestamp-query-in…
Browse files Browse the repository at this point in the history
…side-passes (#18894)

We submit kernels in a batch (a fixed number 16 is used except for the
last batch) for better performance. However, timestamp query support is
at pass level so we disable the batch execution in profiling mode in
previous implementation. Actually we can have multiple passes in a batch
so that we don't have to disable batch execution, which is the first
enhancement of this PR.
Furthermore, WebGPU has an extension to support timestamp query inside
passes, which isn't supported by all the platforms (e.g., Windows
supports it, while macOS doesn't). This is expected to have lower cost
compared with multiple passes solution. So this PR also introduce this
support when available.
This PR also refactors some implementation related to kernelInfo, and
try to unify the related kernel names.
  • Loading branch information
Yang Gu authored Jan 13, 2024
1 parent 78e796b commit e803f8e
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 123 deletions.
1 change: 1 addition & 0 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export declare namespace Env {
kernelId: number;
kernelType: string;
kernelName: string;
programName: string;
startTime: number;
endTime: number;
}
Expand Down
233 changes: 187 additions & 46 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Env, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {tensorDataTypeEnumToString} from '../wasm-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency} from './webgpu/types';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types';

interface KernelInfo {
readonly kernelType: string;
readonly kernelName: string;
readonly kernelEntry: RunFunction;
readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown];
}

interface PendingKernelInfo {
readonly kernelId: number;
readonly programName: string;
readonly inputTensorViews: readonly TensorView[];
readonly outputTensorViews: readonly TensorView[];
}

const getProgramInputTensorInfoDependencyKey =
(inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => {
Expand Down Expand Up @@ -122,20 +138,21 @@ export class WebGpuBackend {
return data;
}

/**
* a KernelID -> kernel info mapping. value is
* [ op_type, name, run function, [optional] preprocess_attribute_once function ]
*/
kernels: Map<number, [string, string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;

// KernelID -> kernelInfo mapping
kernels: Map<number, KernelInfo>;
private commandEncoder: GPUCommandEncoder|null = null;
private computePassEncoder: GPUComputePassEncoder|null = null;
maxDispatchNumber = 16;
pendingDispatchNumber = 0;

queryData?: GpuData;
querySet?: GPUQuerySet;
querySetCount = 2;
queryTimeBase?: bigint;
// info of kernels pending submission for a single batch
private pendingKernels: PendingKernelInfo[] = [];
// queryReadBuffer -> pendingKernels mapping for all the batches
private pendingQueries: Map<GPUBuffer, PendingKernelInfo[]> = new Map();
private queryResolveBuffer?: GPUBuffer;
private querySet?: GPUQuerySet;
private queryTimeBase?: bigint;
queryType: TimestampQuery;

env: Env;

Expand All @@ -161,7 +178,9 @@ export class WebGpuBackend {
requiredFeatures,
};

if (adapter.features.has('timestamp-query')) {
if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
} else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
Expand All @@ -188,6 +207,9 @@ export class WebGpuBackend {
};

Object.defineProperty(this.env.webgpu, 'device', {value: this.device});

// init queryType, which is necessary for createKernel
this.setQueryType();
}

dispose(): void {
Expand All @@ -200,24 +222,31 @@ export class WebGpuBackend {
getCommandEncoder(): GPUCommandEncoder {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();

// refresh queryType, as sometimes we only need to enable query for a specific run
this.setQueryType();
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: this.maxDispatchNumber * 2,
});
this.queryResolveBuffer = this.device.createBuffer(
// eslint-disable-next-line no-bitwise
{size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE});
}
}
return this.commandEncoder;
}

getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
const computePassDescriptor: GPUComputePassDescriptor = {};
if (this.isQueryEnabled()) {
if (typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
count: this.querySetCount,
});
}

if (this.queryType === 'at-passes') {
computePassDescriptor.timestampWrites = {
querySet: this.querySet,
beginningOfPassWriteIndex: 0,
endOfPassWriteIndex: 1,
querySet: this.querySet!,
beginningOfPassWriteIndex: this.pendingDispatchNumber * 2,
endOfPassWriteIndex: this.pendingDispatchNumber * 2 + 1,
};
}

Expand All @@ -234,19 +263,95 @@ export class WebGpuBackend {
}

flush(): void {
if (this.commandEncoder) {
this.endComputePass();
this.device.queue.submit([this.getCommandEncoder().finish()]);
this.gpuDataManager.refreshPendingBuffers();
this.commandEncoder = null;
this.pendingDispatchNumber = 0;
if (!this.commandEncoder) {
return;
}
}

isQueryEnabled(): boolean {
return this.device.features.has('timestamp-query') &&
(this.env.webgpu.profiling?.mode === 'default' ||
(!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default'));
TRACE_FUNC_BEGIN();

this.endComputePass();
let queryReadBuffer: GPUBuffer;
if (this.queryType !== 'none') {
this.commandEncoder.resolveQuerySet(
this.querySet!, 0, this.pendingDispatchNumber * 2, this.queryResolveBuffer!, 0);

queryReadBuffer = this.device.createBuffer(
// eslint-disable-next-line no-bitwise
{size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST});

this.pendingQueries.set(queryReadBuffer, this.pendingKernels);
this.pendingKernels = [];
this.commandEncoder.copyBufferToBuffer(
this.queryResolveBuffer!, 0, queryReadBuffer, 0, this.pendingDispatchNumber * 2 * 8);
}

this.device.queue.submit([this.commandEncoder.finish()]);
this.gpuDataManager.refreshPendingBuffers();
this.commandEncoder = null;
this.pendingDispatchNumber = 0;

if (this.queryType !== 'none') {
void queryReadBuffer!.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(queryReadBuffer.getMappedRange());
const pendingKernels = this.pendingQueries.get(queryReadBuffer)!;
for (let i = 0; i < mappedData.length / 2; i++) {
const pendingKernelInfo = pendingKernels[i];
const kernelId = pendingKernelInfo.kernelId;
const kernelInfo = this.kernels.get(kernelId)!;
const kernelType = kernelInfo.kernelType;
const kernelName = kernelInfo.kernelName;
const programName = pendingKernelInfo.programName;
const inputTensorViews = pendingKernelInfo.inputTensorViews;
const outputTensorViews = pendingKernelInfo.outputTensorViews;
const startTimeU64 = mappedData[i * 2];
const endTimeU64 = mappedData[i * 2 + 1];

if (typeof this.queryTimeBase === 'undefined') {
this.queryTimeBase = startTimeU64;
}

const startTime = Number(startTimeU64 - this.queryTimeBase);
const endTime = Number(endTimeU64 - this.queryTimeBase);

if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) {
throw new RangeError('incorrect timestamp range');
}

if (this.env.webgpu.profiling?.ondata) {
this.env.webgpu.profiling.ondata({
version: 1,
inputsMetadata: inputTensorViews.map(
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
outputsMetadata: outputTensorViews.map(
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
kernelId,
kernelType,
kernelName,
programName,
startTime,
endTime,
});
} else {
// if no callback is provided, print the profiling message to console
let inputShapes = '';
inputTensorViews.forEach((value, i) => {
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
let outputShapes = '';
outputTensorViews.forEach((value, i) => {
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
// eslint-disable-next-line no-console
console.log(`[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${
outputShapes}execution time: ${endTime - startTime} ns`);
}
TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`);
}
queryReadBuffer.unmap();
this.pendingQueries.delete(queryReadBuffer);
});
}
TRACE_FUNC_END();
}

/**
Expand Down Expand Up @@ -384,9 +489,18 @@ export class WebGpuBackend {
'info',
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);
this.programManager.run(
artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup,
uniformBufferBinding);

if (this.queryType !== 'none') {
const pendingKernelInfo: PendingKernelInfo = {
kernelId: this.currentKernelId!,
programName: artifact.programInfo.name,
inputTensorViews,
outputTensorViews,
};
this.pendingKernels.push(pendingKernelInfo);
}

this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);

TRACE_FUNC_END(program.name);
return outputTensorViews;
Expand Down Expand Up @@ -414,13 +528,19 @@ export class WebGpuBackend {
return this.gpuDataManager.release(ptr);
}

createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(opType);
createKernel(kernelType: string, kernelId: number, attribute: unknown, kernelName: string): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(kernelType);
if (!op) {
throw new Error(`kernel not implemented: ${opType}`);
throw new Error(`kernel not implemented: ${kernelType}`);
}

this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]);
const kernelInfo: KernelInfo = {
kernelType,
kernelName,
kernelEntry: op[0],
attributes: [op[1], attribute],
};
this.kernels.set(kernelId, kernelInfo);
}

releaseKernel(kernelId: number): void {
Expand All @@ -441,9 +561,12 @@ export class WebGpuBackend {
if (!kernel) {
throw new Error(`kernel not created: ${kernelId}`);
}
const [opType, nodeName, kernelEntry, attributes] = kernel;
const kernelType = kernel.kernelType;
const kernelName = kernel.kernelName;
const kernelEntry = kernel.kernelEntry;
const attributes = kernel.attributes;
if (this.currentKernelId !== null) {
throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`);
throw new Error(`kernel "[${kernelType}] ${kernelName}" is not allowed to be called recursively`);
}
this.currentKernelId = kernelId;

Expand All @@ -453,7 +576,7 @@ export class WebGpuBackend {
attributes[0] = undefined;
}

LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`);
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${kernelType}] ${kernelName}"...`);

const useErrorScope = this.env.debug;

Expand All @@ -466,12 +589,12 @@ export class WebGpuBackend {
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`));
errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`));
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
errors.push(this.device.popErrorScope().then(
err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null));
err => err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null));
}

for (const data of this.temporaryData) {
Expand Down Expand Up @@ -516,5 +639,23 @@ export class WebGpuBackend {
return createView(data.buffer, type);
};
}
writeTimestamp(index: number): void {
if (this.queryType !== 'inside-passes') {
return;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
(this.computePassEncoder as any).writeTimestamp(this.querySet, index);
}
setQueryType(): void {
this.queryType = 'none';
if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) {
if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
this.queryType = 'inside-passes';
} else if (this.device.features.has('timestamp-query')) {
this.queryType = 'at-passes';
}
}
}
// #endregion
}
5 changes: 2 additions & 3 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
(kernelType: string, kernelId: number, attribute: unknown) =>
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
Expand Down
Loading

0 comments on commit e803f8e

Please sign in to comment.