Skip to content

Commit

Permalink
Rename dimensions to shape
Browse files Browse the repository at this point in the history
  • Loading branch information
egalli committed Sep 18, 2024
1 parent 16c14fb commit bbbb0b1
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 33 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ export const init = async (
// jsepReleaseTensorId,
(tensorId: number) => backend.releaseTensorId(tensorId),
// jsepEnsureTensor
async (tensorId: number, onnxDataType: number, dimensions: number[], copyOld) =>
backend.ensureTensor(tensorId, onnxDataType, dimensions, copyOld),
async (tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
backend.ensureTensor(tensorId, onnxDataType, shape, copyOld),
// jsepUploadTensor
(tensorId: number, data: Uint8Array) => {
backend.uploadTensor(tensorId, data);
Expand Down
42 changes: 21 additions & 21 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export interface TensorManager {
ensureTensor(
tensorId: TensorId,
dataType: MLOperandDataType,
dimensions: readonly number[],
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor>;
/**
Expand All @@ -48,7 +48,7 @@ export interface TensorManager {
/**
* Register an externally created MLTensor with a given MLContext and return a TensorId.
*/
registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, dimensions: number[]): TensorId;
registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
}

let tensorGuid = 1;
Expand All @@ -60,8 +60,8 @@ export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]];
* TensorTracker tracks the MLTensor and pending upload data.
*
* We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until
* we know the data type and dimensions. This is because future implementations of WebNN will only support creating
* MLTensors with dataTypes and dimensions.
* we know the data type and shape. This is because future implementations of WebNN will only support creating
* MLTensors with dataTypes and shape.
*/
class TensorTracker {
private tensorEntry?: MLTensorEntry;
Expand Down Expand Up @@ -103,12 +103,12 @@ class TensorTracker {
}

public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean {
for (const [mlTensor, dataType, dimensions] of this.tensorCache) {
for (const [mlTensor, dataType, shape] of this.tensorCache) {
if (tryMLTensor === mlTensor) {
if (this.context !== context) {
throw new Error('MLTensor cannot be registered with a different MLContext.');
}
this.tensorEntry = [mlTensor, dataType, dimensions];
this.tensorEntry = [mlTensor, dataType, shape];
return true;
}
}
Expand All @@ -117,39 +117,39 @@ class TensorTracker {

public async ensureTensor(
dataType: MLOperandDataType,
dimensions: readonly number[],
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor> {
if (this.tensorEntry) {
const [mlTensor, existingDataType, existingDimensions] = this.tensorEntry;
if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) {
const [mlTensor, existingDataType, existingShape] = this.tensorEntry;
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
return mlTensor;
}
}

for (const [mlTensor, existingDataType, existingDimensions] of this.tensorCache) {
if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) {
for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) {
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
if (copyOld && this.tensorEntry) {
// WebNN does not support copyTensorToTensor, so we need to read and write the tensors.
LOG_DEBUG(
'verbose',
() =>
`[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${
dataType
}, dimensions: ${dimensions}}`,
}, shape: ${shape}}`,
);
const data = await this.context.readTensor(this.tensorEntry[0]);
this.context.writeTensor(mlTensor, data);
}
this.tensorEntry = [mlTensor, existingDataType, existingDimensions];
this.tensorEntry = [mlTensor, existingDataType, existingShape];
return mlTensor;
}
}
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, dimensions: ${dimensions}}`);
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
// eslint-disable-next-line no-bitwise
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
const tensor = await this.context.createTensor({ dataType, dimensions, usage });
this.tensorEntry = [tensor, dataType, dimensions];
const tensor = await this.context.createTensor({ dataType, shape, usage });
this.tensorEntry = [tensor, dataType, shape];
this.tensorCache.push(this.tensorEntry);

if (this.activeUpload) {
Expand Down Expand Up @@ -225,15 +225,15 @@ class TensorManagerImpl implements TensorManager {
public async ensureTensor(
tensorId: TensorId,
dataType: MLOperandDataType,
dimensions: number[],
shape: number[],
copyOld: boolean,
): Promise<MLTensor> {
LOG_DEBUG(
'verbose',
() =>
`[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${
dataType
}, dimensions: ${dimensions}, copyOld: ${copyOld}}`,
}, shape: ${shape}, copyOld: ${copyOld}}`,
);
const tensor = this.tensorsById.get(tensorId);
if (!tensor) {
Expand All @@ -244,7 +244,7 @@ class TensorManagerImpl implements TensorManager {
this.tensorIdsByContext.set(this.backend.currentContext, new Set());
}
this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId);
return tensor.ensureTensor(dataType, dimensions, copyOld);
return tensor.ensureTensor(dataType, shape, copyOld);
}

public upload(tensorId: TensorId, data: Uint8Array): void {
Expand Down Expand Up @@ -277,15 +277,15 @@ class TensorManagerImpl implements TensorManager {
mlContext: MLContext,
mlTensor: MLTensor,
dataType: MLOperandDataType,
dimensions: readonly number[],
shape: readonly number[],
): TensorId {
for (const [tensorId, tensorTracker] of this.tensorsById) {
if (tensorTracker.trySelectTensor(mlContext, mlTensor)) {
return tensorId;
}
}
const tensorId = createNewTensorId();
this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, dimensions]));
this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape]));
let tensors = this.tensorIdsByContext.get(mlContext);
if (!tensors) {
tensors = new Set();
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type MLInputOperandLayout = 'nchw'|'nhwc';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8';
interface MLOperandDescriptor {
dataType: MLOperandDataType;
dimensions?: readonly number[];
shape?: readonly number[];
}
interface MLOperand {
dataType(): MLOperandDataType;
Expand Down Expand Up @@ -405,7 +405,7 @@ interface MLContext {
createTensor(descriptor: MLTensorDescriptor): Promise<MLTensor>;
writeTensor(
destinationTensor: MLTensor, sourceData: ArrayBufferView|ArrayBuffer, sourceElementOffset?: number,
srcElementSize?: number): void;
sourceElementSize?: number): void;
readTensor(sourceTensor: MLTensor): Promise<ArrayBuffer>;
readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise<undefined>;
dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void;
Expand Down
7 changes: 3 additions & 4 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export declare namespace JSEP {
type EnsureTensorFunction = (
tensorId: number,
dataType: DataType,
dimensions: readonly number[],
shape: readonly number[],
copyOld: boolean,
) => Promise<MLTensor>;
type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void;
Expand Down Expand Up @@ -183,21 +183,20 @@ export declare namespace JSEP {
* [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID.
* @param tensorId - specify the MLTensor ID.
* @param onnxDataType - specify the data type.
* @param dimensions - specify the dimensions.
* @param shape - specify the dimensions (WebNN shape) of the tensor.
* @param copyOld - specify whether to copy the old tensor if a new tensor was created.
* @returns the MLTensor associated with the tensor ID.
*/
jsepEnsureTensor: (
tensorId: number,
dataType: DataType,
dimensions: number[],
shape: number[],
copyOld: boolean,
) => Promise<MLTensor>;
/**
* [exported from pre-jsep.js] Upload data to an MLTensor.
* @param tensorId - specify the MLTensor ID.
* @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType.
* @param dimensions - specify the dimensions.
* @returns
*/
jsepUploadTensor: (tensorId: number, data: Uint8Array) => void;
Expand Down
4 changes: 2 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty

const mlTensor = await mlContext.createTensor({
dataType,
dimensions: dims as number[],
shape: dims as number[],
usage: MLTensorUsage.READ,
});

Expand All @@ -685,7 +685,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type;
const mlTensor = await mlContext.createTensor({
dataType,
dimensions: cpuTensor.dims as number[],
shape: cpuTensor.dims as number[],
usage: MLTensorUsage.WRITE,
});
mlContext.writeTensor(mlTensor, cpuTensor.data);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ Module['jsepInit'] = (name, params) => {
Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => {
return backend['createMLTensorDownloader'](tensorId, type);
}
Module['jsepRegisterMLTensor'] = (tensor, dataType, dimensions) => {
return backend['registerMLTensor'](tensor, dataType, dimensions);
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
return backend['registerMLTensor'](tensor, dataType, shape);
}
}
};

0 comments on commit bbbb0b1

Please sign in to comment.