diff --git a/lib/backends/cpu/ops/pool.ts b/lib/backends/cpu/ops/pool.ts index 3f36f98c..f74001a4 100644 --- a/lib/backends/cpu/ops/pool.ts +++ b/lib/backends/cpu/ops/pool.ts @@ -105,7 +105,6 @@ export function pool( const curInd = ShapeUtil.offsetToIndices(ind, outputStride); const startInd = curInd.slice(0); const x = curInd.slice(0); - let offset = new Array(stridesRank); // calculate the start indices of kernel corresponding to current output indices for (let i = 0; i < stridesRank; i++) { startInd[rank - stridesRank + i] = curInd[rank - stridesRank + i] * strides[i]; @@ -115,7 +114,7 @@ export function pool( let isPad = false; // loop through elements within kernel for (let i = 0; i < kernelSize; i++) { - offset = ShapeUtil.offsetToIndices(i, kernelStrides); + const offset = ShapeUtil.offsetToIndices(i, kernelStrides); isPad = false; // "Shift" the kernel by the kernel start indices to loop through the kernel mapped to current output indices for (let j = rank - stridesRank; j < rank; j++) { diff --git a/lib/backends/cpu/ops/softmax.ts b/lib/backends/cpu/ops/softmax.ts index 36878014..af4494a2 100644 --- a/lib/backends/cpu/ops/softmax.ts +++ b/lib/backends/cpu/ops/softmax.ts @@ -17,7 +17,7 @@ export function softmax(x: Tensor, axis: number): Tensor { const inputDimensions = x.dims; const inputRank = inputDimensions.length; - const axisCorrected = util.ShapeUtil.getActualAxisFromNegativeValue(axis, inputRank); + const axisCorrected = util.ShapeUtil.parseAxis(axis, inputRank); const N = util.ShapeUtil.sizeToDimension(inputDimensions, axisCorrected); const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axisCorrected); diff --git a/lib/backends/cpu/ops/transpose.ts b/lib/backends/cpu/ops/transpose.ts index 35345028..1daad713 100644 --- a/lib/backends/cpu/ops/transpose.ts +++ b/lib/backends/cpu/ops/transpose.ts @@ -89,7 +89,7 @@ function doTranspose( let startTargetIndex = 0; for (let i = 0; i < numBlocks; ++i) { - const sizeOffset = ShapeUtil.computeOffset(targetIndex, stride, numAxes); + const sizeOffset = ShapeUtil.indicesToOffset(targetIndex, stride, numAxes); arrayCopyHelper(target, source, startTargetIndex, startSourceIndex + sizeOffset, numElementsInBlock); ShapeUtil.incrementIndex(targetIndex, targetDims, numAxes); @@ -108,7 +108,7 @@ function doTransposeEltWise( let startTargetIndex = 0; for (let i = 0; i < numBlocks; ++i) { - const sourceOffset = ShapeUtil.computeOffset(targetIndex, stride, numAxes); + const sourceOffset = ShapeUtil.indicesToOffset(targetIndex, stride, numAxes); target[startTargetIndex++] = source[sourceOffset]; ShapeUtil.incrementIndex(targetIndex, targetDims, numAxes); } diff --git a/lib/backends/wasm/ops/softmax.ts b/lib/backends/wasm/ops/softmax.ts index bcdc8baa..a50595a2 100644 --- a/lib/backends/wasm/ops/softmax.ts +++ b/lib/backends/wasm/ops/softmax.ts @@ -10,7 +10,7 @@ import {WasmInferenceHandler} from '../inference-handler'; export class WasmSoftmax extends Softmax { run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { const x = inputs[0]; - const axisCorrected = ShapeUtil.getActualAxisFromNegativeValue(this.axis, x.dims.length); + const axisCorrected = ShapeUtil.parseAxis(this.axis, x.dims.length); const N = ShapeUtil.sizeToDimension(x.dims, axisCorrected); const D = ShapeUtil.sizeFromDimension(x.dims, axisCorrected); const y = new Tensor(x.dims, x.type); diff --git a/lib/backends/webgl/glsl-definitions.ts b/lib/backends/webgl/glsl-definitions.ts index 11e14150..2eabbc5c 100644 --- a/lib/backends/webgl/glsl-definitions.ts +++ b/lib/backends/webgl/glsl-definitions.ts @@ -16,8 +16,8 @@ export interface GlslFunction { } export interface GlslValueFunction extends GlslFunction {} export interface GlslPositionalFunction extends GlslFunction { - inputShape: number[]; - outputShape: number[]; + inputShape: ReadonlyArray; + outputShape: ReadonlyArray; } export class GlslContext { diff --git a/lib/backends/webgl/glsl-shape-utils-lib.ts b/lib/backends/webgl/glsl-shape-utils-lib.ts index 804247b6..f792dbee 100644 --- a/lib/backends/webgl/glsl-shape-utils-lib.ts +++ b/lib/backends/webgl/glsl-shape-utils-lib.ts @@ -95,7 +95,7 @@ export class ShapeUtilsGlslLib extends GlslLib { }); return result; } - static indexToOffsetSingle(name: string, rank: number, strides: number[]): string { + static indexToOffsetSingle(name: string, rank: number, strides: ReadonlyArray): string { let block = ''; for (let i = rank - 1; i >= 0; --i) { block += ` @@ -126,7 +126,7 @@ export class ShapeUtilsGlslLib extends GlslLib { }); return result; } - static offsetToIndicesSingle(name: string, rank: number, strides: number[]): string { + static offsetToIndicesSingle(name: string, rank: number, strides: ReadonlyArray): string { const stridesBlock = []; for (let i = 0; i < rank - 1; ++i) { stridesBlock.push(` diff --git a/lib/backends/webgl/inference-handler.ts b/lib/backends/webgl/inference-handler.ts index c8c4dc3f..fc6bc027 100644 --- a/lib/backends/webgl/inference-handler.ts +++ b/lib/backends/webgl/inference-handler.ts @@ -73,7 +73,7 @@ export class WebGLInferenceHandler implements InferenceHandler { } return tensor; } - getOrCreateTextureLayout(tensor: Tensor, channels = 1, unpackedShape?: number[]): TextureLayout { + getOrCreateTextureLayout(tensor: Tensor, channels = 1, unpackedShape?: ReadonlyArray): TextureLayout { const td = this.getTextureData(tensor); if (td) { return td; @@ -87,8 +87,8 @@ export class WebGLInferenceHandler implements InferenceHandler { this.textureToTensor = new Map(); } createTextureData( - dataType: Tensor.DataType, shape: number[], strides?: number[], data?: Tensor.NumberType, channels?: number, - width?: number, height?: number): TextureData { + dataType: Tensor.DataType, shape: ReadonlyArray, strides?: ReadonlyArray, + data?: Tensor.NumberType, channels?: number, width?: number, height?: number): TextureData { Logger.verbose('InferenceHandler', `Creating TextureData: shape:[${shape}], channels:${channels ? channels : 1}`); const td = this.textureManager.createTexture(dataType, shape, strides, data, channels, width, height); return td; @@ -98,8 +98,9 @@ export class WebGLInferenceHandler implements InferenceHandler { const td = this.textureManager.createTextureFromLayout(dataType, layout, data); return td; } - createBasicTextureLayout(shape: number[], channels = 1, unpackedShape?: number[], prefs?: WidthHeightPrefs): - TextureLayout { + createBasicTextureLayout( + shape: ReadonlyArray, channels = 1, unpackedShape?: ReadonlyArray, + prefs?: WidthHeightPrefs): TextureLayout { const [width, height] = this.session.layoutStrategy.computeTextureWH(shape, prefs); if (channels === 1) { unpackedShape = shape; diff --git a/lib/backends/webgl/ops/pad.ts b/lib/backends/webgl/ops/pad.ts index dae7c1b8..cd16e254 100644 --- a/lib/backends/webgl/ops/pad.ts +++ b/lib/backends/webgl/ops/pad.ts @@ -57,8 +57,8 @@ export function getPadFunction( } } function getPadConstant( - name: string, shape: ReadonlyArray, strides: number[], width: number, height: number, pads: number[], - value: number) { + name: string, shape: ReadonlyArray, strides: ReadonlyArray, width: number, height: number, + pads: number[], value: number) { const rank = shape.length; let block = ''; for (let i = rank - 1; i >= 0; --i) { @@ -82,7 +82,8 @@ function getPadConstant( `; } function getPadReflect( - name: string, shape: ReadonlyArray, strides: number[], width: number, height: number, pads: number[]) { + name: string, shape: ReadonlyArray, strides: ReadonlyArray, width: number, height: number, + pads: number[]) { const rank = shape.length; let block = ''; @@ -110,7 +111,8 @@ function getPadReflect( `; } function getPadEdge( - name: string, shape: ReadonlyArray, strides: number[], width: number, height: number, pads: number[]) { + name: string, shape: ReadonlyArray, strides: ReadonlyArray, width: number, height: number, + pads: number[]) { const rank = shape.length; let block = ''; diff --git a/lib/backends/webgl/ops/reshape.ts b/lib/backends/webgl/ops/reshape.ts index f874a307..cb48e8e2 100644 --- a/lib/backends/webgl/ops/reshape.ts +++ b/lib/backends/webgl/ops/reshape.ts @@ -15,7 +15,8 @@ export class WebGLReshape extends Reshape { this.outputShape = attributes.getInts('output_shape', []); this.dimsToKeep = attributes.getInts('dims_to_keep', []); } - getOutputShape(inferenceHandler: WebGLInferenceHandler, inputShapes: number[][]): number[] { + getOutputShape(inferenceHandler: WebGLInferenceHandler, inputShapes: Array>): + ReadonlyArray { if (inputShapes.length >= 2) { return inputShapes[1]; } @@ -32,7 +33,7 @@ export class WebGLReshape extends Reshape { } } run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { - const inputShape = inputs[0].dims.slice(); + const inputShape: ReadonlyArray = inputs[0].dims.slice(); const inputShapes = [inputShape]; if (inputs.length > 1) { inputShapes.push(ShapeUtil.calculateReshapedDims(inputShape, inputs[1].integerData)); @@ -55,7 +56,7 @@ export class WebGLReshape extends Reshape { arrayType: inputTD.arrayType })]; } - getPositionalFunction(inferenceHandler: WebGLInferenceHandler, inputShape: number[], name?: string): + getPositionalFunction(inferenceHandler: WebGLInferenceHandler, inputShape: ReadonlyArray, name?: string): GlslPositionalFunction { const outputShape = this.getOutputShape(inferenceHandler, [inputShape]); if (!name) { @@ -69,7 +70,8 @@ export class WebGLReshape extends Reshape { outputShape }; } - protected getReshapeFunctionBody(name: string, inputShape: number[], outputShape: number[]): string { + protected getReshapeFunctionBody(name: string, inputShape: ReadonlyArray, outputShape: ReadonlyArray): + string { const inputStrides = ShapeUtil.computeStrides(inputShape); const outputStrides = ShapeUtil.computeStrides(outputShape); return ` diff --git a/lib/backends/webgl/ops/softmax.ts b/lib/backends/webgl/ops/softmax.ts index 71e2d2d9..9abb94d7 100644 --- a/lib/backends/webgl/ops/softmax.ts +++ b/lib/backends/webgl/ops/softmax.ts @@ -185,7 +185,7 @@ export class WebGLSoftmax extends Softmax { } createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] { const inputShape = inputs[0].dims.slice(); - const axisCorrected = ShapeUtil.getActualAxisFromNegativeValue(this.axis, inputShape.length); + const axisCorrected = ShapeUtil.parseAxis(this.axis, inputShape.length); const N = ShapeUtil.sizeToDimension(inputShape, axisCorrected); const D = ShapeUtil.sizeFromDimension(inputShape, axisCorrected); const computeMaxProgramInfo = this.createComputeMaxProgramInfo(inferenceHandler, inputs[0], N, D, [N]); diff --git a/lib/backends/webgl/ops/transpose.ts b/lib/backends/webgl/ops/transpose.ts index fe3566f5..5c7bf46b 100644 --- a/lib/backends/webgl/ops/transpose.ts +++ b/lib/backends/webgl/ops/transpose.ts @@ -15,7 +15,7 @@ export class WebGLTranspose extends Transpose implements WebGLOperator, Position run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { return WebGLOperatorHelper.run(this, inferenceHandler, inputs); } - getOutputShape(handler: WebGLInferenceHandler, inputShapes: number[][]): number[] { + getOutputShape(handler: WebGLInferenceHandler, inputShapes: Array>): ReadonlyArray { const perm = this.getAdjustedPerm(inputShapes[0]); return ShapeUtil.sortBasedOnPerm(inputShapes[0], perm); } @@ -70,7 +70,7 @@ export class WebGLTranspose extends Transpose implements WebGLOperator, Position addPositionalSub(positionalSubOperator: PositionalSubOperator): void { this.positionalSubs.push(positionalSubOperator); } - protected getAdjustedPerm(inputShape: number[]): number[] { + protected getAdjustedPerm(inputShape: ReadonlyArray): number[] { let perm = this.perm; if (perm && perm.length !== inputShape.length) { perm = [...(inputShape.keys())].reverse(); diff --git a/lib/backends/webgl/texture-data.ts b/lib/backends/webgl/texture-data.ts index f1eb943b..2d04cea4 100644 --- a/lib/backends/webgl/texture-data.ts +++ b/lib/backends/webgl/texture-data.ts @@ -14,9 +14,9 @@ export interface TextureLayout { width: number; height: number; channels: number; - shape: number[]; - strides: number[]; - unpackedShape: number[]; + shape: ReadonlyArray; + strides: ReadonlyArray; + unpackedShape: ReadonlyArray; } export interface TextureData extends TextureLayout { dataType: Tensor.DataType; diff --git a/lib/backends/webgl/texture-layout-strategy.ts b/lib/backends/webgl/texture-layout-strategy.ts index 475a0e1a..62f8331a 100644 --- a/lib/backends/webgl/texture-layout-strategy.ts +++ b/lib/backends/webgl/texture-layout-strategy.ts @@ -12,7 +12,7 @@ export interface WidthHeightPrefs { * for mapping n-dimensional arrays to 2D textures (and back) */ export interface TextureLayoutStrategy { - computeTextureWH(shape: number[], prefs?: WidthHeightPrefs): [number, number]; + computeTextureWH(shape: ReadonlyArray, prefs?: WidthHeightPrefs): [number, number]; } /** @@ -20,7 +20,7 @@ export interface TextureLayoutStrategy { */ export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy { constructor(public maxTextureSize: number) {} - computeTextureWH(shape: number[], prefs?: WidthHeightPrefs): [number, number] { + computeTextureWH(shape: ReadonlyArray, prefs?: WidthHeightPrefs): [number, number] { const maxTextureSize = this.maxTextureSize; if (prefs) { // check to see if dims fit diff --git a/lib/backends/webgl/texture-manager.ts b/lib/backends/webgl/texture-manager.ts index 38a8cd80..c562718d 100644 --- a/lib/backends/webgl/texture-manager.ts +++ b/lib/backends/webgl/texture-manager.ts @@ -55,8 +55,9 @@ export class TextureManager { return {...layout, dataType, texture, arrayType: textureDataType}; } createTexture( - dataType: Tensor.DataType, shape: number[], strides?: number[], data?: Tensor.NumberType, channels?: number, - width?: number, height?: number, unpackedShape?: number[]): TextureData { + dataType: Tensor.DataType, shape: ReadonlyArray, strides?: ReadonlyArray, + data?: Tensor.NumberType, channels?: number, width?: number, height?: number, + unpackedShape?: ReadonlyArray): TextureData { return this.profiler.event('backend', 'TextureManager.createTexture', () => { if (!width || !height) { [width, height] = this.layoutStrategy.computeTextureWH(shape); @@ -86,7 +87,7 @@ export class TextureManager { return this.toTensorData(dataType, data); }); } - saveTexture(texture: WebGLTexture, dims: number[]): void { + saveTexture(texture: WebGLTexture, dims: ReadonlyArray): void { return this.profiler.event('backend', 'TextureManager.saveTexture', () => { const size = `${dims[0]}-${dims[1]}`; Logger.verbose('TextureManager', `caching texture of size ${size}`); diff --git a/lib/backends/webgl/utils.ts b/lib/backends/webgl/utils.ts index ae694ecf..9b1f8a88 100644 --- a/lib/backends/webgl/utils.ts +++ b/lib/backends/webgl/utils.ts @@ -12,7 +12,7 @@ export function using(resource: T, func: (resource: T) => resource.dispose(); } } -export function expandArray(shape: number[], newLength: number, fill: number) { +export function expandArray(shape: ReadonlyArray, newLength: number, fill: number) { if (shape.length === newLength) { return shape; } @@ -27,7 +27,7 @@ export function expandArray(shape: number[], newLength: number, fill: number) { * NOTE: it is always the last dim that gets packed. * @param unpackedShape original shape to create a packed version from */ -export function getPackedShape(unpackedShape: number[]): number[] { +export function getPackedShape(unpackedShape: ReadonlyArray): ReadonlyArray { const len = unpackedShape.length; return unpackedShape.slice(0, len - 1).concat(unpackedShape[len - 1] / 4); } diff --git a/lib/backends/webgl/webgl-operator-utils.ts b/lib/backends/webgl/webgl-operator-utils.ts index 9b964f7a..7c3a79c9 100644 --- a/lib/backends/webgl/webgl-operator-utils.ts +++ b/lib/backends/webgl/webgl-operator-utils.ts @@ -23,8 +23,8 @@ export class WebGLOperatorHelper { return [inferenceHandler.getTensor(runData.outputTextureData)]; } static getFinalLayout( - inferenceHandler: WebGLInferenceHandler, positionalSubFunctions: GlslPositionalFunction[], outputShape: number[], - channels: number, prefs?: WidthHeightPrefs): TextureLayout { + inferenceHandler: WebGLInferenceHandler, positionalSubFunctions: GlslPositionalFunction[], + outputShape: ReadonlyArray, channels: number, prefs?: WidthHeightPrefs): TextureLayout { let finalShape = outputShape; if (positionalSubFunctions.length > 0) { finalShape = positionalSubFunctions[positionalSubFunctions.length - 1].outputShape; @@ -35,7 +35,7 @@ export class WebGLOperatorHelper { } static getPositionalFunctions( inferenceHandler: WebGLInferenceHandler, subOperators: PositionalSubOperator[], - outputShape: number[]): GlslPositionalFunction[] { + outputShape: ReadonlyArray): GlslPositionalFunction[] { let shape = outputShape; if (subOperators && subOperators.length > 0) { const result = new Array(subOperators.length); diff --git a/lib/backends/webgl/webgl-operator.ts b/lib/backends/webgl/webgl-operator.ts index 7d481fee..682a910f 100644 --- a/lib/backends/webgl/webgl-operator.ts +++ b/lib/backends/webgl/webgl-operator.ts @@ -10,7 +10,8 @@ import {ProgramInfo} from './program-info'; import {RunData} from './program-manager'; export interface PositionalSubOperator extends Operator { - getPositionalFunction(handler: WebGLInferenceHandler, inputShape: number[], name?: string): GlslPositionalFunction; + getPositionalFunction(handler: WebGLInferenceHandler, inputShape: ReadonlyArray, name?: string): + GlslPositionalFunction; } export interface WebGLRunnable extends Operator { addPositionalSub(positionalSubOperator: PositionalSubOperator): void; diff --git a/lib/util.ts b/lib/util.ts index 524d517e..35271298 100644 --- a/lib/util.ts +++ b/lib/util.ts @@ -78,23 +78,23 @@ export class BroadcastUtil { /** * Given the indices of a broadcasted tensor, calculate the original indices * @param indices The given indices of the broadcasted tensor. - * @param shapeOrigin The origin shape of the tensor before broadcast + * @param originalShape The original shape of the tensor before broadcast * @param isMatMul Whether the operation is MatMul * @returns The calculated indices that maps to the original tensor. If the * operation is MatMul, the indices of last 2 dimensions will keep as same as * input indices */ - static index(indices: number[], shapeOrigin: number[], isMatMul = false): number[] { + static index(indices: ReadonlyArray, originalShape: ReadonlyArray, isMatMul = false): number[] { // we assume the parameter indices is valid. ie. it should have the same // length as the broadcasted shape, and for each dimension the index should // not be out of range. - const dimOffset = indices.length - shapeOrigin.length; - const indicesOrigin = indices.slice(dimOffset); - const dimLen = isMatMul ? indicesOrigin.length - 2 : indicesOrigin.length; + const dimOffset = indices.length - originalShape.length; + const indicesOriginal = indices.slice(dimOffset); + const dimLen = isMatMul ? indicesOriginal.length - 2 : indicesOriginal.length; for (let i = 0; i < dimLen; i++) { - indicesOrigin[i] = indices[dimOffset + i] % shapeOrigin[i]; + indicesOriginal[i] = indices[dimOffset + i] % originalShape[i]; } - return indicesOrigin; + return indicesOriginal; } /** @@ -188,7 +188,7 @@ export class GemmUtil { // will throw exception if the input shapes are not compatible static getShapeOfGemmResult( leftShape: ReadonlyArray, transLeft: boolean, rightShape: ReadonlyArray, transRight: boolean, - biasShape: ReadonlyArray): number[] { + biasShape: ReadonlyArray): ReadonlyArray { if (leftShape.length !== 2 || rightShape.length !== 2) { throw new Error(`shape need to be of size 2`); } @@ -338,19 +338,19 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + // `axis` inclusive static sizeFromDimension(dims: ReadonlyArray, axis: number): number { - if (axis > dims.length) { + if (axis < 0 || axis >= dims.length) { throw new Error(`invalid dimension of ${axis} for sizeFromDimension as Tensor has ${dims.length} dimensions.`); } - return ShapeUtil.getSizeFromDimensionRange(dims, axis, dims.length); } + // `axis` exclusive static sizeToDimension(dims: ReadonlyArray, axis: number): number { - if (axis > dims.length) { + if (axis < 0 || axis > dims.length) { throw new Error(`invalid dimension of ${axis} for sizeToDimension as Tensor has ${dims.length} dimensions.`); } - return ShapeUtil.getSizeFromDimensionRange(dims, 0, axis); } @@ -369,59 +369,42 @@ export class ShapeUtil { return size; } - // Computes the offset up until the start index for the specified axis - /** - * @param index Given index to compute offset for in the flattened - * @param stride The strides of the tensor corresponding to the index - * @param axis The 1-indexed axis upto which the offset is to be computed for. If undefined, axis == rank of the - * index. - */ - - static computeOffset(index: number[], stride: number[], axis?: number) { - if (axis === undefined) { - axis = index.length; - } - let offset = 0; - for (let i = 0; i < axis; ++i) { - offset += (index[i] * stride[i]); - } - return offset; - } - static computeStrides(shape: ReadonlyArray): number[] { - const rank = shape.length; + static computeStrides(dims: ReadonlyArray): ReadonlyArray { + const rank = dims.length; if (rank < 2) { return [1]; } - const strides = new Array(rank); strides[rank - 1] = 1; - strides[rank - 2] = shape[rank - 1]; + strides[rank - 2] = dims[rank - 1]; for (let i = rank - 3; i >= 0; --i) { - strides[i] = strides[i + 1] * shape[i + 1]; + strides[i] = strides[i + 1] * dims[i + 1]; } return strides; } - static transpose(dims: number[]): number[] { - return dims.reverse(); + + static transpose(dims: ReadonlyArray): ReadonlyArray { + const copy = dims.slice(); + return copy.reverse(); } - static indicesToOffset(indices: number[], strides: number[]): number { - const rank = strides.length; - if (rank === 0) { - return 0; + + static indicesToOffset(indices: ReadonlyArray, strides: ReadonlyArray, axis?: number): number { + if (axis === undefined) { + axis = indices.length; } - let index = indices[indices.length - 1]; - for (let i = 0; i < indices.length - 1; ++i) { - index += strides[i] * indices[i]; + let offset = 0; + for (let i = 0; i < axis; ++i) { + offset += strides[i] * indices[i]; } - return index; + return offset; } - static offsetToIndices(offset: number, strides: number[]): number[] { + static offsetToIndices(offset: number, strides: ReadonlyArray): ReadonlyArray { const rank = strides.length; if (rank === 0) { return []; } else if (rank === 1) { - return [offset]; + return [offset * strides[0]]; } const indices: number[] = new Array(strides.length); for (let i = 0; i < indices.length - 1; ++i) { @@ -431,8 +414,9 @@ export class ShapeUtil { indices[indices.length - 1] = offset; return indices; } - static getActualAxisFromNegativeValue(axis: number, tensorRank: number): number { - if (axis < -tensorRank && axis > (tensorRank - 1)) { + + static parseAxis(axis: number, tensorRank: number): number { + if (axis < -tensorRank && axis >= tensorRank) { throw new Error('unsupported axis for this operation.'); } return axis < 0 ? axis + tensorRank : axis; @@ -442,13 +426,20 @@ export class ShapeUtil { // ordering), wrapping around the specified upper_bound. /** * Increment an index into a tensor (in lexicographic ordering), wrapping around the specified upper_bound. - * @param index Given index to increment + * @param index Given index to increment (Will be mutated) * @param dims The dimensions of the tensor for which the given index corresponds to * @param axisToIncrementOn The 1-indexed axis to increment on. If undefined, axisToIncrementOn == rank */ - static incrementIndex(index: number[], dims: number[], axisToIncrementOn?: number) { + static incrementIndex(index: number[], dims: ReadonlyArray, axisToIncrementOn?: number) { + if (dims.length === 0 || index.length === 0) { + throw new Error(`Index incrementing unsupported for scalar Tensor`); + } if (axisToIncrementOn === undefined) { axisToIncrementOn = dims.length; + } else { + if (axisToIncrementOn <= 0 || axisToIncrementOn > dims.length) { + throw new Error(`Incorrect axis to increment on`); + } } for (let k = axisToIncrementOn - 1; k >= 0; --k) { @@ -459,6 +450,7 @@ export class ShapeUtil { index[k] = 0; } } + /** * Produces a new dimensions array based on the values in the 'originalDimensions' and 'shape' array * Used in Reshape @@ -471,7 +463,9 @@ export class ShapeUtil { * https://github.com/onnx/onnx/blob/master/docs/Operators.md#Reshape */ - static calculateReshapedDims(originalDims: ReadonlyArray, shapeHints: number[]|Tensor.IntegerType): number[] { + static calculateReshapedDims( + originalDims: ReadonlyArray, + shapeHints: number[]|ReadonlyArray|Tensor.IntegerType): ReadonlyArray { const nDims = shapeHints.length; const reshapedDims = new Array(nDims); let unknownDimension = -1; @@ -516,28 +510,30 @@ export class ShapeUtil { * @param a Array to be sorted such as dims or strides * @param perm Perm given; if null a will be reversed */ - static sortBasedOnPerm(a: ReadonlyArray, perm?: number[]): number[] { + static sortBasedOnPerm(a: ReadonlyArray, perm?: ReadonlyArray): ReadonlyArray { if (perm) { return perm.map((v) => a[v]); } else { return a.slice().reverse(); } } + /** * Pads a given shape according to the padding values * @param dims shape of the Tensor to be padded * @param pad pad values */ - static padShape(dims: ReadonlyArray, pad: number[]): number[] { + static padShape(dims: ReadonlyArray, pad: ReadonlyArray): ReadonlyArray { const rank = dims.length; return dims.map((v, i) => v + pad[i] + pad[i + rank]); } + /** * Determines if the two shapes are identical * @param shape1 * @param shape2 */ - static areEqual(shape1: number[], shape2: number[]): boolean { + static areEqual(shape1: ReadonlyArray, shape2: ReadonlyArray): boolean { if (shape1.length !== shape2.length) { return false; } @@ -724,11 +720,10 @@ export class ReduceUtil { // loop through the output and calculate result one by one const size = ShapeUtil.size(outputDims); const ndY = ndarray(new Array(size), outputDims); - let indices = new Array(outputDims.length); const strides = ShapeUtil.computeStrides(outputDims); const inputStrides = ShapeUtil.computeStrides(dims); for (let i = 0; i < size; i++) { - indices = ShapeUtil.offsetToIndices(i, strides); + const indices = ShapeUtil.offsetToIndices(i, strides); // map index const indicesY = BroadcastUtil.index(indices, dims); ndY.set(