Skip to content

Commit

Permalink
Clears GPU Cache when there are no more active sessions (microsoft#22490
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr authored and ankitm3k committed Dec 11, 2024
1 parent 2e756ab commit 57c3402
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 1 deletion.
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,10 @@ export class WebGpuBackend {
this.sessionStatus = 'default';
}

onCreateSession(): void {
this.gpuDataManager.onCreateSession();
}

onReleaseSession(sessionId: number): void {
this.unregisterBuffers(sessionId);
if (this.capturedCommandList.has(sessionId)) {
Expand Down
31 changes: 30 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ export interface GpuDataManager {
*/
dispose(): void;

/**
* create session related data.
*/
onCreateSession(): void;

/**
* release session related data.
* @param sessionId - specify the session ID.
Expand Down Expand Up @@ -200,6 +205,9 @@ class GpuDataManagerImpl implements GpuDataManager {
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;

// The session count.
private sessionCount: number;

constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
Expand All @@ -213,6 +221,8 @@ class GpuDataManagerImpl implements GpuDataManager {
this.freeBuffers.set(key, []);
this.freeUniformBuffers.set(key, []);
}

this.sessionCount = 0;
}

upload(id: GpuDataId, data: Uint8Array): void {
Expand Down Expand Up @@ -360,7 +370,12 @@ class GpuDataManagerImpl implements GpuDataManager {
release(id: GpuDataId): number {
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('releasing data does not exist');
if (this.storageCache.size === 0) {
// cache was previously cleared, no need to release anything.
return 0;
} else {
throw new Error('releasing data does not exist');
}
}

LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.release(id=${id}), gpuDataId=${cachedData.gpuData.id}`);
Expand Down Expand Up @@ -460,6 +475,10 @@ class GpuDataManagerImpl implements GpuDataManager {
this.capturedPendingBuffers = new Map();
}

onCreateSession() {
this.sessionCount += 1;
}

onReleaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
Expand All @@ -469,6 +488,16 @@ class GpuDataManagerImpl implements GpuDataManager {
});
this.capturedPendingBuffers.delete(sessionId);
}

// release the storage cache if no active sessions.
this.sessionCount -= 1;
if (this.sessionCount === 0) {
LOG_DEBUG('warning', () => '[WebGPU] Clearing webgpu buffer cache');
this.storageCache.forEach((storage) => {
storage.gpuData.buffer.destroy();
});
this.storageCache = new Map();
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ export const createSession = async (
checkLastError("Can't create a session.");
}

wasm.jsepOnCreateSession?.();

// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
Expand Down
6 changes: 6 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ export declare namespace JSEP {
* @param sessionId - specify the session ID.
*/
jsepOnRunStart: (sessionId: number) => void;
/**
* [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is
* called.
* @returns
*/
jsepOnCreateSession: () => void;
/**
* [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is
* called.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ Module['jsepInit'] = (name, params) => {
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnCreateSession'] = sessionId => {
backend['onCreateSession'](sessionId);
};
Module['jsepOnReleaseSession'] = sessionId => {
backend['onReleaseSession'](sessionId);
};
Expand Down

0 comments on commit 57c3402

Please sign in to comment.