From 79c4ed9a45c81d5fa71847789ccecb554ffdf76e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 25 Aug 2023 08:08:15 -0700 Subject: [PATCH] [js/webgpu] support error pop and kernel name (#17260) ### Description This PR contains changes to support error pop and kernel name. - Add a function `JsepGetNodeName` to allow reading kernel name from JS to C++ - When in debug mode ( `env.debug = true;` ) or in profiling mode ( `env.webgpu.profilingMode = 'default';` ), kernel name will be read from ORT; otherwise use the kernel pointer ( a number ) as kernel name to save calls from JS to C++. - When in debug mode, WebGPU validation errors will be recorded and if any error occurs, `inferenceSession.run()` will fail (Promise get rejected). Behavior when not in debug mode is not changed. This is because recording errors are not zero-overhead, and GPU validation errors should occur consistently in and not in debug mode. - Add `jsepOnRunStart()` and `jsepOnRunEnd()` hook to: - allow implementation of the features mentioned above. - pass session ID to backend. --- cmake/onnxruntime_webassembly.cmake | 2 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 11 +++- js/web/lib/wasm/jsep/backend-webgpu.ts | 34 +++++++---- js/web/lib/wasm/jsep/init.ts | 16 +++-- .../lib/wasm/jsep/webgpu/program-manager.ts | 3 +- js/web/lib/wasm/wasm-core-impl.ts | 5 +- onnxruntime/core/providers/js/js_export.cc | 9 ++- onnxruntime/core/providers/js/js_export.h | 5 +- onnxruntime/core/providers/js/js_kernel.h | 4 +- onnxruntime/wasm/api.cc | 5 +- onnxruntime/wasm/js_internal_api.js | 58 ++++++++++++++++--- 11 files changed, 112 insertions(+), 40 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7804c31cc2a01..4243031045b7b 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -206,7 +206,7 @@ else() set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']") if (onnxruntime_USE_JSEP) - set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput") + set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName") else() set(EXPORTED_FUNCTIONS "_malloc,_free") endif() diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 4f1662199adf3..d04578ca697a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -declare namespace JSEP { +export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; type FreeFunction = (size: number) => number; @@ -9,7 +9,11 @@ declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number) => number; + type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; + export interface SessionState { + sessionId: number; + errors: Array>; + } } export interface OrtWasmModule extends EmscriptenModule { @@ -71,7 +75,10 @@ export interface OrtWasmModule extends EmscriptenModule { releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; _JsepOutput(context: number, index: number, data: number): number; + _JsepGetNodeName(kernel: number): number; + jsepOnRunStart?(sessionId: number): void; + jsepOnRunEnd?(sessionId: number): void; jsepRunPromise?: Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a23220e57ff69..861562d2e0e5b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -82,9 +82,10 @@ export class WebGpuBackend { } /** - * a KernelID -> kernel info mapping. value is [ name, run function, [optional] preprocess_attribute_once function ] + * a KernelID -> kernel info mapping. value is + * [ op_type, name, run function, [optional] preprocess_attribute_once function ] */ - kernels: Map unknown) | undefined, unknown]]>; + kernels: Map unknown) | undefined, unknown]]>; commandEncoder: GPUCommandEncoder|null = null; computePassEncoder: GPUComputePassEncoder|null = null; @@ -313,13 +314,13 @@ export class WebGpuBackend { return this.gpuDataManager.release(ptr); } - createKernel(name: string, kernelId: number, attribute: unknown): void { - const op = WEBGPU_OP_RESOLVE_RULES.get(name); + createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void { + const op = WEBGPU_OP_RESOLVE_RULES.get(opType); if (!op) { - throw new Error(`kernel not implemented: ${name}`); + throw new Error(`kernel not implemented: ${opType}`); } - this.kernels.set(kernelId, [name, op[0], [op[1], attribute]]); + this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]); } releaseKernel(kernelId: number): void { @@ -335,14 +336,14 @@ export class WebGpuBackend { this.kernels.delete(kernelId); } - computeKernel(kernelId: number, context: ComputeContext): number { + computeKernel(kernelId: number, context: ComputeContext, errors: Array>): number { const kernel = this.kernels.get(kernelId); if (!kernel) { throw new Error(`kernel not created: ${kernelId}`); } - const [name, kernelEntry, attributes] = kernel; + const [opType, nodeName, kernelEntry, attributes] = kernel; if (this.currentKernelId !== null) { - throw new Error(`kernel "${name}" is not allowed to be called recursively`); + throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`); } this.currentKernelId = kernelId; @@ -352,16 +353,27 @@ export class WebGpuBackend { attributes[0] = undefined; } - LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "${name}"...`); + LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`); + + const useErrorScope = this.env.debug; this.temporaryData = []; try { + if (useErrorScope) { + this.device.pushErrorScope('validation'); + } + kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`); + LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); return 1; // ORT_FAIL } finally { + if (useErrorScope) { + errors.push(this.device.popErrorScope().then( + err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null)); + } + for (const data of this.temporaryData) { this.gpuDataManager.release(data.id); } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index a7449e831b649..24ff79cfad3ee 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,7 +3,7 @@ import {Env} from 'onnxruntime-common'; -import {OrtWasmModule} from '../binding/ort-wasm'; +import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; @@ -169,16 +169,22 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { }, // jsepCreateKernel - (name: string, kernel: number, attribute: unknown) => backend.createKernel(name, kernel, attribute), + (name: string, kernel: number, attribute: unknown) => backend.createKernel( + name, kernel, attribute, + env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : + `${kernel}`), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number) => { - LOG_DEBUG('verbose', () => `[WebGPU] jsepRun: kernel=${kernel}, contextDataOffset=${contextDataOffset}`); + (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context); + return backend.computeKernel(kernel, context, sessionState.errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index da710b7dc2596..08ebe8e5e6df3 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -77,7 +77,8 @@ export class ProgramManager { this.backend.flush(); const kernelId = this.backend.currentKernelId!; - const kernelName = this.backend.kernels.get(kernelId)![0]; + const kernelInfo = this.backend.kernels.get(kernelId)!; + const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`; syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index d79b25d7087dd..9dc55b0d12864 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -258,12 +258,15 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } + wasm.jsepOnRunStart?.(sessionId); + // support RunOptions let errorCode = wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); - // eslint-disable-next-line @typescript-eslint/naming-convention + wasm.jsepOnRunEnd?.(sessionId); + const runPromise = wasm.jsepRunPromise; if (runPromise && typeof runPromise.then !== 'undefined') { errorCode = await runPromise; diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc index ca0527a2ef89b..2c99e246b69d0 100644 --- a/onnxruntime/core/providers/js/js_export.cc +++ b/onnxruntime/core/providers/js/js_export.cc @@ -5,8 +5,8 @@ #include "core/framework/op_kernel.h" -const void* JsepOutput(void* context, int index, void* data) { - uint32_t* data_offset = reinterpret_cast(data); +const void* JsepOutput(void* context, int index, const void* data) { + const uint32_t* data_offset = reinterpret_cast(data); uint32_t dim = *data_offset++; size_t dim_size = static_cast(dim); std::vector dims; @@ -24,3 +24,8 @@ const void* JsepOutput(void* context, int index, void* data) { LOGF_DEFAULT(VERBOSE, "JsepOutput -- data=%zu", (size_t)(r)); return r; } + +const void* JsepGetNodeName(const void* kernel) { + const auto& name = reinterpret_cast(kernel)->Node().Name(); + return name.c_str(); +} diff --git a/onnxruntime/core/providers/js/js_export.h b/onnxruntime/core/providers/js/js_export.h index bb1eb356cc9d5..9cf196767b7ed 100644 --- a/onnxruntime/core/providers/js/js_export.h +++ b/onnxruntime/core/providers/js/js_export.h @@ -7,8 +7,7 @@ #include -// TODO: Move to api.h - extern "C" { -const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, void* data); +const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, const void* data); +const void* EMSCRIPTEN_KEEPALIVE JsepGetNodeName(const void* context); }; diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index b8fab3bbc5665..3accd80875d1b 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -194,7 +194,9 @@ class JsKernel : public OpKernel { return status; } - int status_code = EM_ASM_INT({ return Module.jsepRun($0, $1); }, this, reinterpret_cast(p_serialized_kernel_context)); + int status_code = EM_ASM_INT( + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); }, + this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index a24fb81d496ce..496c9c401f392 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -368,12 +368,9 @@ int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { -#if defined(USE_JSEP) - EM_ASM({ Module["jsepRunPromise"] = new Promise(function(r) { Module.jsepRunPromiseResolve = r; }); }); -#endif auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); #if defined(USE_JSEP) - EM_ASM({ Module.jsepRunPromiseResolve($0); }, status_code); + EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code); #endif return status_code; } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 6c2c3522c7db2..c7bc0e39fc3eb 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -4,13 +4,53 @@ 'use strict'; // init JSEP -Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) { - Module.jsepBackend = backend; - Module.jsepAlloc = alloc; - Module.jsepFree = free; - Module.jsepCopy = copy; - Module.jsepCopyAsync = copyAsync; - Module.jsepCreateKernel = createKernel; - Module.jsepReleaseKernel = releaseKernel; - Module.jsepRun = run; +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { + Module.jsepBackend = backend; + Module.jsepAlloc = alloc; + Module.jsepFree = free; + Module.jsepCopy = copy; + Module.jsepCopyAsync = copyAsync; + Module.jsepCreateKernel = createKernel; + Module.jsepReleaseKernel = releaseKernel; + Module.jsepRunKernel = runKernel; + + Module['jsepOnRunStart'] = sessionId => { + Module['jsepRunPromise'] = new Promise(r => { + Module.jsepRunPromiseResolve = r; + }); + + if (Module.jsepSessionState) { + throw new Error('Session already started'); + } + + Module.jsepSessionState = { + sessionId, + errors: [] + }; + }; + + Module['jsepOnRunEnd'] = sessionId => { + if (Module.jsepSessionState.sessionId !== sessionId) { + throw new Error('Session ID mismatch'); + } + + const errorPromises = Module.jsepSessionState.errors; + Module.jsepSessionState = null; + + if (errorPromises.length > 0) { + const runPromise = Module['jsepRunPromise']; + Module['jsepRunPromise'] = new Promise((resolve, reject) => { + Promise.all(errorPromises).then(errors => { + errors = errors.filter(e => e); + if (errors.length > 0) { + reject(new Error(errors.join('\n'))); + } else { + resolve(runPromise); + } + }, reason => { + reject(reason); + }); + }); + } + }; };