From ccc89af3d5b09ff735d090f91d2335bb8b58f957 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 27 Dec 2023 16:55:32 +0800 Subject: [PATCH 01/28] [js/webgpu] Add record/replay support --- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 ++ js/web/lib/wasm/jsep/backend-webgpu.ts | 7 +++++++ js/web/lib/wasm/wasm-core-impl.ts | 3 ++- onnxruntime/wasm/js_internal_api.js | 6 ++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 6c55dcc1bfd32..2c899911447dc 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -177,6 +177,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + jsepRunStart: (sessionId: number) => void; + jsepRunEnd: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 0148f32cdd91b..e77083b0fb064 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -517,4 +517,11 @@ export class WebGpuBackend { }; } // #endregion + + runStart(sessionId: number): void { + LOG_DEBUG('info', () => `runStart sessionId: ${sessionId}`); + } + runEnd(sessionId: number): void { + LOG_DEBUG('info', () => `runEnd sessionId: ${sessionId}`); + } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a9dfd9218bb6f..5da6fbb899613 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -406,6 +406,7 @@ export const run = async( const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { + wasm.jsepRunStart(sessionId); [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -581,7 +582,7 @@ export const run = async( if (ioBindingState) { wasm._OrtClearBoundOutputs(ioBindingState.handle); } - + wasm.jsepRunEnd(sessionId); return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 427ad6f6d14f3..67963dd19e1f0 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -166,4 +166,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepRunStart'] = (sessionId) => { + return backend['runStart'](sessionId); + }; + Module['jsepRunEnd'] = (sessionId) => { + return backend['runEnd'](sessionId); + }; }; From c675f71ef4f4a9ab40e396ae7e00c1ebf73ff928 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 28 Dec 2023 16:55:50 +0800 Subject: [PATCH 02/28] Add record/replay support in c++ --- js/web/lib/wasm/binding/ort-wasm.d.ts | 6 +- js/web/lib/wasm/jsep/backend-webgpu.ts | 32 ++++++++++ js/web/lib/wasm/jsep/init.ts | 8 ++- .../providers/js/js_execution_provider.cc | 60 +++++++++++++++++++ .../core/providers/js/js_execution_provider.h | 18 ++++++ onnxruntime/core/session/inference_session.cc | 12 ++-- onnxruntime/wasm/js_internal_api.js | 5 +- 7 files changed, 134 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 2c899911447dc..5d9a498b50a3d 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,6 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -123,7 +126,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, + captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; /** * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e77083b0fb064..e78e78dc78c51 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -122,6 +122,9 @@ export class WebGpuBackend { return data; } + // required min regular runs before graph capture for the necessary memory allocations. + // const int min_num_runs_before_webgpu_graph_capture_ = 1 + /** * a KernelID -> kernel info mapping. value is * [ op_type, name, run function, [optional] preprocess_attribute_once function ] @@ -520,8 +523,37 @@ export class WebGpuBackend { runStart(sessionId: number): void { LOG_DEBUG('info', () => `runStart sessionId: ${sessionId}`); + /* + // Begin webgpu graph capture. + if (webgpu_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOG_DEBUG('info', () => 'Capturing the webgpu graph for this model'); + CaptureBegin(); + } + */ } runEnd(sessionId: number): void { LOG_DEBUG('info', () => `runEnd sessionId: ${sessionId}`); + // End webgpu graph capture. + /* + if (webgpu_graph_enable_ && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(); + // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ReplayGraph(); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + */ + } + captureBegin(): void { + LOG_DEBUG('info', () => 'captureBegin'); + } + captureEnd(): void { + LOG_DEBUG('info', () => 'captureEnd'); + } + replay(): void { + LOG_DEBUG('info', () => 'replay'); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3c6edf3ebb35d..8b9b8b0b732ea 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay()); }; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..b7efe6eff1605 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -3,6 +3,7 @@ #include "js_execution_provider.h" +#include #include #include #include @@ -749,4 +750,63 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } +Status JsExecutionProvider::OnRunStart() { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; + EM_ASM({ Module.jsepCaptureBegin(); }); + } + return Status::OK(); +} + +Status JsExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + EM_ASM({ Module.jsepCaptureEnd(); }); + is_graph_captured_ = true; + // CUDA work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + EM_ASM({ Module.jsepReplay(); }); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureEnabled() const { + return true; +} + +bool JsExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status JsExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + EM_ASM({ Module.jsepReplay(); }); + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), + // therefore following increment is guaranteed to be thread safe. + ++regular_run_count_before_graph_capture_; +} +/* +void JsExecutionProvider::CaptureBegin() { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(); +} + +void JsExecutionProvider::CaptureEnd() { + cuda_graph_.CaptureEnd(); + is_graph_captured_ = true; +} +*/ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 39d43498c0717..8b8a804b730d0 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -57,7 +57,25 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + + Status OnRunStart() override; + Status OnRunEnd(bool sync_stream) override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 665cdbc36a963..5028989c75608 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -140,7 +140,7 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { +static bool AreAllComputeNodesAssignedToCudaJSEp(const Graph& graph) { bool nodes_on_cpu_and_cuda_eps_only = true; for (const auto& node : graph.Nodes()) { @@ -149,6 +149,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { // Empty node provider means CPU EP if (!node_provider.empty() && node_provider != kCudaExecutionProvider && + node_provider != kJsExecutionProvider && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_eps_only = false; break; @@ -1705,7 +1706,9 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kJsExecutionProvider}; for (auto& it : cuda_graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); @@ -1722,12 +1725,13 @@ common::Status InferenceSession::Initialize() { "as the model has control flow nodes which can't be supported by CUDA Graphs.")); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaEp(graph)) { + if (!AreAllComputeNodesAssignedToCudaJSEp(graph)) { LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " << " as all compute graph nodes have not been partitioned to the CUDA EP."; diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 67963dd19e1f0..908e3467f490f 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -4,7 +4,7 @@ 'use strict'; // init JSEP -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; @@ -13,6 +13,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; + Module.jsepCaptureBegin = captureBegin; + Module.jsepCaptureEnd = captureEnd; + Module.jsepReplay = replay; // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. From 547e0055c8d09b0f70866ad80789d626fc639643 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 29 Dec 2023 14:31:38 +0800 Subject: [PATCH 03/28] support record/replay in js --- js/web/lib/wasm/jsep/backend-webgpu.ts | 28 ++++++++++++++++++- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 28 +++++++++++-------- .../lib/wasm/jsep/webgpu/program-manager.ts | 16 +++++++++-- js/web/lib/wasm/jsep/webgpu/types.ts | 6 ++++ .../providers/js/js_execution_provider.cc | 3 -- 5 files changed, 62 insertions(+), 19 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e78e78dc78c51..615a459bef3bd 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -8,7 +8,14 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, StatusType} from './webgpu/types'; + +interface CommandInfo { + kernelId: number; + computePipeline: GPUComputePipeline; + bindGroup: GPUBindGroup; + dispatchGroup: [number, number, number]; +} const getProgramInputTensorInfoDependencyKey = (inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => { @@ -141,6 +148,8 @@ export class WebGpuBackend { queryTimeBase?: bigint; env: Env; + status: StatusType = StatusType.default; + capturedCommandList: CommandInfo[] = []; /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -549,11 +558,28 @@ export class WebGpuBackend { } captureBegin(): void { LOG_DEBUG('info', () => 'captureBegin'); + this.capturedCommandList = []; + this.status = StatusType.capture; } captureEnd(): void { LOG_DEBUG('info', () => 'captureEnd'); + this.status = StatusType.default; } replay(): void { LOG_DEBUG('info', () => 'replay'); + this.status = StatusType.replay; + const length = this.capturedCommandList.length; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = this.capturedCommandList[i]; + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.pendingDispatchNumber++; + if (this.pendingDispatchNumber >= 16) { + this.flush(); + } + } + this.status = StatusType.default; } } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..54d51c4d2a9b7 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -4,7 +4,7 @@ import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; -import {GpuData, GpuDataId, GpuDataType} from './types'; +import {GpuData, GpuDataId, GpuDataType, StatusType} from './types'; /** * manages GpuDataId -> GpuBuffer @@ -312,20 +312,24 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + // Don't release intermediate tensors in non-default mode. + if (this.backend.status === StatusType.default) { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 0d699326366b3..48ede3b728424 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -9,7 +9,7 @@ import {LOG_DEBUG} from '../log'; import {TensorView} from '../tensor-view'; import {createShaderHelper} from './ops/common'; -import {Artifact, GpuData, ProgramInfo} from './types'; +import {Artifact, GpuData, ProgramInfo, StatusType} from './types'; /** * ProgramManager is the main class behind running computations @@ -41,7 +41,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -54,8 +53,19 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.status === StatusType.capture) { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + this.backend.capturedCommandList.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 23fa33a9bba8f..ad806f3a01ca1 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,6 +5,12 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export enum StatusType { + default = 0, + capture = 1, + replay = 2 +} + export enum GpuDataType { default = 0, upload = 1, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index b7efe6eff1605..77576841d5b6e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -763,9 +763,6 @@ Status JsExecutionProvider::OnRunEnd(bool sync_stream) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); is_graph_captured_ = true; - // CUDA work issued to a capturing stream doesn’t actually run on the GPU, - // so run the captured graph here to actually execute the work. - EM_ASM({ Module.jsepReplay(); }); } else { IncrementRegularRunCountBeforeGraphCapture(); } From 62d64b83df2edb2ecb69aeede3d072a2dd9a6e9e Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 3 Jan 2024 16:37:48 +0800 Subject: [PATCH 04/28] remove unused codes --- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 -- js/web/lib/wasm/jsep/backend-webgpu.ts | 29 ------------------- js/web/lib/wasm/wasm-core-impl.ts | 2 -- .../providers/js/js_execution_provider.cc | 14 --------- .../core/providers/js/js_execution_provider.h | 4 --- onnxruntime/wasm/js_internal_api.js | 6 ---- 6 files changed, 57 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 5d9a498b50a3d..1a0d7f02a0de5 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -181,8 +181,6 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; - jsepRunStart: (sessionId: number) => void; - jsepRunEnd: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 615a459bef3bd..fd194545841af 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -129,9 +129,6 @@ export class WebGpuBackend { return data; } - // required min regular runs before graph capture for the necessary memory allocations. - // const int min_num_runs_before_webgpu_graph_capture_ = 1 - /** * a KernelID -> kernel info mapping. value is * [ op_type, name, run function, [optional] preprocess_attribute_once function ] @@ -530,32 +527,6 @@ export class WebGpuBackend { } // #endregion - runStart(sessionId: number): void { - LOG_DEBUG('info', () => `runStart sessionId: ${sessionId}`); - /* - // Begin webgpu graph capture. - if (webgpu_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { - LOG_DEBUG('info', () => 'Capturing the webgpu graph for this model'); - CaptureBegin(); - } - */ - } - runEnd(sessionId: number): void { - LOG_DEBUG('info', () => `runEnd sessionId: ${sessionId}`); - // End webgpu graph capture. - /* - if (webgpu_graph_enable_ && !IsGraphCaptured()) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(); - // CUDA work issued to a capturing stream doesn’t actually run on the GPU, - // so run the captured graph here to actually execute the work. - ReplayGraph(); - } else { - IncrementRegularRunCountBeforeGraphCapture(); - } - } - */ - } captureBegin(): void { LOG_DEBUG('info', () => 'captureBegin'); this.capturedCommandList = []; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5da6fbb899613..7c63e83dea3c8 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -406,7 +406,6 @@ export const run = async( const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { - wasm.jsepRunStart(sessionId); [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -582,7 +581,6 @@ export const run = async( if (ioBindingState) { wasm._OrtClearBoundOutputs(ioBindingState.handle); } - wasm.jsepRunEnd(sessionId); return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 77576841d5b6e..18ee4dce6b8fe 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -790,20 +790,6 @@ bool JsExecutionProvider::IsGraphCaptureAllowed() const { } void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { - // Please note that this function is not thread safe. - // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), - // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } -/* -void JsExecutionProvider::CaptureBegin() { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); -} - -void JsExecutionProvider::CaptureEnd() { - cuda_graph_.CaptureEnd(); - is_graph_captured_ = true; -} -*/ } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 8b8a804b730d0..ee4ab1902c83e 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -71,10 +71,6 @@ class JsExecutionProvider : public IExecutionProvider { DataLayout preferred_data_layout_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; - // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: - // (1) memory pattern is enabled. (2) arena allocation for stream. - // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs - // to allocate enough memory in Arena before graph capturing. const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 908e3467f490f..5367562724d83 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -169,10 +169,4 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepRunStart'] = (sessionId) => { - return backend['runStart'](sessionId); - }; - Module['jsepRunEnd'] = (sessionId) => { - return backend['runEnd'](sessionId); - }; }; From be58e40f481d2532181be68c702b476d80646a78 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 5 Jan 2024 17:46:03 +0800 Subject: [PATCH 05/28] add EP option graphCaptureEnabled --- js/common/lib/inference-session.ts | 1 + js/web/lib/wasm/session-options.ts | 13 +++++++++++++ .../core/providers/js/js_execution_provider.cc | 4 ++-- .../core/providers/js/js_execution_provider.h | 9 +++++++++ onnxruntime/core/session/provider_registration.cc | 2 ++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index c7760692eed00..96944c7459a1a 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -237,6 +237,7 @@ export declare namespace InferenceSession { export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgpu'; preferredLayout?: 'NCHW'|'NHWC'; + graphCaptureEnabled?: boolean; } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 45ea48a2df209..ee322842035ff 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -115,6 +115,19 @@ const setExecutionProviders = `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); } } + if (webgpuOptions?.graphCaptureEnabled) { + if (webgpuOptions.graphCaptureEnabled !== true && webgpuOptions.graphCaptureEnabled !== false) { + throw new Error( + `graphCaptureEnabled must be either 'true' or 'false': ${webgpuOptions.graphCaptureEnabled}`); + } + const keyDataOffset = allocWasmString('graphCaptureEnabled', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.graphCaptureEnabled.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError(`Can't set a session config entry: 'graphCaptureEnabled' - ${ + webgpuOptions.graphCaptureEnabled}.`); + } + } } break; case 'wasm': diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 18ee4dce6b8fe..ec353269a7b20 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -682,7 +682,7 @@ using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, - preferred_data_layout_{info.data_layout} { + preferred_data_layout_{info.data_layout}, graph_capture_enabled_(info.graph_capture_enabled) { } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -772,7 +772,7 @@ Status JsExecutionProvider::OnRunEnd(bool sync_stream) { } bool JsExecutionProvider::IsGraphCaptureEnabled() const { - return true; + return graph_capture_enabled_; } bool JsExecutionProvider::IsGraphCaptured() const { diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index ee4ab1902c83e..477d5dfc622b1 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -30,10 +30,18 @@ struct JsExecutionProviderInfo { data_layout = DataLayout::NHWC; } } + const std::string& graph_capture_enabled_str = po.at("graph_capture_enabled"); + if (graph_capture_enabled_str == "true") + { + graph_capture_enabled = true; + } else { + graph_capture_enabled = false; + } } // JSEP default preferred layout is NHWC DataLayout data_layout = DataLayout::NHWC; + bool graph_capture_enabled = false; }; class JsExecutionProvider : public IExecutionProvider { @@ -69,6 +77,7 @@ class JsExecutionProvider : public IExecutionProvider { bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool graph_capture_enabled_ = false; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 2e9af9f1f9bb2..a83e801359929 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -141,6 +141,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } + std::string graph_capture_enabled = options->value.config_options.GetConfigOrDefault("graphCaptureEnabled", "false"); + provider_options["graph_capture_enabled"] = graph_capture_enabled; options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); From 012066bd64ed35fd659657966760d0d57c401f53 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 8 Jan 2024 11:08:45 +0800 Subject: [PATCH 06/28] add sessionId to capture/replay methods --- js/web/lib/wasm/binding/ort-wasm.d.ts | 6 +-- js/web/lib/wasm/jsep/backend-webgpu.ts | 38 ++++++++++++++----- js/web/lib/wasm/jsep/init.ts | 6 +-- .../lib/wasm/jsep/webgpu/program-manager.ts | 3 +- .../providers/js/js_execution_provider.cc | 6 +-- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 1a0d7f02a0de5..83ff433385c3f 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,9 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; - type CaptureBeginFunction = () => void; - type CaptureEndFunction = () => void; - type ReplayFunction = () => void; + type CaptureBeginFunction = (sessionHandle: number) => void; + type CaptureEndFunction = (sessionHandle: number) => void; + type ReplayFunction = (sessionHandle: number) => void; } export interface OrtWasmModule extends EmscriptenModule { diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index fd194545841af..5186beb2adb3a 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -94,6 +94,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being captured/replay. + * `null` means no session is being captured. + * only valid when captureGraphEnabled = true. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -146,7 +153,10 @@ export class WebGpuBackend { env: Env; status: StatusType = StatusType.default; - capturedCommandList: CommandInfo[] = []; + /** + * a SessionID -> CommandInfo[] mapping. + */ + capturedCommandList: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -527,22 +537,30 @@ export class WebGpuBackend { } // #endregion - captureBegin(): void { - LOG_DEBUG('info', () => 'captureBegin'); - this.capturedCommandList = []; + captureBegin(sessionHandle: number): void { + LOG_DEBUG('info', () => `captureBegin ${sessionHandle}`); + this.currentSessionId = sessionHandle; + let sessionCommandList = this.capturedCommandList.get(sessionHandle); + if (!sessionCommandList) { + sessionCommandList = []; + this.capturedCommandList.set(sessionHandle, sessionCommandList); + } this.status = StatusType.capture; } - captureEnd(): void { - LOG_DEBUG('info', () => 'captureEnd'); + captureEnd(sessionHandle: number): void { + LOG_DEBUG('info', () => `captureEnd ${sessionHandle}`); + this.currentSessionId = null; this.status = StatusType.default; } - replay(): void { - LOG_DEBUG('info', () => 'replay'); + replay(sessionHandle: number): void { + LOG_DEBUG('info', () => `replay ${sessionHandle}`); + this.currentSessionId = sessionHandle; this.status = StatusType.replay; - const length = this.capturedCommandList.length; + const sessionCommandList = this.capturedCommandList.get(sessionHandle); + const length = sessionCommandList!.length; for (let i = 0; i < length; i++) { const computePassEncoder = this.getComputePassEncoder(); - const command = this.capturedCommandList[i]; + const command = sessionCommandList![i]; computePassEncoder.setPipeline(command.computePipeline); computePassEncoder.setBindGroup(0, command.bindGroup); computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 8b9b8b0b732ea..470077bd81e69 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -203,9 +203,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte return backend.computeKernel(kernel, context, errors); }, // jsepCaptureBegin - () => backend.captureBegin(), + (sessionHandle: number) => backend.captureBegin(sessionHandle), // jsepCaptureEnd - () => backend.captureEnd(), + (sessionHandle: number) => backend.captureEnd(sessionHandle), // jsepReplay - () => backend.replay()); + (sessionHandle: number) => backend.replay(sessionHandle)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 48ede3b728424..87061af61ef1c 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -61,7 +61,8 @@ export class ProgramManager { bindGroup, dispatchGroup }; - this.backend.capturedCommandList.push(commandInfo); + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList?.push(commandInfo); } computePassEncoder.setPipeline(buildArtifact.computePipeline); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index ec353269a7b20..428bce2a10832 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -753,7 +753,7 @@ JsExecutionProvider::~JsExecutionProvider() { Status JsExecutionProvider::OnRunStart() { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; - EM_ASM({ Module.jsepCaptureBegin(); }); + EM_ASM({ Module.jsepCaptureBegin(Module.jsepSessionState.sessionHandle); }); } return Status::OK(); } @@ -761,7 +761,7 @@ Status JsExecutionProvider::OnRunStart() { Status JsExecutionProvider::OnRunEnd(bool sync_stream) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { - EM_ASM({ Module.jsepCaptureEnd(); }); + EM_ASM({ Module.jsepCaptureEnd(Module.jsepSessionState.sessionHandle); }); is_graph_captured_ = true; } else { IncrementRegularRunCountBeforeGraphCapture(); @@ -781,7 +781,7 @@ bool JsExecutionProvider::IsGraphCaptured() const { Status JsExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); - EM_ASM({ Module.jsepReplay(); }); + EM_ASM({ Module.jsepReplay(Module.jsepSessionState.sessionHandle); }); return Status::OK(); } From db026152c681fbad54d747cae09e4b06118c4fb1 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 8 Jan 2024 15:43:22 +0800 Subject: [PATCH 07/28] Add releaseSession interface --- js/web/lib/wasm/binding/ort-wasm.d.ts | 6 +++ js/web/lib/wasm/jsep/backend-webgpu.ts | 7 ++++ .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 41 ++++++++++++++++++- js/web/lib/wasm/wasm-core-impl.ts | 1 + onnxruntime/wasm/js_internal_api.js | 3 ++ 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 83ff433385c3f..2c51ee5462f5e 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -181,6 +181,12 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Release a session. + * @param sessionId - specify the session ID. + * @returns + */ + jsepReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5186beb2adb3a..1d80378ef7800 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -571,4 +571,11 @@ export class WebGpuBackend { } this.status = StatusType.default; } + + releaseSession(sessionId: number): void { + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + this.gpuDataManager.releaseSession(sessionId); + } } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 54d51c4d2a9b7..64c8818faf4a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + releaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -313,6 +324,10 @@ class GpuDataManagerImpl implements GpuDataManager { } this.buffersForUploadingPending = []; + if (this.buffersPending.length === 0) { + return; + } + // Don't release intermediate tensors in non-default mode. if (this.backend.status === StatusType.default) { for (const buffer of this.buffersPending) { @@ -329,6 +344,16 @@ class GpuDataManagerImpl implements GpuDataManager { } } this.buffersPending = []; + } else { + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); + } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } } @@ -348,9 +373,23 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + releaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBffers = this.capturedPendingBuffers.get(sessionId); + pendingBffers!.forEach(buffer => { + buffer.destroy(); + }); } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7c63e83dea3c8..7f3d301c23c9e 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -303,6 +303,7 @@ export const releaseSession = (sessionId: number): void => { } wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 5367562724d83..7453abcda1923 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -169,4 +169,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepReleaseSession'] = sessionId => { + backend['releaseSession'](sessionId); + }; }; From 2d0c878fc686531a8dddd3a0ec7b7e893a209a0d Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 8 Jan 2024 16:59:52 +0800 Subject: [PATCH 08/28] Create an internal buffer for each external buffer --- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 64c8818faf4a2..eda999eb878da 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -143,7 +143,7 @@ class GpuDataManagerImpl implements GpuDataManager { private freeUniformBuffers: Map; // The external buffers registered users for IO Binding. - private externalBuffers: Map; + private externalBuffers: Map; // The pendingBuffers for capture graph. // a SessionID -> GPUBuffer[] mapping. @@ -219,38 +219,41 @@ class GpuDataManagerImpl implements GpuDataManager { } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { - let id: number|undefined; if (previousBuffer) { - id = this.externalBuffers.get(previousBuffer); - if (id === undefined) { + const ids = this.externalBuffers.get(previousBuffer); + if (ids === undefined) { throw new Error('previous buffer is not registered'); } if (buffer === previousBuffer) { LOG_DEBUG( 'verbose', () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ - id}, buffer is the same, skip.`); - return id; + ids[0]}, buffer is the same, skip.`); + return ids[1]; } this.externalBuffers.delete(previousBuffer); - } else { - id = createNewGpuDataId(); } + const id = createNewGpuDataId(); this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); - this.externalBuffers.set(buffer, id); LOG_DEBUG( 'verbose', () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); - return id; + + // copy the externl data to an internal gpu buffer. + const internalGpuData = this.create(originalSize); + const internalId = internalGpuData.id; + this.memcpy(id, internalId); + this.externalBuffers.set(buffer, [id, internalId]); + return internalId; } unregisterExternalBuffer(buffer: GPUBuffer): void { const id = this.externalBuffers.get(buffer); if (id !== undefined) { - this.storageCache.delete(id); + this.storageCache.delete(id[0]); this.externalBuffers.delete(buffer); - LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id[0]}`); } } From c00b29bd5d302386c57ea1e09e8ae93c7505ac15 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 8 Jan 2024 17:04:08 +0800 Subject: [PATCH 09/28] Revert "Create an internal buffer for each external buffer" This reverts commit 80b53bc01a7d643e7cdc3491579e754a45d8b621. --- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index eda999eb878da..64c8818faf4a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -143,7 +143,7 @@ class GpuDataManagerImpl implements GpuDataManager { private freeUniformBuffers: Map; // The external buffers registered users for IO Binding. - private externalBuffers: Map; + private externalBuffers: Map; // The pendingBuffers for capture graph. // a SessionID -> GPUBuffer[] mapping. @@ -219,41 +219,38 @@ class GpuDataManagerImpl implements GpuDataManager { } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; if (previousBuffer) { - const ids = this.externalBuffers.get(previousBuffer); - if (ids === undefined) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { throw new Error('previous buffer is not registered'); } if (buffer === previousBuffer) { LOG_DEBUG( 'verbose', () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ - ids[0]}, buffer is the same, skip.`); - return ids[1]; + id}, buffer is the same, skip.`); + return id; } this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); } - const id = createNewGpuDataId(); this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); LOG_DEBUG( 'verbose', () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); - - // copy the externl data to an internal gpu buffer. - const internalGpuData = this.create(originalSize); - const internalId = internalGpuData.id; - this.memcpy(id, internalId); - this.externalBuffers.set(buffer, [id, internalId]); - return internalId; + return id; } unregisterExternalBuffer(buffer: GPUBuffer): void { const id = this.externalBuffers.get(buffer); if (id !== undefined) { - this.storageCache.delete(id[0]); + this.storageCache.delete(id); this.externalBuffers.delete(buffer); - LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id[0]}`); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); } } From e1a4bc4ece71384585cfd81735f973a929d4a790 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 9 Jan 2024 13:32:45 +0800 Subject: [PATCH 10/28] throw errrors when not supported --- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 64c8818faf4a2..08729aac924bf 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -231,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { From 387ff444ad0d2f5dcd74c672e6b2af4566493688 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 9 Jan 2024 17:19:10 +0800 Subject: [PATCH 11/28] only bind input/output once for IOBinding when graphCaptureEnabled = true --- js/web/lib/wasm/wasm-core-impl.ts | 89 ++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7f3d301c23c9e..93b7782e7f9d4 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -138,7 +138,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, graphCaptureEnabled: boolean, inputOutputBounded: boolean ]; const activeSessions = new Map(); @@ -219,6 +219,19 @@ export const createSession = checkLastError('Can\'t create a session.'); } + let graphCaptureEnabled = false; + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const executionProviders = options?.executionProviders; + for (const ep of executionProviders!) { + const epName = typeof ep === 'string' ? ep : ep.name; + if (epName === 'webgpu') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + graphCaptureEnabled = + webgpuOptions.graphCaptureEnabled === undefined ? false : webgpuOptions.graphCaptureEnabled; + } + } + } + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); const inputNames = []; @@ -242,6 +255,10 @@ export const createSession = outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (graphCaptureEnabled) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; @@ -267,7 +284,9 @@ export const createSession = }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, graphCaptureEnabled, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -388,7 +407,8 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, inputOutputBounded] = + session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -434,41 +454,46 @@ export const run = async( } if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { - const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; - - if (inputNamesUTF8Encoded.length !== inputCount) { - throw new Error(`input count from feeds (${ - inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); - } + if (!graphCaptureEnabled || (graphCaptureEnabled && !inputOutputBounded)) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; - // process inputs - for (let i = 0; i < inputCount; i++) { - const index = inputIndices[i]; - const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); - if (errorCode !== 0) { - checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); } - } - // process pre-allocated outputs - for (let i = 0; i < outputCount; i++) { - const index = outputIndices[i]; - const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. - - if (location) { - // output is pre-allocated. bind the tensor. - const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + // process inputs + for (let i = 0; i < inputCount; i++) { + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); if (errorCode !== 0) { - checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); } - } else { - // output is not pre-allocated. reset preferred location. - const errorCode = - wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); - if (errorCode !== 0) { - checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + + // process pre-allocated outputs + for (let i = 0; i < outputCount; i++) { + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, true]); } } @@ -579,7 +604,7 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !graphCaptureEnabled) { wasm._OrtClearBoundOutputs(ioBindingState.handle); } return output; From 79f392cbb8d47f84dc77547b45dbb04e6b4ff76e Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 10 Jan 2024 15:06:28 +0800 Subject: [PATCH 12/28] nits --- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 08729aac924bf..8866351169704 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -389,10 +389,13 @@ class GpuDataManagerImpl implements GpuDataManager { releaseSession(sessionId: number) { // release the captured pending buffers. - const pendingBffers = this.capturedPendingBuffers.get(sessionId); - pendingBffers!.forEach(buffer => { - buffer.destroy(); - }); + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } From 030d3477255ed977236e34ddfa4f52f0f53790cf Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 10 Jan 2024 16:35:18 +0800 Subject: [PATCH 13/28] update name and annotation --- .../providers/js/js_execution_provider.cc | 3 +- .../core/providers/js/js_execution_provider.h | 7 +-- onnxruntime/core/session/inference_session.cc | 61 +++++++++++-------- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 428bce2a10832..025611e890483 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -682,7 +682,8 @@ using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, - preferred_data_layout_{info.data_layout}, graph_capture_enabled_(info.graph_capture_enabled) { + preferred_data_layout_{info.data_layout}, + graph_capture_enabled_(info.graph_capture_enabled) { } std::vector JsExecutionProvider::CreatePreferredAllocators() { diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 477d5dfc622b1..b89c4e03bb6f4 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -31,11 +31,10 @@ struct JsExecutionProviderInfo { } } const std::string& graph_capture_enabled_str = po.at("graph_capture_enabled"); - if (graph_capture_enabled_str == "true") - { - graph_capture_enabled = true; + if (graph_capture_enabled_str == "true") { + graph_capture_enabled = true; } else { - graph_capture_enabled = false; + graph_capture_enabled = false; } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5028989c75608..5f58e9374adc6 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -140,8 +140,8 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaJSEp(const Graph& graph) { - bool nodes_on_cpu_and_cuda_eps_only = true; +static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { + bool nodes_on_cpu_and_cuda_and_js_eps_only = true; for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); @@ -151,18 +151,18 @@ static bool AreAllComputeNodesAssignedToCudaJSEp(const Graph& graph) { node_provider != kCudaExecutionProvider && node_provider != kJsExecutionProvider && node_provider != kCpuExecutionProvider) { - nodes_on_cpu_and_cuda_eps_only = false; + nodes_on_cpu_and_cuda_and_js_eps_only = false; break; } } - // If we see nodes assigned to EPs other than CPU or CUDA + // If we see nodes assigned to EPs other than CPU, or CUDA/JS // (or) if there are Memcpy nodes, then all compute nodes have - // not been parititoned to the CUDA EP. + // not been parititoned to the CUDA/JS EP. // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy // involved as shape subgraphs will be forced onto CPU and these will not have // Memcpy nodes involved. - return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); } static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -1693,7 +1693,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently graph capture is only considered by CUDA EP, TRT EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1706,50 +1706,61 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kJsExecutionProvider}; + // + // Check for JS EP: + // If the JS EP is part of the providers list for this session AND + // The JS EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the JS EP, + // Then the JS EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kJsExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // Graphs capture can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << "as the model has control flow nodes which can't be supported by " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the graph capture feature as requested by the user " + "as the model has control flow nodes which can't be supported by" + + target_ep->Type())); } if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { - // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaJSEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the graph capture feature as requested by the user " + " as all compute graph nodes have not been partitioned to the " + + target_ep->Type())); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the graph capture feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the graph capture feature."; } } else { // Following code path is for TRT EP currently. From c4cfde0924f280532552fb8832c2c8a678ab1bb2 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 10 Jan 2024 17:18:10 +0800 Subject: [PATCH 14/28] fix format issues --- js/web/lib/wasm/wasm-core-impl.ts | 127 ++++++++++++++++-------------- 1 file changed, 68 insertions(+), 59 deletions(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 93b7782e7f9d4..3f059d9cdf319 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,7 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let currentEpName: string; // #region Initializations /** @@ -105,6 +106,7 @@ export const initEp = async(env: Env, epName: string): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env, adapter); } + currentEpName = epName; }; // #endregion Initializations @@ -220,7 +222,7 @@ export const createSession = } let graphCaptureEnabled = false; - if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (currentEpName === 'webgpu') { const executionProviders = options?.executionProviders; for (const ep of executionProviders!) { const epName = typeof ep === 'string' ? ep : ep.name; @@ -331,70 +333,75 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + graphCaptureEnabled = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (graphCaptureEnabled && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when graphCaptureEnabled is true.`); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -431,13 +438,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], graphCaptureEnabled); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + graphCaptureEnabled); } let inputValuesIndex = inputValuesOffset / 4; From d105c52057f536fc33355053f0fb0adecd2d79dd Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 11 Jan 2024 17:07:26 +0800 Subject: [PATCH 15/28] fix lint/format errors --- js/web/lib/wasm/wasm-core-impl.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 3f059d9cdf319..3c2c2e8be92f5 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -414,8 +414,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, inputOutputBounded] = - session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const graphCaptureEnabled = session[4]; + const inputOutputBounded = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; From cc2ff91c40eecf4e38253d0450c8a01aa47b701a Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 15 Jan 2024 15:55:42 +0800 Subject: [PATCH 16/28] nits --- js/web/lib/wasm/jsep/backend-webgpu.ts | 8 ++++---- js/web/lib/wasm/wasm-core-impl.ts | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e62bea498a37c..b78a1f995bc1e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,10 +13,10 @@ import {ProgramManager} from './webgpu/program-manager'; import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, StatusType, TimestampQuery} from './webgpu/types'; interface CommandInfo { - kernelId: number; - computePipeline: GPUComputePipeline; - bindGroup: GPUBindGroup; - dispatchGroup: [number, number, number]; + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; } interface KernelInfo { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index c443900064f4d..652ab9344c198 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -301,8 +301,8 @@ export const createSession = async( } activeSessions.set( - sessionHandle, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, graphCaptureEnabled, false]); + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, graphCaptureEnabled, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); From e630dbf528fc3a955702cceb968930d0abdfc652 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 18 Jan 2024 09:26:01 +0800 Subject: [PATCH 17/28] enable timestamp query --- js/web/lib/wasm/jsep/backend-webgpu.ts | 48 +++++++++++++------ .../lib/wasm/jsep/webgpu/program-manager.ts | 19 +++++++- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index b78a1f995bc1e..a5dcf0a45a826 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -26,7 +26,7 @@ interface KernelInfo { readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown]; } -interface PendingKernelInfo { +export interface PendingKernelInfo { readonly kernelId: number; readonly programName: string; readonly inputTensorViews: readonly TensorView[]; @@ -160,7 +160,7 @@ export class WebGpuBackend { pendingDispatchNumber = 0; // info of kernels pending submission for a single batch - private pendingKernels: PendingKernelInfo[] = []; + pendingKernels: PendingKernelInfo[] = []; // queryReadBuffer -> pendingKernels mapping for all the batches private pendingQueries: Map = new Map(); private queryResolveBuffer?: GPUBuffer; @@ -175,6 +175,11 @@ export class WebGpuBackend { */ capturedCommandList: Map = new Map(); + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + capturedPendingKernels: Map = new Map(); + /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. */ @@ -259,6 +264,8 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + // getCommandEncoder must be put before checking this.queryType since this.queryType is updated there. + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -269,7 +276,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -509,17 +516,9 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { - const pendingKernelInfo: PendingKernelInfo = { - kernelId: this.currentKernelId!, - programName: artifact.programInfo.name, - inputTensorViews, - outputTensorViews, - }; - this.pendingKernels.push(pendingKernelInfo); - } - - this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); + this.programManager.run( + artifact, inputDatas, outputDatas, inputTensorViews, outputTensorViews, normalizedDispatchGroup, + uniformBufferBinding); TRACE_FUNC_END(program.name); return outputTensorViews; @@ -682,9 +681,12 @@ export class WebGpuBackend { LOG_DEBUG('info', () => `captureBegin ${sessionHandle}`); this.currentSessionId = sessionHandle; let sessionCommandList = this.capturedCommandList.get(sessionHandle); + let sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle); if (!sessionCommandList) { sessionCommandList = []; this.capturedCommandList.set(sessionHandle, sessionCommandList); + sessionPendingKernels = []; + this.capturedPendingKernels.set(sessionHandle, sessionPendingKernels); } this.status = StatusType.capture; } @@ -694,22 +696,35 @@ export class WebGpuBackend { this.status = StatusType.default; } replay(sessionHandle: number): void { + // make sure previous commands are all submitted. + this.flush(); LOG_DEBUG('info', () => `replay ${sessionHandle}`); this.currentSessionId = sessionHandle; this.status = StatusType.replay; const sessionCommandList = this.capturedCommandList.get(sessionHandle); + const sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle); const length = sessionCommandList!.length; + this.pendingKernels = []; for (let i = 0; i < length; i++) { const computePassEncoder = this.getComputePassEncoder(); const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); computePassEncoder.setPipeline(command.computePipeline); computePassEncoder.setBindGroup(0, command.bindGroup); computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); this.pendingDispatchNumber++; - if (this.pendingDispatchNumber >= 16) { + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { this.flush(); } } + this.flush(); this.status = StatusType.default; } @@ -717,6 +732,9 @@ export class WebGpuBackend { if (this.capturedCommandList.has(sessionId)) { this.capturedCommandList.delete(sessionId); } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } this.gpuDataManager.releaseSession(sessionId); } } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index dada309ab4d9b..9f8cab0e55549 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -3,8 +3,9 @@ import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {WebGpuBackend} from '../backend-webgpu'; +import {PendingKernelInfo, WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; +import {TensorView} from '../tensor-view'; import {createShaderHelper} from './ops/common'; import {Artifact, GpuData, ProgramInfo, StatusType} from './types'; @@ -32,7 +33,8 @@ export class ProgramManager { setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } - run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], + run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], inputTensorViews: readonly TensorView[], + outputTensorViews: readonly TensorView[], dispatchGroup: [number, number, number], uniformBufferBinding: GPUBindingResource|undefined): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; @@ -68,6 +70,19 @@ export class ProgramManager { this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; + if (this.backend.queryType !== 'none' || this.backend.status === StatusType.capture) { + const pendingKernelInfo: PendingKernelInfo = { + kernelId: this.backend.currentKernelId!, + programName: buildArtifact.programInfo.name, + inputTensorViews, + outputTensorViews, + }; + this.backend.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.backend.capturedPendingKernels.get(this.backend.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } + if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || this.backend.queryType === 'at-passes') { this.backend.endComputePass(); From 4c313ad31fef3d254bd67f69949a3c91b6c90835 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 19 Jan 2024 13:38:11 +0800 Subject: [PATCH 18/28] address Yulong's comments --- js/web/lib/wasm/binding/ort-wasm.d.ts | 8 +------- js/web/lib/wasm/jsep/backend-webgpu.ts | 5 +++-- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 7 ++++--- js/web/lib/wasm/jsep/webgpu/program-manager.ts | 2 +- js/web/lib/wasm/session-options.ts | 7 +++---- js/web/lib/wasm/wasm-core-impl.ts | 3 +-- onnxruntime/wasm/js_internal_api.js | 7 ++----- 7 files changed, 15 insertions(+), 24 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index e34d635dc3ecc..9a936fbbedc97 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -162,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule { * @returns the GPU data ID for the registered GPU buffer. */ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; /** * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. * @@ -191,7 +185,7 @@ export interface OrtWasmModule extends EmscriptenModule { * @param sessionId - specify the session ID. * @returns */ - jsepReleaseSession: (sessionId: number) => void; + jsepOnReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a5dcf0a45a826..18bd719a9554b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -728,13 +728,14 @@ export class WebGpuBackend { this.status = StatusType.default; } - releaseSession(sessionId: number): void { + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); if (this.capturedCommandList.has(sessionId)) { this.capturedCommandList.delete(sessionId); } if (this.capturedPendingKernels.has(sessionId)) { this.capturedPendingKernels.delete(sessionId); } - this.gpuDataManager.releaseSession(sessionId); + this.gpuDataManager.onReleaseSession(sessionId); } } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 8866351169704..bd8d79b6df4dd 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -68,7 +68,7 @@ export interface GpuDataManager { * release session related data. * @param sessionId - specify the session ID. */ - releaseSession(sessionId: number): void; + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -331,7 +331,6 @@ class GpuDataManagerImpl implements GpuDataManager { return; } - // Don't release intermediate tensors in non-default mode. if (this.backend.status === StatusType.default) { for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise @@ -348,6 +347,8 @@ class GpuDataManagerImpl implements GpuDataManager { } this.buffersPending = []; } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); if (!capturedBuffers) { capturedBuffers = []; @@ -387,7 +388,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.capturedPendingBuffers = new Map(); } - releaseSession(sessionId: number) { + onReleaseSession(sessionId: number) { // release the captured pending buffers. const pendingBuffers = this.capturedPendingBuffers.get(sessionId); if (pendingBuffers) { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 9f8cab0e55549..e54f1855a885f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -61,7 +61,7 @@ export class ProgramManager { dispatchGroup }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); - sessionCommandList?.push(commandInfo); + sessionCommandList!.push(commandInfo); } computePassEncoder.setPipeline(buildArtifact.computePipeline); diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index c939ebf0b3c52..fb29b7b5ce6b0 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -112,10 +112,9 @@ const setExecutionProviders = `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); } } - if (webgpuOptions?.graphCaptureEnabled) { - if (webgpuOptions.graphCaptureEnabled !== true && webgpuOptions.graphCaptureEnabled !== false) { - throw new Error( - `graphCaptureEnabled must be either 'true' or 'false': ${webgpuOptions.graphCaptureEnabled}`); + if (webgpuOptions?.graphCaptureEnabled !== undefined) { + if (typeof webgpuOptions.graphCaptureEnabled !== 'boolean') { + throw new Error(`graphCaptureEnabled must be a boolean value: ${webgpuOptions.graphCaptureEnabled}`); } const keyDataOffset = allocWasmString('graphCaptureEnabled', allocs); const valueDataOffset = allocWasmString(webgpuOptions.graphCaptureEnabled.toString(), allocs); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 652ab9344c198..afd3592591c47 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -340,8 +340,7 @@ export const releaseSession = (sessionId: number): void => { wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); - wasm.jsepReleaseSession?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 0b254abf2db41..9a0406d8d0ef6 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -180,16 +180,13 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module['jsepUnregisterBuffers'] = sessionId => { - backend['unregisterBuffers'](sessionId); - }; Module['jsepGetBuffer'] = (dataId) => { return backend['getBuffer'](dataId); }; Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepReleaseSession'] = sessionId => { - backend['releaseSession'](sessionId); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); }; }; From 3f3c6dfa7f138114dcfbd4b2c40db34ba04aa9c5 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 19 Jan 2024 15:13:12 +0800 Subject: [PATCH 19/28] reuse the storage buffer --- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index bd8d79b6df4dd..ee07ed3a46e7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -347,15 +347,20 @@ class GpuDataManagerImpl implements GpuDataManager { } this.buffersPending = []; } else { - // Don't release intermediate tensors in non-default mode. - // TODO: reuse the storage buffers in non-default mode. + // Don't release uniform buffers in non-default mode. let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); if (!capturedBuffers) { capturedBuffers = []; this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } for (const buffer of this.buffersPending) { - capturedBuffers.push(buffer); + // eslint-disable-next-line no-bitwise + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + } else { + capturedBuffers.push(buffer); + } } this.buffersPending = []; } From b992f6cf6092d655393929dfe0b665c551cc6f66 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 22 Jan 2024 15:19:46 +0800 Subject: [PATCH 20/28] Revert "reuse the storage buffer" This reverts commit 3f3c6dfa7f138114dcfbd4b2c40db34ba04aa9c5. --- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index ee07ed3a46e7d..bd8d79b6df4dd 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -347,20 +347,15 @@ class GpuDataManagerImpl implements GpuDataManager { } this.buffersPending = []; } else { - // Don't release uniform buffers in non-default mode. + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); if (!capturedBuffers) { capturedBuffers = []; this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); - } else { - capturedBuffers.push(buffer); - } + capturedBuffers.push(buffer); } this.buffersPending = []; } From 217298400fe9b82bb38402ad6b94816cb0b71f6c Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 23 Jan 2024 12:57:01 +0800 Subject: [PATCH 21/28] integrate setQueryType changes --- js/web/lib/wasm/jsep/backend-webgpu.ts | 27 ++++++++++++------- .../lib/wasm/jsep/webgpu/program-manager.ts | 19 ++----------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 8ac9262741733..e2ddaf3183521 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -26,7 +26,7 @@ interface KernelInfo { readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown]; } -export interface PendingKernelInfo { +interface PendingKernelInfo { readonly kernelId: number; readonly programName: string; readonly inputTensorViews: readonly TensorView[]; @@ -160,7 +160,7 @@ export class WebGpuBackend { pendingDispatchNumber = 0; // info of kernels pending submission for a single batch - pendingKernels: PendingKernelInfo[] = []; + private pendingKernels: PendingKernelInfo[] = []; // queryReadBuffer -> pendingKernels mapping for all the batches private pendingQueries: Map = new Map(); private queryResolveBuffer?: GPUBuffer; @@ -178,7 +178,7 @@ export class WebGpuBackend { /** * a SessionID -> PendingKernelInfo[] mapping for profiling. */ - capturedPendingKernels: Map = new Map(); + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -262,7 +262,6 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { - // getCommandEncoder must be put before checking this.queryType since this.queryType is updated there. const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; @@ -514,9 +513,20 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - this.programManager.run( - artifact, inputDatas, outputDatas, inputTensorViews, outputTensorViews, normalizedDispatchGroup, - uniformBufferBinding); + if (this.queryType !== 'none' || this.status === StatusType.capture) { + const pendingKernelInfo: PendingKernelInfo = { + kernelId: this.currentKernelId!, + programName: artifact.programInfo.name, + inputTensorViews, + outputTensorViews, + }; + this.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } + + this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); TRACE_FUNC_END(program.name); return outputTensorViews; @@ -694,8 +704,6 @@ export class WebGpuBackend { this.status = StatusType.default; } replay(sessionHandle: number): void { - // make sure previous commands are all submitted. - this.flush(); LOG_DEBUG('info', () => `replay ${sessionHandle}`); this.currentSessionId = sessionHandle; this.status = StatusType.replay; @@ -722,7 +730,6 @@ export class WebGpuBackend { this.flush(); } } - this.flush(); this.status = StatusType.default; } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index e54f1855a885f..b84f10cd9023a 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -3,9 +3,8 @@ import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {PendingKernelInfo, WebGpuBackend} from '../backend-webgpu'; +import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; -import {TensorView} from '../tensor-view'; import {createShaderHelper} from './ops/common'; import {Artifact, GpuData, ProgramInfo, StatusType} from './types'; @@ -33,8 +32,7 @@ export class ProgramManager { setArtifact(key: unknown, artifact: Artifact): void { this.repo.set(key, artifact); } - run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], inputTensorViews: readonly TensorView[], - outputTensorViews: readonly TensorView[], dispatchGroup: [number, number, number], + run(buildArtifact: Artifact, inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number], uniformBufferBinding: GPUBindingResource|undefined): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; @@ -70,19 +68,6 @@ export class ProgramManager { this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; - if (this.backend.queryType !== 'none' || this.backend.status === StatusType.capture) { - const pendingKernelInfo: PendingKernelInfo = { - kernelId: this.backend.currentKernelId!, - programName: buildArtifact.programInfo.name, - inputTensorViews, - outputTensorViews, - }; - this.backend.pendingKernels.push(pendingKernelInfo); - - const sessionPendingKernels = this.backend.capturedPendingKernels.get(this.backend.currentSessionId!); - sessionPendingKernels!.push(pendingKernelInfo); - } - if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || this.backend.queryType === 'at-passes') { this.backend.endComputePass(); From 6e0ef2001c1b6d7cb7ca5cbed1546ba8d7a9a564 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 23 Jan 2024 13:51:10 +0800 Subject: [PATCH 22/28] flush the left commands before status changed --- js/web/lib/wasm/jsep/backend-webgpu.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e2ddaf3183521..5c5a29dfb3512 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -700,6 +700,8 @@ export class WebGpuBackend { } captureEnd(sessionHandle: number): void { LOG_DEBUG('info', () => `captureEnd ${sessionHandle}`); + // flush the left commands before we change the status. + this.flush(); this.currentSessionId = null; this.status = StatusType.default; } @@ -730,6 +732,8 @@ export class WebGpuBackend { this.flush(); } } + // flush the left commands before we change the status. + this.flush(); this.status = StatusType.default; } From b785a050b2171aa2331c9b8310678258602f8863 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Thu, 25 Jan 2024 15:01:58 +0800 Subject: [PATCH 23/28] address comments --- js/web/lib/wasm/binding/ort-wasm.d.ts | 15 +++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 52 +++++++++---------- js/web/lib/wasm/jsep/init.ts | 6 +-- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 4 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 4 +- js/web/lib/wasm/jsep/webgpu/types.ts | 6 +-- js/web/lib/wasm/wasm-core-impl.ts | 2 +- .../providers/js/js_execution_provider.cc | 6 +-- onnxruntime/wasm/js_internal_api.js | 4 +- 9 files changed, 48 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 8e55df50cadcc..272a041fe832a 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,9 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; - type CaptureBeginFunction = (sessionHandle: number) => void; - type CaptureEndFunction = (sessionHandle: number) => void; - type ReplayFunction = (sessionHandle: number) => void; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -181,11 +181,14 @@ export interface OrtWasmModule extends EmscriptenModule { (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. */ - jsepOnRunStart: () => void; + jsepOnRunStart: (sessionId: number) => void; /** - * [exported from js_internal_api.js] Release a session. + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. * @param sessionId - specify the session ID. * @returns */ diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5c5a29dfb3512..5b9d0c2faef65 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, StatusType, TimestampQuery} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; interface CommandInfo { readonly kernelId: number; @@ -111,9 +111,9 @@ export class WebGpuBackend { programManager: ProgramManager; /** - * representing the session ID of which is currently being captured/replay. - * `null` means no session is being captured. - * only valid when captureGraphEnabled = true. + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. */ currentSessionId: number|null = null; @@ -169,9 +169,9 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; - status: StatusType = StatusType.default; + sessionStatus: SessionState = 'default'; /** - * a SessionID -> CommandInfo[] mapping. + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. */ capturedCommandList: Map = new Map(); @@ -513,7 +513,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none' || this.status === StatusType.capture) { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -685,32 +685,29 @@ export class WebGpuBackend { } } - captureBegin(sessionHandle: number): void { - LOG_DEBUG('info', () => `captureBegin ${sessionHandle}`); - this.currentSessionId = sessionHandle; - let sessionCommandList = this.capturedCommandList.get(sessionHandle); - let sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle); + captureBegin(): void { + LOG_DEBUG('info', () => 'captureBegin'); + let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); if (!sessionCommandList) { sessionCommandList = []; - this.capturedCommandList.set(sessionHandle, sessionCommandList); + this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); sessionPendingKernels = []; - this.capturedPendingKernels.set(sessionHandle, sessionPendingKernels); + this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); } - this.status = StatusType.capture; + this.sessionStatus = 'capturing'; } - captureEnd(sessionHandle: number): void { - LOG_DEBUG('info', () => `captureEnd ${sessionHandle}`); + captureEnd(): void { + LOG_DEBUG('info', () => 'captureEnd'); // flush the left commands before we change the status. this.flush(); - this.currentSessionId = null; - this.status = StatusType.default; + this.sessionStatus = 'default'; } - replay(sessionHandle: number): void { - LOG_DEBUG('info', () => `replay ${sessionHandle}`); - this.currentSessionId = sessionHandle; - this.status = StatusType.replay; - const sessionCommandList = this.capturedCommandList.get(sessionHandle); - const sessionPendingKernels = this.capturedPendingKernels.get(sessionHandle); + replay(): void { + LOG_DEBUG('info', () => 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); const length = sessionCommandList!.length; this.pendingKernels = []; for (let i = 0; i < length; i++) { @@ -734,7 +731,7 @@ export class WebGpuBackend { } // flush the left commands before we change the status. this.flush(); - this.status = StatusType.default; + this.sessionStatus = 'default'; } onReleaseSession(sessionId: number): void { @@ -748,7 +745,8 @@ export class WebGpuBackend { this.gpuDataManager.onReleaseSession(sessionId); } - onRunStart(): void { + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 6901d5b7b7b44..786ae41646554 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -203,9 +203,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte return backend.computeKernel(kernel, context, errors); }, // jsepCaptureBegin - (sessionHandle: number) => backend.captureBegin(sessionHandle), + () => backend.captureBegin(), // jsepCaptureEnd - (sessionHandle: number) => backend.captureEnd(sessionHandle), + () => backend.captureEnd(), // jsepReplay - (sessionHandle: number) => backend.replay(sessionHandle)); + () => backend.replay()); }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index bd8d79b6df4dd..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -4,7 +4,7 @@ import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; -import {GpuData, GpuDataId, GpuDataType, StatusType} from './types'; +import {GpuData, GpuDataId, GpuDataType} from './types'; /** * manages GpuDataId -> GpuBuffer @@ -331,7 +331,7 @@ class GpuDataManagerImpl implements GpuDataManager { return; } - if (this.backend.status === StatusType.default) { + if (this.backend.sessionStatus === 'default') { for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index b84f10cd9023a..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -7,7 +7,7 @@ import {WebGpuBackend} from '../backend-webgpu'; import {LOG_DEBUG} from '../log'; import {createShaderHelper} from './ops/common'; -import {Artifact, GpuData, ProgramInfo, StatusType} from './types'; +import {Artifact, GpuData, ProgramInfo} from './types'; /** * ProgramManager is the main class behind running computations @@ -51,7 +51,7 @@ export class ProgramManager { const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - if (this.backend.status === StatusType.capture) { + if (this.backend.sessionStatus === 'capturing') { const commandInfo = { kernelId: this.backend.currentKernelId!, computePipeline: buildArtifact.computePipeline, diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 097255797d34c..52ff1510d0a86 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,11 +5,7 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; -export enum StatusType { - default = 0, - capture = 1, - replay = 2 -} +export type SessionState = 'default'|'capturing'|'replaying'; export enum GpuDataType { default = 0, diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 0f4800f9ad8c2..44fb20b4b2c38 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -526,7 +526,7 @@ export const run = async( } } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index b4e6c7862e1dd..6b4b60534f974 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -756,7 +756,7 @@ JsExecutionProvider::~JsExecutionProvider() { Status JsExecutionProvider::OnRunStart() { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; - EM_ASM({ Module.jsepCaptureBegin(Module.jsepSessionState.sessionHandle); }); + EM_ASM({ Module.jsepCaptureBegin(); }); } return Status::OK(); } @@ -764,7 +764,7 @@ Status JsExecutionProvider::OnRunStart() { Status JsExecutionProvider::OnRunEnd(bool sync_stream) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { - EM_ASM({ Module.jsepCaptureEnd(Module.jsepSessionState.sessionHandle); }); + EM_ASM({ Module.jsepCaptureEnd(); }); is_graph_captured_ = true; } else { IncrementRegularRunCountBeforeGraphCapture(); @@ -784,7 +784,7 @@ bool JsExecutionProvider::IsGraphCaptured() const { Status JsExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); - EM_ASM({ Module.jsepReplay(Module.jsepSessionState.sessionHandle); }); + EM_ASM({ Module.jsepReplay(); }); return Status::OK(); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index db3f4c4c088e2..cf9bfe111e8ca 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -189,7 +189,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepOnReleaseSession'] = sessionId => { backend['onReleaseSession'](sessionId); }; - Module['jsepOnRunStart'] = () => { - return backend['onRunStart'](); + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; }; From 3a80d5c1d8b092e935690c3aec7b23c05ce0835b Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 26 Jan 2024 10:59:43 +0800 Subject: [PATCH 24/28] rename to enableGraphCapture and move to SessionOptions --- js/common/lib/inference-session.ts | 9 ++++- js/web/lib/wasm/session-options.ts | 24 ++++++------ js/web/lib/wasm/wasm-core-impl.ts | 39 +++++++------------ .../providers/js/js_execution_provider.cc | 11 ++++-- .../core/providers/js/js_execution_provider.h | 12 ++---- .../core/providers/js/js_provider_factory.cc | 11 +++--- .../js/js_provider_factory_creator.h | 4 +- .../core/session/provider_registration.cc | 4 +- 8 files changed, 52 insertions(+), 62 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 3f7d3c8c1eb30..4f85c3b46e253 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ @@ -238,7 +244,6 @@ export declare namespace InferenceSession { export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgpu'; preferredLayout?: 'NCHW'|'NHWC'; - graphCaptureEnabled?: boolean; } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index fb29b7b5ce6b0..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -112,18 +112,6 @@ const setExecutionProviders = `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); } } - if (webgpuOptions?.graphCaptureEnabled !== undefined) { - if (typeof webgpuOptions.graphCaptureEnabled !== 'boolean') { - throw new Error(`graphCaptureEnabled must be a boolean value: ${webgpuOptions.graphCaptureEnabled}`); - } - const keyDataOffset = allocWasmString('graphCaptureEnabled', allocs); - const valueDataOffset = allocWasmString(webgpuOptions.graphCaptureEnabled.toString(), allocs); - if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== - 0) { - checkLastError(`Can't set a session config entry: 'graphCaptureEnabled' - ${ - webgpuOptions.graphCaptureEnabled}.`); - } - } } break; case 'wasm': @@ -180,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 44fb20b4b2c38..701387b02ddcb 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -11,7 +11,6 @@ import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; import {loadFile} from './wasm-utils-load-file'; -let currentEpName: string; // #region Initializations /** @@ -107,7 +106,6 @@ export const initEp = async(env: Env, epName: string): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env, adapter); } - currentEpName = epName; }; // #endregion Initializations @@ -141,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null, graphCaptureEnabled: boolean, inputOutputBounded: boolean + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBounded: boolean ]; const activeSessions = new Map(); @@ -237,18 +235,7 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); - let graphCaptureEnabled = false; - if (currentEpName === 'webgpu') { - const executionProviders = options?.executionProviders; - for (const ep of executionProviders!) { - const epName = typeof ep === 'string' ? ep : ep.name; - if (epName === 'webgpu') { - const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; - graphCaptureEnabled = - webgpuOptions.graphCaptureEnabled === undefined ? false : webgpuOptions.graphCaptureEnabled; - } - } - } + const enableGraphCapture = options?.enableGraphCapture === undefined ? false : options.enableGraphCapture; const inputNames = []; const outputNames = []; @@ -271,7 +258,7 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { - if (graphCaptureEnabled) { + if (enableGraphCapture) { outputPreferredLocations.push('gpu-buffer'); continue; } @@ -302,7 +289,7 @@ export const createSession = async( activeSessions.set( sessionHandle, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, graphCaptureEnabled, false]); + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -350,7 +337,7 @@ export const releaseSession = (sessionId: number): void => { export const prepareInputOutputTensor = (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, - graphCaptureEnabled = false): void => { + enableGraphCapture = false): void => { if (!tensor) { tensorHandles.push(0); return; @@ -369,9 +356,9 @@ export const prepareInputOutputTensor = throw new Error('String tensor is not supported on GPU.'); } - if (graphCaptureEnabled && location !== 'gpu-buffer') { + if (enableGraphCapture && location !== 'gpu-buffer') { throw new Error( - `External buffer must be provided for input/output index ${index} when graphCaptureEnabled is true.`); + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); } if (location === 'gpu-buffer') { @@ -434,7 +421,7 @@ export const run = async( const inputNamesUTF8Encoded = session[1]; const outputNamesUTF8Encoded = session[2]; const ioBindingState = session[3]; - const graphCaptureEnabled = session[4]; + const enableGraphCapture = session[4]; const inputOutputBounded = session[5]; const inputCount = inputIndices.length; @@ -459,14 +446,14 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { prepareInputOutputTensor( - inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], graphCaptureEnabled); + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], - graphCaptureEnabled); + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -483,7 +470,7 @@ export const run = async( } if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { - if (!graphCaptureEnabled || (graphCaptureEnabled && !inputOutputBounded)) { + if (!enableGraphCapture || (enableGraphCapture && !inputOutputBounded)) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -522,7 +509,7 @@ export const run = async( } activeSessions.set( sessionId, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, graphCaptureEnabled, true]); + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } } @@ -633,7 +620,7 @@ export const run = async( } } - if (ioBindingState && !graphCaptureEnabled) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); } return output; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6b4b60534f974..308e1c7d952d9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -682,10 +682,13 @@ std::unique_ptr RegisterKernels() { using namespace js; -JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, - preferred_data_layout_{info.data_layout}, - graph_capture_enabled_(info.graph_capture_enabled) { + preferred_data_layout_{info.data_layout} { + if (session_options) { + enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; + LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; + } } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -775,7 +778,7 @@ Status JsExecutionProvider::OnRunEnd(bool sync_stream) { } bool JsExecutionProvider::IsGraphCaptureEnabled() const { - return graph_capture_enabled_; + return enable_graph_capture_; } bool JsExecutionProvider::IsGraphCaptured() const { diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index b89c4e03bb6f4..91a3256ec2bd5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -5,6 +5,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" @@ -30,22 +31,15 @@ struct JsExecutionProviderInfo { data_layout = DataLayout::NHWC; } } - const std::string& graph_capture_enabled_str = po.at("graph_capture_enabled"); - if (graph_capture_enabled_str == "true") { - graph_capture_enabled = true; - } else { - graph_capture_enabled = false; - } } // JSEP default preferred layout is NHWC DataLayout data_layout = DataLayout::NHWC; - bool graph_capture_enabled = false; }; class JsExecutionProvider : public IExecutionProvider { public: - JsExecutionProvider(const JsExecutionProviderInfo& info); + JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ~JsExecutionProvider() override; std::vector> GetCapability( @@ -76,7 +70,7 @@ class JsExecutionProvider : public IExecutionProvider { bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; - bool graph_capture_enabled_ = false; + bool enable_graph_capture_ = false; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc index 5b7329a87cf6a..cbdf99f702150 100644 --- a/onnxruntime/core/providers/js/js_provider_factory.cc +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -10,21 +10,22 @@ namespace onnxruntime { struct JsProviderFactory : IExecutionProviderFactory { - JsProviderFactory(const ProviderOptions& provider_options) - : info_{provider_options} { + JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) + : info_{provider_options}, session_options_(session_options) { } std::unique_ptr CreateProvider() override { - return std::make_unique(info_); + return std::make_unique(info_, session_options_); } private: JsExecutionProviderInfo info_; + const SessionOptions* session_options_; }; std::shared_ptr JsProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - return std::make_shared(provider_options); + const ProviderOptions& provider_options, const SessionOptions* session_options) { + return std::make_shared(provider_options, session_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h index dbabe255c2d7b..510b0fb4248ca 100644 --- a/onnxruntime/core/providers/js/js_provider_factory_creator.h +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -9,9 +9,11 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; struct JsProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 9301597df4d7d..964355956b4ab 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -145,9 +145,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } - std::string graph_capture_enabled = options->value.config_options.GetConfigOrDefault("graphCaptureEnabled", "false"); - provider_options["graph_capture_enabled"] = graph_capture_enabled; - options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); #endif From dc8cc2b1bb5ad44ba02ea130be5cae5ccc214d77 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Fri, 26 Jan 2024 15:41:20 +0800 Subject: [PATCH 25/28] nits --- js/web/lib/wasm/jsep/backend-webgpu.ts | 2 ++ js/web/lib/wasm/wasm-core-impl.ts | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index f39a8ba065465..7e00785362c0b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -711,6 +711,8 @@ export class WebGpuBackend { sessionPendingKernels = []; this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); } + // flush the left commands before we change the status. + this.flush(); this.sessionStatus = 'capturing'; } captureEnd(): void { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7726ee0ae7bce..2f652ecf72f71 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -258,7 +258,7 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { - if (enableGraphCapture) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { outputPreferredLocations.push('gpu-buffer'); continue; } @@ -268,6 +268,10 @@ export const createSession = async( if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } From dff25faa81bba7d296348e1b6bc41588cd471314 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 29 Jan 2024 14:30:54 +0800 Subject: [PATCH 26/28] Address Yulong's comments --- js/web/lib/wasm/jsep/backend-webgpu.ts | 6 +++--- js/web/lib/wasm/wasm-core-impl.ts | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 7e00785362c0b..e1faecfc046e3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -702,7 +702,7 @@ export class WebGpuBackend { } captureBegin(): void { - LOG_DEBUG('info', () => 'captureBegin'); + LOG_DEBUG('info', 'captureBegin'); let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); if (!sessionCommandList) { @@ -716,13 +716,13 @@ export class WebGpuBackend { this.sessionStatus = 'capturing'; } captureEnd(): void { - LOG_DEBUG('info', () => 'captureEnd'); + LOG_DEBUG('info', 'captureEnd'); // flush the left commands before we change the status. this.flush(); this.sessionStatus = 'default'; } replay(): void { - LOG_DEBUG('info', () => 'replay'); + LOG_DEBUG('info', 'replay'); this.sessionStatus = 'replaying'; const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 2f652ecf72f71..f3423d6636750 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -139,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBounded: boolean + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,7 +235,7 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); - const enableGraphCapture = options?.enableGraphCapture === undefined ? false : options.enableGraphCapture; + const enableGraphCapture = !!options?.enableGraphCapture; const inputNames = []; const outputNames = []; @@ -426,7 +426,7 @@ export const run = async( const outputNamesUTF8Encoded = session[2]; const ioBindingState = session[3]; const enableGraphCapture = session[4]; - const inputOutputBounded = session[5]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -474,7 +474,7 @@ export const run = async( } if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { - if (!enableGraphCapture || (enableGraphCapture && !inputOutputBounded)) { + if (!enableGraphCapture || !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -626,6 +626,9 @@ export const run = async( if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } return output; } finally { From 1beb3f16ef276227ac905629174ac3e6eaf0ee7a Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 29 Jan 2024 15:34:17 +0800 Subject: [PATCH 27/28] further simplify if (!enableGraphCapture || !inputOutputBound) --- js/web/lib/wasm/wasm-core-impl.ts | 66 +++++++++++++++---------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index f3423d6636750..08f3067ac74cc 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -473,48 +473,46 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { - if (!enableGraphCapture || !inputOutputBound) { - const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; - if (inputNamesUTF8Encoded.length !== inputCount) { - throw new Error(`input count from feeds (${ - inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); - } + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } - // process inputs - for (let i = 0; i < inputCount; i++) { - const index = inputIndices[i]; - const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); - if (errorCode !== 0) { - checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); - } + // process inputs + for (let i = 0; i < inputCount; i++) { + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); } + } - // process pre-allocated outputs - for (let i = 0; i < outputCount; i++) { - const index = outputIndices[i]; - const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + // process pre-allocated outputs + for (let i = 0; i < outputCount; i++) { + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. - if (location) { - // output is pre-allocated. bind the tensor. - const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); - if (errorCode !== 0) { - checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); - } - } else { - // output is not pre-allocated. reset preferred location. - const errorCode = - wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); - if (errorCode !== 0) { - checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); - } + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); } } - activeSessions.set( - sessionId, - [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } wasm.jsepOnRunStart?.(sessionHandle); From 59f5f92a1911e46862a9033353cc276ccb97c8ac Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Mon, 29 Jan 2024 17:50:35 +0800 Subject: [PATCH 28/28] call OrtClearBoundOutputs when release session --- js/web/lib/wasm/wasm-core-impl.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 08f3067ac74cc..37b9ed6a1002f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -325,9 +325,12 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); }