Skip to content

Commit

Permalink
[js/webgpu] Introduce trace support
Browse files Browse the repository at this point in the history
This is to leverage console.timeStamp to add a single marker to
browsers' (Only Chromium and Firefox support it) performance tool. With
this support, we can dump both CPU and GPU timestamps, and use
post-processing tool to clearly understand the calibrated timeline. A
demo tool can be found at https://github.com/webatintel/ort-test, and
more detailed info can be found at
https://docs.google.com/document/d/1TuVxjE8jnELBXdhI4QGFgMnUqQn6Q53QA9y4a_dH688/edit.
  • Loading branch information
Yang Gu committed Dec 25, 2023
1 parent 9dd9461 commit 5a9ddea
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 2 deletions.
7 changes: 7 additions & 0 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ export declare namespace Env {
*/
simd?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @defaultValue `false`
*/
trace?: boolean;

/**
* Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero
* value indicates no timeout is set.
Expand Down
1 change: 1 addition & 0 deletions js/common/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ export * from './backend.js';
export * from './env.js';
export * from './inference-session.js';
export * from './tensor.js';
export * from './trace.js';
export * from './onnx-value.js';
export * from './training-session.js';
5 changes: 5 additions & 0 deletions js/common/lib/inference-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js';

type SessionOptions = InferenceSessionInterface.SessionOptions;
type RunOptions = InferenceSessionInterface.RunOptions;
Expand All @@ -20,6 +21,7 @@ export class InferenceSession implements InferenceSessionInterface {
run(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
TRACE_FUNC_BEGIN();
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
Expand Down Expand Up @@ -117,6 +119,7 @@ export class InferenceSession implements InferenceSessionInterface {
}
}
}
TRACE_FUNC_END();
return returnValue;
}

Expand All @@ -132,6 +135,7 @@ export class InferenceSession implements InferenceSessionInterface {
static async create(
arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number,
arg3?: SessionOptions): Promise<InferenceSessionInterface> {
TRACE_FUNC_BEGIN();
// either load from a file or buffer
let filePathOrUint8Array: string|Uint8Array;
let options: SessionOptions = {};
Expand Down Expand Up @@ -196,6 +200,7 @@ export class InferenceSession implements InferenceSessionInterface {
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
TRACE_FUNC_END();
return new InferenceSession(handler);
}

Expand Down
44 changes: 44 additions & 0 deletions js/common/lib/trace.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from './env-impl.js';

export const TRACE = (deviceType: string, label: string) => {
if (!env.wasm.trace) {
return;
}
// eslint-disable-next-line no-console
console.timeStamp(`${deviceType}::ORT::${label}`);
};

const TRACE_FUNC = (msg: string, extraMsg?: string) => {
const stack = new Error().stack?.split(/\r\n|\r|\n/g) || [];
let hasTraceFunc = false;
for (let i = 0; i < stack.length; i++) {
if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) {
let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`;
if (extraMsg) {
label += `::${extraMsg}`;
}
TRACE('CPU', label);
return;
}
if (stack[i].includes('TRACE_FUNC')) {
hasTraceFunc = true;
}
}
};

export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
if (!env.wasm.trace) {
return;
}
TRACE_FUNC('BEGIN', extraMsg);
};

export const TRACE_FUNC_END = (extraMsg?: string) => {
if (!env.wasm.trace) {
return;
}
TRACE_FUNC('END', extraMsg);
};
4 changes: 4 additions & 0 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export const initializeFlags = (): void => {
env.wasm.proxy = false;
}

if (typeof env.wasm.trace !== 'boolean') {
env.wasm.trace = false;
}

if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Env, Tensor} from 'onnxruntime-common';
import {Env, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
Expand Down Expand Up @@ -263,6 +263,7 @@ export class WebGpuBackend {
run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[],
createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] {
TRACE_FUNC_BEGIN(program.name);
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
Expand Down Expand Up @@ -387,6 +388,7 @@ export class WebGpuBackend {
artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup,
uniformBufferBinding);

TRACE_FUNC_END(program.name);
return outputTensorViews;
}

Expand Down
6 changes: 6 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {tensorDataTypeEnumToString} from '../../wasm-common';
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';
Expand Down Expand Up @@ -35,6 +37,7 @@ export class ProgramManager {
run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[],
inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number],
uniformBufferBinding: GPUBindingResource|undefined): void {
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;

const computePassEncoder = this.backend.getComputePassEncoder();
Expand Down Expand Up @@ -128,11 +131,13 @@ export class ProgramManager {
if (this.backend.pendingDispatchNumber >= 16) {
this.backend.flush();
}
TRACE_FUNC_END(buildArtifact.programInfo.name);
}
dispose(): void {
// this.repo.forEach(a => this.glContext.deleteProgram(a.program));
}
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
Expand All @@ -147,6 +152,7 @@ export class ProgramManager {
const computePipeline = device.createComputePipeline(
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});

TRACE_FUNC_END(programInfo.name);
return {programInfo, computePipeline};
}

Expand Down
6 changes: 5 additions & 1 deletion js/web/lib/wasm/session-handler-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

import {readFile} from 'node:fs/promises';
import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
Expand Down Expand Up @@ -54,6 +54,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
TRACE_FUNC_BEGIN();
let model: Parameters<typeof createSession>[0];

if (typeof pathOrBuffer === 'string') {
Expand All @@ -70,6 +71,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

[this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
TRACE_FUNC_END();
}

async dispose(): Promise<void> {
Expand All @@ -78,6 +80,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan

async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise<SessionHandler.ReturnType> {
TRACE_FUNC_BEGIN();
const inputArray: Tensor[] = [];
const inputIndices: number[] = [];
Object.entries(feeds).forEach(kvp => {
Expand Down Expand Up @@ -115,6 +118,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
TRACE_FUNC_END();
return resultMap;
}

Expand Down

0 comments on commit 5a9ddea

Please sign in to comment.