Skip to content

Commit

Permalink
Assign to both shape and dimensions when creating MLTensors
Browse files Browse the repository at this point in the history
  • Loading branch information
egalli committed Sep 18, 2024
1 parent 5b3ceec commit 7bc892d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
8 changes: 7 additions & 1 deletion js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,13 @@ class TensorTracker {
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
// eslint-disable-next-line no-bitwise
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
const tensor = await this.context.createTensor({ dataType, shape, usage });
const tensor = await this.context.createTensor({
dataType,
shape,
// Assign both shape and dimensions while transitioning to new API.
dimensions: shape,
usage,
});
this.tensorEntry = [tensor, dataType, shape];
this.tensorCache.push(this.tensorEntry);

Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'
interface MLOperandDescriptor {
dataType: MLOperandDataType;
shape?: readonly number[];
/** @deprecated Use shape instead of dimensions */
dimensions?: readonly number[];
}
interface MLOperand {
dataType(): MLOperandDataType;
Expand Down
4 changes: 4 additions & 0 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty
const mlTensor = await mlContext.createTensor({
dataType,
shape: dims as number[],
// Assign both shape and dimensions while transitioning to new API.
dimensions: dims as number[],
usage: MLTensorUsage.READ,
});

Expand All @@ -686,6 +688,8 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
const mlTensor = await mlContext.createTensor({
dataType,
shape: cpuTensor.dims as number[],
// Assign both shape and dimensions while transitioning to new API.
dimensions: cpuTensor.dims as number[],
usage: MLTensorUsage.WRITE,
});
mlContext.writeTensor(mlTensor, cpuTensor.data);
Expand Down

0 comments on commit 7bc892d

Please sign in to comment.