diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 4932691bda65b..45b5b8b4fa932 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -141,8 +141,9 @@ class TensorWrapper { return this.mlContext.readTensor(this.mlTensor); } - public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean { + public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean { return ( + this.mlContext === context && this.dataType === dataType && this.tensorShape.length === shape.length && this.tensorShape.every((v, i) => v === shape[i]) @@ -176,12 +177,13 @@ class TensorIdTracker { } public async ensureTensor( + context: MLContext, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, ): Promise { if (this.wrapper) { - if (this.wrapper.sameTypeAndShape(dataType, shape)) { + if (this.wrapper.canReuseTensor(context, dataType, shape)) { return this.wrapper.tensor; } else { if (copyOld) { @@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager { if (!tensor) { throw new Error('Tensor not found.'); } - return tensor.ensureTensor(dataType, shape, copyOld); + return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { @@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager { readable: boolean, ): Promise { const sessionId = this.backend.currentSessionId; + const context = this.backend.currentContext; for (const [index, tensor] of this.freeTensors.entries()) { - if (tensor.sameTypeAndShape(dataType, shape)) { + if (tensor.canReuseTensor(context, dataType, shape)) { LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); const wrapper = this.freeTensors.splice(index, 1)[0]; wrapper.sessionId = sessionId; return wrapper; } } - const context = this.backend.currentContext; LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); const tensor = await context.createTensor({ dataType,