diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index bfb74355b0d70..50d83f5af26e0 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -902,6 +902,10 @@ export class WebGpuBackend { this.sessionStatus = 'default'; } + onCreateSession(): void { + this.gpuDataManager.onCreateSession(); + } + onReleaseSession(sessionId: number): void { this.unregisterBuffers(sessionId); if (this.capturedCommandList.has(sessionId)) { diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 33e8c95c141ee..4e14c8b58e2bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -64,6 +64,11 @@ export interface GpuDataManager { */ dispose(): void; + /** + * create session related data. + */ + onCreateSession(): void; + /** * release session related data. * @param sessionId - specify the session ID. @@ -200,6 +205,9 @@ class GpuDataManagerImpl implements GpuDataManager { // a SessionID -> GPUBuffer[] mapping. private capturedPendingBuffers: Map<number, GPUBuffer[]>; + // The session count. + private sessionCount: number; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -213,6 +221,8 @@ class GpuDataManagerImpl implements GpuDataManager { this.freeBuffers.set(key, []); this.freeUniformBuffers.set(key, []); } + + this.sessionCount = 0; } upload(id: GpuDataId, data: Uint8Array): void { @@ -360,7 +370,12 @@ class GpuDataManagerImpl implements GpuDataManager { release(id: GpuDataId): number { const cachedData = this.storageCache.get(id); if (!cachedData) { - throw new Error('releasing data does not exist'); + if (this.storageCache.size === 0) { + // cache was previously cleared, no need to release anything. + return 0; + } else { + throw new Error('releasing data does not exist'); + } } LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`); @@ -460,6 +475,10 @@ class GpuDataManagerImpl implements GpuDataManager { this.capturedPendingBuffers = new Map(); } + onCreateSession() { + this.sessionCount += 1; + } + onReleaseSession(sessionId: number) { // release the captured pending buffers. const pendingBuffers = this.capturedPendingBuffers.get(sessionId); @@ -469,6 +488,16 @@ class GpuDataManagerImpl implements GpuDataManager { }); this.capturedPendingBuffers.delete(sessionId); } + + // release the storage cache if no active sessions. + this.sessionCount -= 1; + if (this.sessionCount === 0) { + LOG_DEBUG('warning', () => '[WebGPU] Clearing webgpu buffer cache'); + this.storageCache.forEach((storage) => { + storage.gpuData.buffer.destroy(); + }); + this.storageCache = new Map(); + } } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5f219f63aaf61..19f89feb9d0d7 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -317,6 +317,8 @@ export const createSession = async ( checkLastError("Can't create a session."); } + wasm.jsepOnCreateSession?.(); + // clear current MLContext after session creation if (wasm.currentContext) { wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 3e08fe97f559d..16674f0f4a79e 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -141,6 +141,12 @@ export declare namespace JSEP { * @param sessionId - specify the session ID. */ jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is + * called. + * @returns + */ + jsepOnCreateSession: () => void; /** * [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is * called. diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 78d60326dd0a8..0efbcab3a3238 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -192,6 +192,9 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnCreateSession'] = sessionId => { + backend['onCreateSession'](sessionId); + }; Module['jsepOnReleaseSession'] = sessionId => { backend['onReleaseSession'](sessionId); };