Skip to content

Commit

Permalink
[WebNN EP] Fix issues with MLTensor caching
Browse files Browse the repository at this point in the history
This PR fixes a bug that occurs when searching for compatible `MLTensor` in the cache. We were missing checking the number of dimensions in the shape. This would mean that a cached buffer of shape `[1]` could match for `[1, 1, 256, 256]`.

This PR also adds better handling when attempting to force an `MLTensor` to a different shape.
  • Loading branch information
egalli committed Nov 2, 2024
1 parent d419df4 commit 959600f
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,33 @@ export interface TensorManager {
let tensorGuid = 1;
const createNewTensorId = (): TensorId => tensorGuid++;

/**
* Map from MLOperandDataType to size in bytes.
*/
const webnnDataTypeToSize = new Map<MLOperandDataType, number>([
['float32', 4],
['float16', 2],
['int32', 4],
['uint32', 4],
['int64', 8],
['uint64', 8],
['int8', 1],
['uint8', 1],
['int4', 0.5],
['uint4', 0.5],
]);

/**
* Calculate the byte length of a tensor with the given data type and shape.
*/
const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => {
const size = webnnDataTypeToSize.get(dataType);
if (!size) {
throw new Error('Unsupported data type.');
}
return Math.ceil(shape.reduce((a, b) => a * b) * size);
};

/**
* TensorWrapper wraps an MLTensor and provides a way to track the last session that used it.
*/
Expand Down Expand Up @@ -92,6 +119,10 @@ class TensorWrapper {
return this.tensorShape;
}

public get byteLength(): number {
return calculateByteLength(this.dataType, this.tensorShape);
}

public destroy(): void {
LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy');
this.mlTensor.destroy();
Expand All @@ -111,7 +142,11 @@ class TensorWrapper {
}

public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]);
return (
this.dataType === dataType &&
this.tensorShape.length === shape.length &&
this.tensorShape.every((v, i) => v === shape[i])
);
}
}

Expand All @@ -136,6 +171,7 @@ class TensorIdTracker {
public releaseTensor(): void {
if (this.tensorWrapper) {
this.tensorManager.releaseTensor(this.tensorWrapper);
this.wrapper = undefined;
}
}

Expand All @@ -149,6 +185,9 @@ class TensorIdTracker {
return this.wrapper.tensor;
} else {
if (copyOld) {
if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) {
throw new Error('Unable to copy data to tensor with different size.');
}
this.activeUpload = new Uint8Array(await this.wrapper.read());
}
this.tensorManager.releaseTensor(this.wrapper);
Expand All @@ -169,8 +208,13 @@ class TensorIdTracker {

public upload(data: Uint8Array): void {
if (this.wrapper) {
this.wrapper.write(data);
return;
if (data.byteLength === this.wrapper.byteLength) {
this.wrapper.write(data);
return;
} else {
LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.');
this.releaseTensor();
}
}

if (this.activeUpload) {
Expand Down Expand Up @@ -312,6 +356,7 @@ class TensorManagerImpl implements TensorManager {
const sessionId = this.backend.currentSessionId;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;
Expand Down

0 comments on commit 959600f

Please sign in to comment.