Skip to content

Commit

Permalink
[WebNN] Fixes MLTensor caching across different contexts
Browse files Browse the repository at this point in the history
We weren't checking that MLTensors were from the same context before reusing them.

Found while debugging microsoft/webnn-developer-preview#69
  • Loading branch information
egalli committed Dec 13, 2024
1 parent 62e7e24 commit c0557b6
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ class TensorWrapper {
return this.mlContext.readTensor(this.mlTensor);
}

public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
return (
this.mlContext === context &&
this.dataType === dataType &&
this.tensorShape.length === shape.length &&
this.tensorShape.every((v, i) => v === shape[i])
Expand Down Expand Up @@ -176,12 +177,13 @@ class TensorIdTracker {
}

public async ensureTensor(
context: MLContext,
dataType: MLOperandDataType,
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor> {
if (this.wrapper) {
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
return this.wrapper.tensor;
} else {
if (copyOld) {
Expand Down Expand Up @@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
if (!tensor) {
throw new Error('Tensor not found.');
}
return tensor.ensureTensor(dataType, shape, copyOld);
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
}

public upload(tensorId: TensorId, data: Uint8Array): void {
Expand Down Expand Up @@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
readable: boolean,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
const context = this.backend.currentContext;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
if (tensor.canReuseTensor(context, dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;
}
}
const context = this.backend.currentContext;
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
const tensor = await context.createTensor({
dataType,
Expand Down

0 comments on commit c0557b6

Please sign in to comment.