Skip to content

Commit

Permalink
[js/webgpu] support error pop and kernel name (#17260)
Browse files Browse the repository at this point in the history
### Description
This PR contains changes to support error pop and kernel name.

- Add a function `JsepGetNodeName` to allow reading kernel name from JS
to C++
- When in debug mode ( `env.debug = true;` ) or in profiling mode (
`env.webgpu.profilingMode = 'default';` ), kernel name will be read from
ORT; otherwise use the kernel pointer ( a number ) as kernel name to
save calls from JS to C++.
- When in debug mode, WebGPU validation errors will be recorded and if
any error occurs, `inferenceSession.run()` will fail (Promise get
rejected). Behavior when not in debug mode is not changed. This is
because recording errors are not zero-overhead, and GPU validation
errors should occur consistently in and not in debug mode.
- Add `jsepOnRunStart()` and `jsepOnRunEnd()` hook to:
   - allow implementation of the features mentioned above.
   - pass session ID to backend.
  • Loading branch information
fs-eire authored Aug 25, 2023
1 parent da180b2 commit 79c4ed9
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 40 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ else()

set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']")
if (onnxruntime_USE_JSEP)
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput")
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()
Expand Down
11 changes: 9 additions & 2 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

declare namespace JSEP {
export declare namespace JSEP {
type BackendType = unknown;
type AllocFunction = (size: number) => number;
type FreeFunction = (size: number) => number;
type UploadFunction = (dataOffset: number, gpuDataId: number, size: number) => void;
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise<void>;
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction = (kernel: number, contextDataOffset: number) => number;
type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number;
export interface SessionState {
sessionId: number;
errors: Array<Promise<string|null>>;
}
}

export interface OrtWasmModule extends EmscriptenModule {
Expand Down Expand Up @@ -71,7 +75,10 @@ export interface OrtWasmModule extends EmscriptenModule {
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;

_JsepOutput(context: number, index: number, data: number): number;
_JsepGetNodeName(kernel: number): number;

jsepOnRunStart?(sessionId: number): void;
jsepOnRunEnd?(sessionId: number): void;
jsepRunPromise?: Promise<number>;
// #endregion
}
Expand Down
34 changes: 23 additions & 11 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ export class WebGpuBackend {
}

/**
* a KernelID -> kernel info mapping. value is [ name, run function, [optional] preprocess_attribute_once function ]
* a KernelID -> kernel info mapping. value is
* [ op_type, name, run function, [optional] preprocess_attribute_once function ]
*/
kernels: Map<number, [string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;
kernels: Map<number, [string, string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;

commandEncoder: GPUCommandEncoder|null = null;
computePassEncoder: GPUComputePassEncoder|null = null;
Expand Down Expand Up @@ -313,13 +314,13 @@ export class WebGpuBackend {
return this.gpuDataManager.release(ptr);
}

createKernel(name: string, kernelId: number, attribute: unknown): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(name);
createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(opType);
if (!op) {
throw new Error(`kernel not implemented: ${name}`);
throw new Error(`kernel not implemented: ${opType}`);
}

this.kernels.set(kernelId, [name, op[0], [op[1], attribute]]);
this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]);
}

releaseKernel(kernelId: number): void {
Expand All @@ -335,14 +336,14 @@ export class WebGpuBackend {
this.kernels.delete(kernelId);
}

computeKernel(kernelId: number, context: ComputeContext): number {
computeKernel(kernelId: number, context: ComputeContext, errors: Array<Promise<string|null>>): number {
const kernel = this.kernels.get(kernelId);
if (!kernel) {
throw new Error(`kernel not created: ${kernelId}`);
}
const [name, kernelEntry, attributes] = kernel;
const [opType, nodeName, kernelEntry, attributes] = kernel;
if (this.currentKernelId !== null) {
throw new Error(`kernel "${name}" is not allowed to be called recursively`);
throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`);
}
this.currentKernelId = kernelId;

Expand All @@ -352,16 +353,27 @@ export class WebGpuBackend {
attributes[0] = undefined;
}

LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "${name}"...`);
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`);

const useErrorScope = this.env.debug;

this.temporaryData = [];
try {
if (useErrorScope) {
this.device.pushErrorScope('validation');
}

kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`);
LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`);
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
errors.push(this.device.popErrorScope().then(
err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null));
}

for (const data of this.temporaryData) {
this.gpuDataManager.release(data.id);
}
Expand Down
16 changes: 11 additions & 5 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Env} from 'onnxruntime-common';

import {OrtWasmModule} from '../binding/ort-wasm';
import {JSEP, OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';

import {WebGpuBackend} from './backend-webgpu';
Expand Down Expand Up @@ -169,16 +169,22 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(name, kernel, attribute),
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
`${kernel}`),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number) => {
LOG_DEBUG('verbose', () => `[WebGPU] jsepRun: kernel=${kernel}, contextDataOffset=${contextDataOffset}`);
(kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context);
return backend.computeKernel(kernel, context, sessionState.errors);
});
}
};
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ export class ProgramManager {
this.backend.flush();

const kernelId = this.backend.currentKernelId!;
const kernelName = this.backend.kernels.get(kernelId)![0];
const kernelInfo = this.backend.kernels.get(kernelId)!;
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;

syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
Expand Down
5 changes: 4 additions & 1 deletion js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,15 @@ export const run = async(
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
}

wasm.jsepOnRunStart?.(sessionId);

// support RunOptions
let errorCode = wasm._OrtRun(
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset, runOptionsHandle);

// eslint-disable-next-line @typescript-eslint/naming-convention
wasm.jsepOnRunEnd?.(sessionId);

const runPromise = wasm.jsepRunPromise;
if (runPromise && typeof runPromise.then !== 'undefined') {
errorCode = await runPromise;
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/js/js_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#include "core/framework/op_kernel.h"

const void* JsepOutput(void* context, int index, void* data) {
uint32_t* data_offset = reinterpret_cast<uint32_t*>(data);
const void* JsepOutput(void* context, int index, const void* data) {
const uint32_t* data_offset = reinterpret_cast<const uint32_t*>(data);
uint32_t dim = *data_offset++;
size_t dim_size = static_cast<size_t>(dim);
std::vector<int64_t> dims;
Expand All @@ -24,3 +24,8 @@ const void* JsepOutput(void* context, int index, void* data) {
LOGF_DEFAULT(VERBOSE, "JsepOutput -- data=%zu", (size_t)(r));
return r;
}

const void* JsepGetNodeName(const void* kernel) {
const auto& name = reinterpret_cast<const onnxruntime::OpKernel*>(kernel)->Node().Name();
return name.c_str();
}
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/js/js_export.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

#include <stddef.h>

// TODO: Move to api.h

extern "C" {
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, void* data);
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, const void* data);
const void* EMSCRIPTEN_KEEPALIVE JsepGetNodeName(const void* context);
};
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/js/js_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ class JsKernel : public OpKernel {
return status;
}

int status_code = EM_ASM_INT({ return Module.jsepRun($0, $1); }, this, reinterpret_cast<int32_t>(p_serialized_kernel_context));
int status_code = EM_ASM_INT(
{ return Module.jsepRunKernel($0, $1, Module.jsepSessionState); },
this, reinterpret_cast<int32_t>(p_serialized_kernel_context));

LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
<< (size_t)(context->Output<Tensor>(0)->DataRaw()) << ".";
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/wasm/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,9 @@ int OrtRun(OrtSession* session,
const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count,
const char** output_names, size_t output_count, ort_tensor_handle_t* outputs,
OrtRunOptions* run_options) {
#if defined(USE_JSEP)
EM_ASM({ Module["jsepRunPromise"] = new Promise(function(r) { Module.jsepRunPromiseResolve = r; }); });
#endif
auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
#if defined(USE_JSEP)
EM_ASM({ Module.jsepRunPromiseResolve($0); }, status_code);
EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code);
#endif
return status_code;
}
Expand Down
58 changes: 49 additions & 9 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,53 @@
'use strict';

// init JSEP
Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
Module.jsepCopy = copy;
Module.jsepCopyAsync = copyAsync;
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRun = run;
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
Module.jsepCopy = copy;
Module.jsepCopyAsync = copyAsync;
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRunKernel = runKernel;

Module['jsepOnRunStart'] = sessionId => {
Module['jsepRunPromise'] = new Promise(r => {
Module.jsepRunPromiseResolve = r;
});

if (Module.jsepSessionState) {
throw new Error('Session already started');
}

Module.jsepSessionState = {
sessionId,
errors: []
};
};

Module['jsepOnRunEnd'] = sessionId => {
if (Module.jsepSessionState.sessionId !== sessionId) {
throw new Error('Session ID mismatch');
}

const errorPromises = Module.jsepSessionState.errors;
Module.jsepSessionState = null;

if (errorPromises.length > 0) {
const runPromise = Module['jsepRunPromise'];
Module['jsepRunPromise'] = new Promise((resolve, reject) => {
Promise.all(errorPromises).then(errors => {
errors = errors.filter(e => e);
if (errors.length > 0) {
reject(new Error(errors.join('\n')));
} else {
resolve(runPromise);
}
}, reason => {
reject(reason);
});
});
}
};
};

0 comments on commit 79c4ed9

Please sign in to comment.