From 959600fce88102bf530e5e70a6fb77cfd06e8c1e Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Sat, 2 Nov 2024 01:17:34 -0700 Subject: [PATCH] [WebNN EP] Fix issues with MLTensor caching This PR fixes a bug that occurs when searching for compatible `MLTensor` in the cache. We were missing checking the number of dimensions in the shape. This would mean that a cached buffer of shape `[1]` could match for `[1, 1, 256, 256]`. This PR also adds better handling when attempting to force an `MLTensor` to a different shape. --- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 51 ++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 916dec4545af3..35b1640afa266 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -54,6 +54,33 @@ export interface TensorManager { let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; +/** + * Map from MLOperandDataType to size in bytes. + */ +const webnnDataTypeToSize = new Map([ + ['float32', 4], + ['float16', 2], + ['int32', 4], + ['uint32', 4], + ['int64', 8], + ['uint64', 8], + ['int8', 1], + ['uint8', 1], + ['int4', 0.5], + ['uint4', 0.5], +]); + +/** + * Calculate the byte length of a tensor with the given data type and shape. + */ +const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => { + const size = webnnDataTypeToSize.get(dataType); + if (!size) { + throw new Error('Unsupported data type.'); + } + return Math.ceil(shape.reduce((a, b) => a * b) * size); +}; + /** * TensorWrapper wraps an MLTensor and provides a way to track the last session that used it. */ @@ -92,6 +119,10 @@ class TensorWrapper { return this.tensorShape; } + public get byteLength(): number { + return calculateByteLength(this.dataType, this.tensorShape); + } + public destroy(): void { LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy'); this.mlTensor.destroy(); @@ -111,7 +142,11 @@ class TensorWrapper { } public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean { - return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]); + return ( + this.dataType === dataType && + this.tensorShape.length === shape.length && + this.tensorShape.every((v, i) => v === shape[i]) + ); } } @@ -136,6 +171,7 @@ class TensorIdTracker { public releaseTensor(): void { if (this.tensorWrapper) { this.tensorManager.releaseTensor(this.tensorWrapper); + this.wrapper = undefined; } } @@ -149,6 +185,9 @@ class TensorIdTracker { return this.wrapper.tensor; } else { if (copyOld) { + if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) { + throw new Error('Unable to copy data to tensor with different size.'); + } this.activeUpload = new Uint8Array(await this.wrapper.read()); } this.tensorManager.releaseTensor(this.wrapper); @@ -169,8 +208,13 @@ class TensorIdTracker { public upload(data: Uint8Array): void { if (this.wrapper) { - this.wrapper.write(data); - return; + if (data.byteLength === this.wrapper.byteLength) { + this.wrapper.write(data); + return; + } else { + LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.'); + this.releaseTensor(); + } } if (this.activeUpload) { @@ -312,6 +356,7 @@ class TensorManagerImpl implements TensorManager { const sessionId = this.backend.currentSessionId; for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.sameTypeAndShape(dataType, shape)) { + LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); const wrapper = this.freeTensors.splice(index, 1)[0]; wrapper.sessionId = sessionId; return wrapper;