Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
misc: make shapes/strides ReadonlyArray<number> in code base (#36)
Browse files Browse the repository at this point in the history
* Update methods in the code base to just consume ReadonlyArray wherever dims or strides are 
  encountered and corresponding changes. This should prevent accidental mutation.

* Remove redundant function computeOffset() in ShapeUtil class. It's functionality is already provided 
   by indicesToOffset().
  • Loading branch information
hariharans29 authored Dec 6, 2018
1 parent e7ba5d8 commit c2ef813
Show file tree
Hide file tree
Showing 18 changed files with 98 additions and 97 deletions.
3 changes: 1 addition & 2 deletions lib/backends/cpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ export function pool(
const curInd = ShapeUtil.offsetToIndices(ind, outputStride);
const startInd = curInd.slice(0);
const x = curInd.slice(0);
let offset = new Array<number>(stridesRank);
// calculate the start indices of kernel corresponding to current output indices
for (let i = 0; i < stridesRank; i++) {
startInd[rank - stridesRank + i] = curInd[rank - stridesRank + i] * strides[i];
Expand All @@ -115,7 +114,7 @@ export function pool(
let isPad = false;
// loop through elements within kernel
for (let i = 0; i < kernelSize; i++) {
offset = ShapeUtil.offsetToIndices(i, kernelStrides);
const offset = ShapeUtil.offsetToIndices(i, kernelStrides);
isPad = false;
// "Shift" the kernel by the kernel start indices to loop through the kernel mapped to current output indices
for (let j = rank - stridesRank; j < rank; j++) {
Expand Down
2 changes: 1 addition & 1 deletion lib/backends/cpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export function softmax(x: Tensor, axis: number): Tensor {
const inputDimensions = x.dims;
const inputRank = inputDimensions.length;

const axisCorrected = util.ShapeUtil.getActualAxisFromNegativeValue(axis, inputRank);
const axisCorrected = util.ShapeUtil.parseAxis(axis, inputRank);
const N = util.ShapeUtil.sizeToDimension(inputDimensions, axisCorrected);
const D = util.ShapeUtil.sizeFromDimension(inputDimensions, axisCorrected);

Expand Down
4 changes: 2 additions & 2 deletions lib/backends/cpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function doTranspose(
let startTargetIndex = 0;

for (let i = 0; i < numBlocks; ++i) {
const sizeOffset = ShapeUtil.computeOffset(targetIndex, stride, numAxes);
const sizeOffset = ShapeUtil.indicesToOffset(targetIndex, stride, numAxes);
arrayCopyHelper(target, source, startTargetIndex, startSourceIndex + sizeOffset, numElementsInBlock);

ShapeUtil.incrementIndex(targetIndex, targetDims, numAxes);
Expand All @@ -108,7 +108,7 @@ function doTransposeEltWise(
let startTargetIndex = 0;

for (let i = 0; i < numBlocks; ++i) {
const sourceOffset = ShapeUtil.computeOffset(targetIndex, stride, numAxes);
const sourceOffset = ShapeUtil.indicesToOffset(targetIndex, stride, numAxes);
target[startTargetIndex++] = source[sourceOffset];
ShapeUtil.incrementIndex(targetIndex, targetDims, numAxes);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/backends/wasm/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {WasmInferenceHandler} from '../inference-handler';
export class WasmSoftmax extends Softmax {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const x = inputs[0];
const axisCorrected = ShapeUtil.getActualAxisFromNegativeValue(this.axis, x.dims.length);
const axisCorrected = ShapeUtil.parseAxis(this.axis, x.dims.length);
const N = ShapeUtil.sizeToDimension(x.dims, axisCorrected);
const D = ShapeUtil.sizeFromDimension(x.dims, axisCorrected);
const y = new Tensor(x.dims, x.type);
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/glsl-definitions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export interface GlslFunction<T extends FunctionType> {
}
export interface GlslValueFunction extends GlslFunction<FunctionType.ValueBased> {}
export interface GlslPositionalFunction extends GlslFunction<FunctionType.Positional> {
inputShape: number[];
outputShape: number[];
inputShape: ReadonlyArray<number>;
outputShape: ReadonlyArray<number>;
}

export class GlslContext {
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/glsl-shape-utils-lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ export class ShapeUtilsGlslLib extends GlslLib {
});
return result;
}
static indexToOffsetSingle(name: string, rank: number, strides: number[]): string {
static indexToOffsetSingle(name: string, rank: number, strides: ReadonlyArray<number>): string {
let block = '';
for (let i = rank - 1; i >= 0; --i) {
block += `
Expand Down Expand Up @@ -126,7 +126,7 @@ export class ShapeUtilsGlslLib extends GlslLib {
});
return result;
}
static offsetToIndicesSingle(name: string, rank: number, strides: number[]): string {
static offsetToIndicesSingle(name: string, rank: number, strides: ReadonlyArray<number>): string {
const stridesBlock = [];
for (let i = 0; i < rank - 1; ++i) {
stridesBlock.push(`
Expand Down
11 changes: 6 additions & 5 deletions lib/backends/webgl/inference-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export class WebGLInferenceHandler implements InferenceHandler {
}
return tensor;
}
getOrCreateTextureLayout(tensor: Tensor, channels = 1, unpackedShape?: number[]): TextureLayout {
getOrCreateTextureLayout(tensor: Tensor, channels = 1, unpackedShape?: ReadonlyArray<number>): TextureLayout {
const td = this.getTextureData(tensor);
if (td) {
return td;
Expand All @@ -87,8 +87,8 @@ export class WebGLInferenceHandler implements InferenceHandler {
this.textureToTensor = new Map();
}
createTextureData(
dataType: Tensor.DataType, shape: number[], strides?: number[], data?: Tensor.NumberType, channels?: number,
width?: number, height?: number): TextureData {
dataType: Tensor.DataType, shape: ReadonlyArray<number>, strides?: ReadonlyArray<number>,
data?: Tensor.NumberType, channels?: number, width?: number, height?: number): TextureData {
Logger.verbose('InferenceHandler', `Creating TextureData: shape:[${shape}], channels:${channels ? channels : 1}`);
const td = this.textureManager.createTexture(dataType, shape, strides, data, channels, width, height);
return td;
Expand All @@ -98,8 +98,9 @@ export class WebGLInferenceHandler implements InferenceHandler {
const td = this.textureManager.createTextureFromLayout(dataType, layout, data);
return td;
}
createBasicTextureLayout(shape: number[], channels = 1, unpackedShape?: number[], prefs?: WidthHeightPrefs):
TextureLayout {
createBasicTextureLayout(
shape: ReadonlyArray<number>, channels = 1, unpackedShape?: ReadonlyArray<number>,
prefs?: WidthHeightPrefs): TextureLayout {
const [width, height] = this.session.layoutStrategy.computeTextureWH(shape, prefs);
if (channels === 1) {
unpackedShape = shape;
Expand Down
10 changes: 6 additions & 4 deletions lib/backends/webgl/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ export function getPadFunction(
}
}
function getPadConstant(
name: string, shape: ReadonlyArray<number>, strides: number[], width: number, height: number, pads: number[],
value: number) {
name: string, shape: ReadonlyArray<number>, strides: ReadonlyArray<number>, width: number, height: number,
pads: number[], value: number) {
const rank = shape.length;
let block = '';
for (let i = rank - 1; i >= 0; --i) {
Expand All @@ -82,7 +82,8 @@ function getPadConstant(
`;
}
function getPadReflect(
name: string, shape: ReadonlyArray<number>, strides: number[], width: number, height: number, pads: number[]) {
name: string, shape: ReadonlyArray<number>, strides: ReadonlyArray<number>, width: number, height: number,
pads: number[]) {
const rank = shape.length;

let block = '';
Expand Down Expand Up @@ -110,7 +111,8 @@ function getPadReflect(
`;
}
function getPadEdge(
name: string, shape: ReadonlyArray<number>, strides: number[], width: number, height: number, pads: number[]) {
name: string, shape: ReadonlyArray<number>, strides: ReadonlyArray<number>, width: number, height: number,
pads: number[]) {
const rank = shape.length;

let block = '';
Expand Down
10 changes: 6 additions & 4 deletions lib/backends/webgl/ops/reshape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ export class WebGLReshape extends Reshape {
this.outputShape = attributes.getInts('output_shape', []);
this.dimsToKeep = attributes.getInts('dims_to_keep', []);
}
getOutputShape(inferenceHandler: WebGLInferenceHandler, inputShapes: number[][]): number[] {
getOutputShape(inferenceHandler: WebGLInferenceHandler, inputShapes: Array<ReadonlyArray<number>>):
ReadonlyArray<number> {
if (inputShapes.length >= 2) {
return inputShapes[1];
}
Expand All @@ -32,7 +33,7 @@ export class WebGLReshape extends Reshape {
}
}
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const inputShape = inputs[0].dims.slice();
const inputShape: ReadonlyArray<number> = inputs[0].dims.slice();
const inputShapes = [inputShape];
if (inputs.length > 1) {
inputShapes.push(ShapeUtil.calculateReshapedDims(inputShape, inputs[1].integerData));
Expand All @@ -55,7 +56,7 @@ export class WebGLReshape extends Reshape {
arrayType: inputTD.arrayType
})];
}
getPositionalFunction(inferenceHandler: WebGLInferenceHandler, inputShape: number[], name?: string):
getPositionalFunction(inferenceHandler: WebGLInferenceHandler, inputShape: ReadonlyArray<number>, name?: string):
GlslPositionalFunction {
const outputShape = this.getOutputShape(inferenceHandler, [inputShape]);
if (!name) {
Expand All @@ -69,7 +70,8 @@ export class WebGLReshape extends Reshape {
outputShape
};
}
protected getReshapeFunctionBody(name: string, inputShape: number[], outputShape: number[]): string {
protected getReshapeFunctionBody(name: string, inputShape: ReadonlyArray<number>, outputShape: ReadonlyArray<number>):
string {
const inputStrides = ShapeUtil.computeStrides(inputShape);
const outputStrides = ShapeUtil.computeStrides(outputShape);
return `
Expand Down
2 changes: 1 addition & 1 deletion lib/backends/webgl/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ export class WebGLSoftmax extends Softmax {
}
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] {
const inputShape = inputs[0].dims.slice();
const axisCorrected = ShapeUtil.getActualAxisFromNegativeValue(this.axis, inputShape.length);
const axisCorrected = ShapeUtil.parseAxis(this.axis, inputShape.length);
const N = ShapeUtil.sizeToDimension(inputShape, axisCorrected);
const D = ShapeUtil.sizeFromDimension(inputShape, axisCorrected);
const computeMaxProgramInfo = this.createComputeMaxProgramInfo(inferenceHandler, inputs[0], N, D, [N]);
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export class WebGLTranspose extends Transpose implements WebGLOperator, Position
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return WebGLOperatorHelper.run(this, inferenceHandler, inputs);
}
getOutputShape(handler: WebGLInferenceHandler, inputShapes: number[][]): number[] {
getOutputShape(handler: WebGLInferenceHandler, inputShapes: Array<ReadonlyArray<number>>): ReadonlyArray<number> {
const perm = this.getAdjustedPerm(inputShapes[0]);
return ShapeUtil.sortBasedOnPerm(inputShapes[0], perm);
}
Expand Down Expand Up @@ -70,7 +70,7 @@ export class WebGLTranspose extends Transpose implements WebGLOperator, Position
addPositionalSub(positionalSubOperator: PositionalSubOperator): void {
this.positionalSubs.push(positionalSubOperator);
}
protected getAdjustedPerm(inputShape: number[]): number[] {
protected getAdjustedPerm(inputShape: ReadonlyArray<number>): number[] {
let perm = this.perm;
if (perm && perm.length !== inputShape.length) {
perm = [...(inputShape.keys())].reverse();
Expand Down
6 changes: 3 additions & 3 deletions lib/backends/webgl/texture-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ export interface TextureLayout {
width: number;
height: number;
channels: number;
shape: number[];
strides: number[];
unpackedShape: number[];
shape: ReadonlyArray<number>;
strides: ReadonlyArray<number>;
unpackedShape: ReadonlyArray<number>;
}
export interface TextureData extends TextureLayout {
dataType: Tensor.DataType;
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/texture-layout-strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ export interface WidthHeightPrefs {
* for mapping n-dimensional arrays to 2D textures (and back)
*/
export interface TextureLayoutStrategy {
computeTextureWH(shape: number[], prefs?: WidthHeightPrefs): [number, number];
computeTextureWH(shape: ReadonlyArray<number>, prefs?: WidthHeightPrefs): [number, number];
}

/**
* This strategy try to find the minimal max(W,H) that fulfills (W * H == totalSize)
*/
export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy {
constructor(public maxTextureSize: number) {}
computeTextureWH(shape: number[], prefs?: WidthHeightPrefs): [number, number] {
computeTextureWH(shape: ReadonlyArray<number>, prefs?: WidthHeightPrefs): [number, number] {
const maxTextureSize = this.maxTextureSize;
if (prefs) {
// check to see if dims fit
Expand Down
7 changes: 4 additions & 3 deletions lib/backends/webgl/texture-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ export class TextureManager {
return {...layout, dataType, texture, arrayType: textureDataType};
}
createTexture(
dataType: Tensor.DataType, shape: number[], strides?: number[], data?: Tensor.NumberType, channels?: number,
width?: number, height?: number, unpackedShape?: number[]): TextureData {
dataType: Tensor.DataType, shape: ReadonlyArray<number>, strides?: ReadonlyArray<number>,
data?: Tensor.NumberType, channels?: number, width?: number, height?: number,
unpackedShape?: ReadonlyArray<number>): TextureData {
return this.profiler.event('backend', 'TextureManager.createTexture', () => {
if (!width || !height) {
[width, height] = this.layoutStrategy.computeTextureWH(shape);
Expand Down Expand Up @@ -86,7 +87,7 @@ export class TextureManager {
return this.toTensorData(dataType, data);
});
}
saveTexture(texture: WebGLTexture, dims: number[]): void {
saveTexture(texture: WebGLTexture, dims: ReadonlyArray<number>): void {
return this.profiler.event('backend', 'TextureManager.saveTexture', () => {
const size = `${dims[0]}-${dims[1]}`;
Logger.verbose('TextureManager', `caching texture of size ${size}`);
Expand Down
4 changes: 2 additions & 2 deletions lib/backends/webgl/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export function using<T extends Disposable>(resource: T, func: (resource: T) =>
resource.dispose();
}
}
export function expandArray(shape: number[], newLength: number, fill: number) {
export function expandArray(shape: ReadonlyArray<number>, newLength: number, fill: number) {
if (shape.length === newLength) {
return shape;
}
Expand All @@ -27,7 +27,7 @@ export function expandArray(shape: number[], newLength: number, fill: number) {
* NOTE: it is always the last dim that gets packed.
* @param unpackedShape original shape to create a packed version from
*/
export function getPackedShape(unpackedShape: number[]): number[] {
export function getPackedShape(unpackedShape: ReadonlyArray<number>): ReadonlyArray<number> {
const len = unpackedShape.length;
return unpackedShape.slice(0, len - 1).concat(unpackedShape[len - 1] / 4);
}
6 changes: 3 additions & 3 deletions lib/backends/webgl/webgl-operator-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ export class WebGLOperatorHelper {
return [inferenceHandler.getTensor(runData.outputTextureData)];
}
static getFinalLayout(
inferenceHandler: WebGLInferenceHandler, positionalSubFunctions: GlslPositionalFunction[], outputShape: number[],
channels: number, prefs?: WidthHeightPrefs): TextureLayout {
inferenceHandler: WebGLInferenceHandler, positionalSubFunctions: GlslPositionalFunction[],
outputShape: ReadonlyArray<number>, channels: number, prefs?: WidthHeightPrefs): TextureLayout {
let finalShape = outputShape;
if (positionalSubFunctions.length > 0) {
finalShape = positionalSubFunctions[positionalSubFunctions.length - 1].outputShape;
Expand All @@ -35,7 +35,7 @@ export class WebGLOperatorHelper {
}
static getPositionalFunctions(
inferenceHandler: WebGLInferenceHandler, subOperators: PositionalSubOperator[],
outputShape: number[]): GlslPositionalFunction[] {
outputShape: ReadonlyArray<number>): GlslPositionalFunction[] {
let shape = outputShape;
if (subOperators && subOperators.length > 0) {
const result = new Array<GlslPositionalFunction>(subOperators.length);
Expand Down
3 changes: 2 additions & 1 deletion lib/backends/webgl/webgl-operator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import {ProgramInfo} from './program-info';
import {RunData} from './program-manager';

export interface PositionalSubOperator extends Operator {
getPositionalFunction(handler: WebGLInferenceHandler, inputShape: number[], name?: string): GlslPositionalFunction;
getPositionalFunction(handler: WebGLInferenceHandler, inputShape: ReadonlyArray<number>, name?: string):
GlslPositionalFunction;
}
export interface WebGLRunnable extends Operator {
addPositionalSub(positionalSubOperator: PositionalSubOperator): void;
Expand Down
Loading

0 comments on commit c2ef813

Please sign in to comment.