diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 685f3dc019461..d13136d252d2a 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -91,12 +91,12 @@ export class WebNNBackend { // Current session is not a WebNN session. return; } + this.tensorManager.releaseTensorsForSession(sessionId); this.mlContextBySessionId.delete(sessionId); const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); - this.tensorManager.releaseTensorsForContext(mlContext); } } diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 9475de019ed1d..13888fa855ef6 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -42,9 +42,9 @@ export interface TensorManager { download(tensorId: TensorId): Promise; download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise; /** - * Release all tensors for a MLContext. + * Release all tensors for a given session. */ - releaseTensorsForContext(mlContext: MLContext): void; + releaseTensorsForSession(session: number): void; /** * Register an externally created MLTensor with a given MLContext and return a TensorId. */ @@ -54,65 +54,89 @@ export interface TensorManager { let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; -export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]]; - /** - * TensorTracker tracks the MLTensor and pending upload data. - * - * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until - * we know the data type and shape. This is because future implementations of WebNN will only support creating - * MLTensors with dataTypes and shape. + * TensorWrapper wraps an MLTensor and provides a way to track the last session that used it. */ -class TensorTracker { - private tensorEntry?: MLTensorEntry; - private activeUpload?: Uint8Array; - private tensorCache: MLTensorEntry[]; +class TensorWrapper { + // The id of the last session that used this tensor. + public sessionId: number; - constructor( - private mlContext?: MLContext, - tensorEntry?: MLTensorEntry, - ) { - this.tensorEntry = tensorEntry; - this.tensorCache = tensorEntry ? [tensorEntry] : []; + private mlContext: MLContext; + private mlTensor: MLTensor; + private dataType: MLOperandDataType; + private tensorShape: readonly number[]; + + constructor(descriptor: { + sessionId: number; + context: MLContext; + tensor: MLTensor; + dataType: MLOperandDataType; + shape: readonly number[]; + }) { + this.sessionId = descriptor.sessionId; + this.mlContext = descriptor.context; + this.mlTensor = descriptor.tensor; + this.dataType = descriptor.dataType; + this.tensorShape = descriptor.shape; } - public get tensor(): MLTensor | undefined { - return this.tensorEntry?.[0]; + public get tensor(): MLTensor { + return this.mlTensor; } - public get context(): MLContext { - if (!this.mlContext) { - throw new Error('MLContext has not been set.'); - } - return this.mlContext; + public get type(): MLOperandDataType { + return this.dataType; } - public set context(mlContext: MLContext) { - if (this.mlContext && this.mlContext !== mlContext) { - throw new Error('MLTensor in use in a different MLContext.'); - } - this.mlContext = mlContext; + public get shape(): readonly number[] { + return this.tensorShape; } public destroy(): void { - for (const [mlTensor] of this.tensorCache) { - mlTensor.destroy(); + LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy'); + this.mlTensor.destroy(); + } + + public write(data: Uint8Array): void { + this.mlContext.writeTensor(this.mlTensor, data); + } + + public async read(): Promise; + public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise; + async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + if (dstBuffer) { + return this.mlContext.readTensor(this.mlTensor, dstBuffer); } - this.tensorCache = []; - this.tensorEntry = undefined; + return this.mlContext.readTensor(this.mlTensor); } - public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean { - for (const [mlTensor, dataType, shape] of this.tensorCache) { - if (tryMLTensor === mlTensor) { - if (this.context !== context) { - throw new Error('MLTensor cannot be registered with a different MLContext.'); - } - this.tensorEntry = [mlTensor, dataType, shape]; - return true; - } + public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean { + return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]); + } +} + +/** + * TensorTracker tracks the MLTensor and pending upload data. + * + * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until + * we know the data type and shape. This is because WebNN only support creating MLTensors with dataTypes and shape. + */ +class TensorIdTracker { + private activeUpload?: Uint8Array; + + constructor( + private tensorManager: TensorManagerImpl, + private wrapper?: TensorWrapper, + ) {} + + public get tensorWrapper(): TensorWrapper | undefined { + return this.wrapper; + } + + public releaseTensor(): void { + if (this.tensorWrapper) { + this.tensorManager.releaseTensor(this.tensorWrapper); } - return false; } public async ensureTensor( @@ -120,55 +144,40 @@ class TensorTracker { shape: readonly number[], copyOld: boolean, ): Promise { - if (this.tensorEntry) { - const [mlTensor, existingDataType, existingShape] = this.tensorEntry; - if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { - return mlTensor; - } - } - - for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) { - if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { - if (copyOld && this.tensorEntry) { - // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. - LOG_DEBUG( - 'verbose', - () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`, - ); - const data = await this.context.readTensor(this.tensorEntry[0]); - this.context.writeTensor(mlTensor, data); + if (this.wrapper) { + if (this.wrapper.sameTypeAndShape(dataType, shape)) { + return this.wrapper.tensor; + } else { + if (copyOld) { + this.activeUpload = new Uint8Array(await this.wrapper.read()); } - this.tensorEntry = [mlTensor, existingDataType, existingShape]; - return mlTensor; + this.tensorManager.releaseTensor(this.wrapper); } } - LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); + // eslint-disable-next-line no-bitwise const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; - const tensor = await this.context.createTensor({ - dataType, - shape, - // Assign both shape and dimensions while transitioning to new API. - dimensions: shape, - usage, - }); - this.tensorEntry = [tensor, dataType, shape]; - this.tensorCache.push(this.tensorEntry); + this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage); - if (this.activeUpload) { - this.mlContext?.writeTensor(tensor, this.activeUpload); + if (copyOld && this.activeUpload) { + this.wrapper.write(this.activeUpload); this.activeUpload = undefined; } - return tensor; + return this.wrapper.tensor; } public upload(data: Uint8Array): void { - if (!this.tensorEntry) { - this.activeUpload = new Uint8Array(data); + if (this.wrapper) { + this.wrapper.write(data); return; } - this.mlContext?.writeTensor(this.tensorEntry[0], data); + + if (this.activeUpload) { + this.activeUpload.set(data); + } else { + this.activeUpload = new Uint8Array(data); + } } public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { @@ -179,49 +188,42 @@ class TensorTracker { } else { new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); } - return; } else { return this.activeUpload.buffer; } } - if (!this.tensorEntry) { + if (!this.wrapper) { throw new Error('Tensor has not been created.'); } - if (dstBuffer) { - return this.context.readTensor(this.tensorEntry[0], dstBuffer); + if (!dstBuffer) { + return this.wrapper.read(); } - return this.context.readTensor(this.tensorEntry[0]); + return this.wrapper.read(dstBuffer); } } class TensorManagerImpl implements TensorManager { - private tensorsById = new Map(); - private tensorIdsByContext = new Map>(); + private tensorTrackersById: Map = new Map(); + private freeTensors: TensorWrapper[] = []; + private externalTensors: Set = new Set(); constructor(private backend: WebNNBackend) {} public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); - this.tensorsById.set(tensorId, new TensorTracker()); + this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); return tensorId; } public releaseTensorId(tensorId: TensorId): void { - const tensorTracker = this.tensorsById.get(tensorId); + const tensorTracker = this.tensorTrackersById.get(tensorId); if (!tensorTracker) { return; } - tensorTracker.destroy(); - this.tensorsById.delete(tensorId); - for (const [mlContext, tensors] of this.tensorIdsByContext) { - if (tensors.has(tensorId)) { - tensors.delete(tensorId); - if (tensors.size === 0) { - this.tensorIdsByContext.delete(mlContext); - } - break; - } + this.tensorTrackersById.delete(tensorId); + if (tensorTracker.tensorWrapper) { + this.releaseTensor(tensorTracker.tensorWrapper); } } @@ -238,20 +240,19 @@ class TensorManagerImpl implements TensorManager { dataType }, shape: ${shape}, copyOld: ${copyOld}}`, ); - const tensor = this.tensorsById.get(tensorId); + const tensor = this.tensorTrackersById.get(tensorId); if (!tensor) { throw new Error('Tensor not found.'); } - tensor.context = this.backend.currentContext; - if (!this.tensorIdsByContext.has(this.backend.currentContext)) { - this.tensorIdsByContext.set(this.backend.currentContext, new Set()); - } - this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId); return tensor.ensureTensor(dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { - this.tensorsById.get(tensorId)!.upload(data); + const tensor = this.tensorTrackersById.get(tensorId); + if (!tensor) { + throw new Error('Tensor not found.'); + } + tensor.upload(data); } public async download(tensorId: TensorId): Promise; @@ -261,19 +262,20 @@ class TensorManagerImpl implements TensorManager { 'verbose', () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`, ); - return this.tensorsById.get(tensorId)!.download(dstBuffer); + const tensorTracker = this.tensorTrackersById.get(tensorId); + if (!tensorTracker) { + throw new Error('Tensor not found.'); + } + return tensorTracker.download(dstBuffer); } - public releaseTensorsForContext(mlContext: MLContext): void { - const tensors = this.tensorIdsByContext.get(mlContext); - if (!tensors) { - return; - } - for (const tensorId of tensors) { - this.tensorsById.get(tensorId)!.destroy(); - this.tensorsById.delete(tensorId); + public releaseTensorsForSession(sessionId: number): void { + for (const tensor of this.freeTensors) { + if (tensor.sessionId === sessionId) { + tensor.destroy(); + } } - this.tensorIdsByContext.delete(mlContext); + this.freeTensors = this.freeTensors.filter((tensor) => tensor.sessionId !== sessionId); } public registerTensor( @@ -282,20 +284,56 @@ class TensorManagerImpl implements TensorManager { dataType: MLOperandDataType, shape: readonly number[], ): TensorId { - for (const [tensorId, tensorTracker] of this.tensorsById) { - if (tensorTracker.trySelectTensor(mlContext, mlTensor)) { - return tensorId; + const tensorId = createNewTensorId(); + // Defaulting to READ | WRITE if usage is not provided. + // eslint-disable-next-line no-bitwise + const wrapper = new TensorWrapper({ + sessionId: this.backend.currentSessionId, + context: mlContext, + tensor: mlTensor, + dataType, + shape, + }); + this.tensorTrackersById.set(tensorId, new TensorIdTracker(this, wrapper)); + this.externalTensors.add(wrapper); + return tensorId; + } + + /** + * Get or create an MLTensor with the given data type and shape. + */ + public async getCachedTensor( + dataType: MLOperandDataType, + shape: readonly number[], + usage: MLTensorUsageFlags, + ): Promise { + const sessionId = this.backend.currentSessionId; + for (const [index, tensor] of this.freeTensors.entries()) { + if (tensor.sameTypeAndShape(dataType, shape)) { + const wrapper = this.freeTensors.splice(index, 1)[0]; + wrapper.sessionId = sessionId; + return wrapper; } } - const tensorId = createNewTensorId(); - this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape])); - let tensors = this.tensorIdsByContext.get(mlContext); - if (!tensors) { - tensors = new Set(); - this.tensorIdsByContext.set(mlContext, tensors); + const context = this.backend.currentContext; + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); + const tensor = await context.createTensor({ + dataType, + shape, + dimensions: shape, + usage, + }); + return new TensorWrapper({ sessionId, context, tensor, dataType, shape }); + } + + /** + * Release tensor for reuse unless external. + */ + public releaseTensor(tensorWrapper: TensorWrapper) { + if (this.externalTensors.has(tensorWrapper)) { + this.externalTensors.delete(tensorWrapper); } - tensors.add(tensorId); - return tensorId; + this.freeTensors.push(tensorWrapper); } }