diff --git a/.eslintrc.js b/.eslintrc.js index 0492c52..24f3b46 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -14,6 +14,7 @@ module.exports = { 'indent': 2, 'require-jsdoc': 'off', 'max-len': ['error', {'code': 100}], + 'prefer-rest-params': 'off' }, extends: [ 'eslint:recommended', diff --git a/src/batch_normalization.js b/src/batch_normalization.js index 06b2a92..2e2e206 100644 --- a/src/batch_normalization.js +++ b/src/batch_normalization.js @@ -2,26 +2,8 @@ import {add, sub, mul, div, pow} from './binary.js'; import {reshape} from './reshape.js'; -import {Tensor, Scalar} from './tensor.js'; - -/** - * Check the tensor whether it is a 1-D tensor and its length is equal to `expectedSize`. - * @param {Tensor} a - * @param {Number} expectedSize - * @param {String} name - */ -function check1DTensorWithSize(a, expectedSize, name) { - if (a) { - if (a.rank !== 1) { - throw new Error(`The parameter ${name} is not a 1-D tensor.`); - } else { - if (a.shape[0] !== expectedSize) { - throw new Error(`The length ${a.shape[0]} of the ${name} values is not equal to the ` + - `size ${expectedSize} of the input dimension denoted by options.axis.`); - } - } - } -} +import {Tensor, Scalar} from './lib/tensor.js'; +import {validateBatchNormalizationParams} from './lib/validate-input.js'; /** * Normalize the tensor values of input features across the batch dimension using @@ -32,25 +14,10 @@ function check1DTensorWithSize(a, expectedSize, name) { * @param {MLBatchNormalizationOptions} [options] * @return {Tensor} */ -export function batchNormalization(input, mean, variance, options = {}) { - let axis = options.axis; - if (axis !== undefined) { - if (!Number.isInteger(axis)) { - throw new Error(`Invalid axis ${axis}, axis should be an integer.`); - } - } else { - axis = 1; - } - const dim = input.shape[axis]; - check1DTensorWithSize(mean, dim, 'mean'); - check1DTensorWithSize(variance, dim, 'variance'); - const scale = options.scale; - check1DTensorWithSize(scale, dim, 'scale'); - const bias = options.bias; - check1DTensorWithSize(bias, dim, 'bias'); - const epsilon = options.epsilon ? options.epsilon : 1e-5; - const activation = options.activation; - // The output tensor of the same shape as the input tensor. +export function batchNormalization(input, mean, variance, {axis=1, scale, bias, epsilon=1e-5, + activation = (x) => x} = {}) { + validateBatchNormalizationParams(...arguments); + // The output tensor has the same shape as the input tensor. let output = new Tensor(input.shape); const shape = new Array(input.rank).fill(1); shape[axis] = -1; @@ -63,8 +30,6 @@ export function batchNormalization(input, mean, variance, options = {}) { if (bias) { output = add(output, reshape(bias, shape)); } - if (activation) { - output = activation(output); - } + output = activation(output); return output; } diff --git a/src/binary.js b/src/binary.js index 2d79d2e..8f4db79 100644 --- a/src/binary.js +++ b/src/binary.js @@ -1,7 +1,7 @@ 'use strict'; -import {broadcast, getBroadcastShape} from './broadcast.js'; -import {Tensor, sizeOfShape} from './tensor.js'; +import {broadcast, getBroadcastShape} from './lib/broadcast.js'; +import {Tensor, sizeOfShape} from './lib/tensor.js'; /** * Compute the element-wise binary operation of two input tensors. diff --git a/src/clamp.js b/src/clamp.js index 37d378c..5866b75 100644 --- a/src/clamp.js +++ b/src/clamp.js @@ -1,6 +1,6 @@ 'use strict'; -import {Tensor} from './tensor.js'; +import {Tensor} from './lib/tensor.js'; /** * Clamp the input tensor element-wise within a range specified by the minimum and maximum values. @@ -8,24 +8,11 @@ import {Tensor} from './tensor.js'; * @param {MLClampOptions} [options] * @return {Tensor} */ -export function clamp(input, options = {}) { +export function clamp(input, {minValue=-Infinity, maxValue=Infinity} = {}) { const output = new Tensor(input.shape); for (let i = 0; i < input.size; ++i) { const x = input.getValueByIndex(i); - let y; - if (options.minValue === undefined) { - if (options.maxValue === undefined) { - y = x; - } else { - y = Math.min(x, options.maxValue); - } - } else { - if (options.maxValue === undefined) { - y = Math.max(x, options.minValue); - } else { - y = Math.min(Math.max(x, options.minValue), options.maxValue); - } - } + const y = Math.min(Math.max(x, minValue), maxValue); output.setValueByIndex(i, y); } return output; diff --git a/src/concat.js b/src/concat.js index bf17195..b322430 100644 --- a/src/concat.js +++ b/src/concat.js @@ -1,6 +1,7 @@ 'use strict'; -import {Tensor, sizeOfShape} from './tensor.js'; +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateConcatParams} from './lib/validate-input.js'; /** * Concatenates the input tensors along a given axis. @@ -9,32 +10,11 @@ import {Tensor, sizeOfShape} from './tensor.js'; * @return {Tensor} */ export function concat(inputs, axis) { - const rank = inputs[0].rank; - if (!Number.isInteger(axis)) { - throw new Error(`Invalid axis ${axis}, axis should be an integer.`); - } else { - if (axis < 0 || axis >= rank) { - throw new Error(`Invalid axis ${axis}, axis should be in the interval [0, ${rank}).`); - } - } + validateConcatParams(...arguments); const inputShape = inputs[0].shape; const outputShape = inputShape.slice(); for (let i = 1; i < inputs.length; ++i) { - if (inputs[i].rank !== rank) { - throw new Error('All input tensors should have the same rank.'); - } else { - const shape = inputs[i].shape; - for (let j = 0; j < inputShape.length; ++j) { - if (j !== axis) { - if (inputShape[j] !== shape[j]) { - throw new Error('All input tensors should have the same shape, ' + - 'except for the size of the dimension to concatenate on.'); - } - } else { - outputShape[axis] += shape[axis]; - } - } - } + outputShape[axis] += inputs[i].shape[axis]; } const output = new Tensor(outputShape); for (let i = 0; i < sizeOfShape(outputShape); ++i) { diff --git a/src/conv2d.js b/src/conv2d.js index a821b27..25cc0b2 100644 --- a/src/conv2d.js +++ b/src/conv2d.js @@ -1,37 +1,10 @@ 'use strict'; -import {Tensor} from './tensor.js'; +import {Tensor} from './lib/tensor.js'; +import {validateConv2dParams} from './lib/validate-input.js'; +import {computePaddingForAutoPad} from './lib/compute-padding.js'; import {transpose} from './transpose.js'; -/** - * Compute the beginning and ending pad given input, filter and stride. - * @param {String} autoPad - * @param {Number} inputSize - * @param {Number} effectiveFilterSize - * @param {Number} stride - * @return {Array} [paddingBegin, paddingEnd] - */ -export function computePaddingForAutoPad(autoPad, inputSize, effectiveFilterSize, stride) { - const outSize = Math.ceil(inputSize / stride); - const neededInput = (outSize - 1) * stride + effectiveFilterSize; - const totalPadding = neededInput > inputSize ? neededInput - inputSize : 0; - let paddingBegin; - let paddingEnd; - switch (autoPad) { - case 'same-upper': - paddingBegin = Math.floor(totalPadding / 2); - paddingEnd = Math.floor((totalPadding + 1) / 2); - break; - case 'same-lower': - paddingBegin = Math.floor((totalPadding + 1) / 2); - paddingEnd = Math.floor(totalPadding / 2); - break; - default: - throw new Error('The autoPad is invalid.'); - } - return [paddingBegin, paddingEnd]; -} - /** * Compute a 2-D convolution given 4-D input and filter tensors. * @param {Tensor} input @@ -39,27 +12,21 @@ export function computePaddingForAutoPad(autoPad, inputSize, effectiveFilterSize * @param {MLConv2dOptions} options * @return {Tensor} */ -export function conv2d(input, filter, options = {}) { - if (input.rank !== 4) { - throw new Error('The input should be a 4-D tensor.'); - } - - if (filter.rank !== 4) { - throw new Error('The filter should be a 4-D tensor.'); - } - - const padding = options.padding ? options.padding : [0, 0, 0, 0]; - const strides = options.strides ? options.strides : [1, 1]; - const groups = options.groups ? options.groups : 1; - const dilations = options.dilations ? options.dilations : [1, 1]; - const activation = options.activation; - - const inputLayout = options.inputLayout ? options.inputLayout : 'nchw'; +export function conv2d(input, filter, {padding = [0, 0, 0, 0], + strides = [1, 1], + groups = 1, + dilations = [1, 1], + activation = (x) => x, + inputLayout = 'nchw', + filterLayout = 'oihw', + bias, + autoPad = 'explicit', +} += {}) { if (inputLayout === 'nhwc') { // nhwc -> nchw input = transpose(input, {permutation: [0, 3, 1, 2]}); } - const filterLayout = options.filterLayout ? options.filterLayout : 'oihw'; if (filterLayout === 'hwio') { // hwio -> oihw filter = transpose(filter, {permutation: [3, 2, 0, 1]}); @@ -70,43 +37,27 @@ export function conv2d(input, filter, options = {}) { // ihwo -> oihw filter = transpose(filter, {permutation: [3, 0, 1, 2]}); } + validateConv2dParams(input, filter, {groups, bias}); - const batchCount = input.shape[0]; - const inputChannels = input.shape[1]; - const inputHeight = input.shape[2]; - const inputWidth = input.shape[3]; - const outputChannels = filter.shape[0]; - const filterInputChannels = filter.shape[1]; - const filterHeight = filter.shape[2]; - const filterWidth = filter.shape[3]; - const strideHeight = strides[0]; - const strideWidth = strides[1]; - const dilationHeight = dilations[0]; - const dilationWidth = dilations[1]; + const [batchCount, inputChannels, inputHeight, inputWidth] = input.shape; + const [outputChannels, , filterHeight, filterWidth] = filter.shape; + const [strideHeight, strideWidth] = strides; + const [dilationHeight, dilationWidth] = dilations; const effectiveFilterHeight = filterHeight + (filterHeight - 1) * (dilationHeight - 1); const effectiveFilterWidth = filterWidth + (filterWidth - 1) * (dilationWidth - 1); - if (inputChannels !== filterInputChannels * groups) { - throw new Error('The input channels of filter is invalid.'); - } - - const bias = options.bias; - if (bias && (bias.rank !== 1 || bias.shape[0] != outputChannels)) { - throw new Error('the bias should be a 1-D tensor with the shape of [output_channels].'); - } - let beginningPaddingHeight; let endingPaddingHeight; let beginningPaddingWidth; let endingPaddingWidth; - if (options.autoPad === undefined || options.autoPad === 'explicit') { + if (autoPad === 'explicit') { [beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth] = padding; } else { [beginningPaddingHeight, endingPaddingHeight] = computePaddingForAutoPad( - options.autoPad, inputHeight, effectiveFilterHeight, strideHeight); + autoPad, inputHeight, effectiveFilterHeight, strideHeight); [beginningPaddingWidth, endingPaddingWidth] = computePaddingForAutoPad( - options.autoPad, inputWidth, effectiveFilterWidth, strideWidth); + autoPad, inputWidth, effectiveFilterWidth, strideWidth); } const outputShape = new Array(4); @@ -178,9 +129,7 @@ export function conv2d(input, filter, options = {}) { } } - if (activation) { - output = activation(output); - } + output = activation(output); if (inputLayout === 'nhwc') { // nchw -> nhwc diff --git a/src/gemm.js b/src/gemm.js index 2446dad..7e721a2 100644 --- a/src/gemm.js +++ b/src/gemm.js @@ -2,7 +2,8 @@ import {add, mul} from './binary.js'; import {matmul} from './matmul.js'; -import {Scalar} from './tensor.js'; +import {Scalar} from './lib/tensor.js'; +import {validateGemmParams} from './lib/validate-input.js'; import {transpose} from './transpose.js'; /** @@ -13,19 +14,15 @@ import {transpose} from './transpose.js'; * @param {MLGemmOptions} options * @return {Tensor} */ -export function gemm(a, b, options = {}) { - if (a.rank !== 2) { - throw new Error('The input a is not a 2-D tensor.'); - } - if (b.rank !== 2) { - throw new Error('The input b is not a 2-D tensor.'); - } - const c = options.c ? options.c : undefined; - const alpha = new Scalar(options.alpha ? options.alpha : 1.0); - const beta = new Scalar(options.beta ? options.beta : 1.0); - const aTranspose = options.aTranspose ? options.aTranspose : false; - const bTranspose = options.bTranspose ? options.bTranspose : false; - +export function gemm(a, b, {c = new Scalar(0.0), + alpha: fAlpha = 1.0, + beta: fBeta = 1.0, + aTranspose = false, + bTranspose = false, +} = {}) { + validateGemmParams(...arguments); + const alpha = new Scalar(fAlpha); + const beta = new Scalar(fBeta); if (aTranspose) { a = transpose(a); } @@ -35,7 +32,6 @@ export function gemm(a, b, options = {}) { } let output = matmul(mul(a, alpha), b); - if (c) { output = add(output, mul(c, beta)); } diff --git a/src/gru.js b/src/gru.js index 31004cf..137b407 100644 --- a/src/gru.js +++ b/src/gru.js @@ -3,11 +3,12 @@ import {concat} from './concat.js'; import {gruCell} from './gru_cell.js'; import {reshape} from './reshape.js'; -import {sizeOfShape, Tensor} from './tensor.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; import {slice} from './slice.js'; import {squeeze} from './squeeze.js'; import {tanh} from './tanh.js'; +import {validateGruParams} from './lib/validate-input.js'; /** * Gated Recurrent Unit [GRU] recurrent network using an update gate and a reset gate to compute @@ -20,89 +21,22 @@ import {tanh} from './tanh.js'; * @param {MLGruOptions} options * @return {Array.} */ -export function gru(input, weight, recurrentWeight, steps, hiddenSize, options = {}) { - const bias = options.bias; - const recurrentBias = options.recurrentBias; - const initialHiddenState = options.initialHiddenState; - const resetAfter = options.resetAfter !== undefined ? options.resetAfter : true; - const returnSequence = options.returnSequence !== undefined ? options.returnSequence : false; - const direction = options.direction !== undefined ? options.direction : 'forward'; - const layout = options.layout !== undefined ? options.layout : 'zrn'; - const activations = options.activations ? options.activations : [sigmoid, tanh]; - - if (!Number.isInteger(steps) || steps <= 0) { - throw new Error(`The steps ${steps} is invalid.`); - } - if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { - throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); - } - if (input.rank !== 3) { - throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`); - } - if (input.shape[0] !== steps) { - throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`); - } - const batchSize = input.shape[1]; - const inputSize = input.shape[2]; - if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') { - throw new Error(`The direction ${direction} is invalid.`); - } +export function gru(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, resetAfter = true, + returnSequence = false, direction = 'forward', + layout = 'zrn', activations = [sigmoid, tanh]} = {}) { + validateGruParams(...arguments); const numDirections = (direction === 'both' ? 2 : 1); - if (weight.rank !== 3) { - throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`); - } - if (weight.shape[0] !== numDirections || weight.shape[1] !== 3 * hiddenSize || - weight.shape[2] !== inputSize) { - throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}, - ${weight.shape[2]}] is invalid.`); - } - if (recurrentWeight.rank !== 3) { - throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`); - } - if (recurrentWeight.shape[0] !== numDirections || recurrentWeight.shape[1] !== 3 * hiddenSize || - recurrentWeight.shape[2] !== hiddenSize) { - throw new Error(`The shape of recurrentWeight ` + - `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` + - `${recurrentWeight.shape[2]}] is invalid.`); - } - if (bias) { - if (bias.rank !== 2) { - throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`); - } - if (bias.shape[0] !== numDirections || bias.shape[1] !== 3 * hiddenSize) { - throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`); - } - } - if (recurrentBias) { - if (recurrentBias.rank !== 2) { - throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`); - } - if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 3 * hiddenSize) { - throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}, - ${recurrentBias.shape[1]}] is invalid.`); - } - } + const batchSize = input.shape[1]; + let hiddenState; if (initialHiddenState) { - if (initialHiddenState.rank !== 3) { - throw new Error( - `The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`); - } - if (initialHiddenState.shape[0] !== numDirections || - initialHiddenState.shape[1] !== batchSize || - initialHiddenState.shape[2] !== hiddenSize) { - throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]}, - ${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`); - } hiddenState = initialHiddenState; } else { const initialHiddenStateShape = [numDirections, batchSize, hiddenSize]; hiddenState = new Tensor( initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0)); } - if (layout !== 'zrn' && layout !== 'rzn') { - throw new Error(`The layout ${layout} is invalid.`); - } let sequence; const cellWeight = []; diff --git a/src/gru_cell.js b/src/gru_cell.js index 9fa6a7b..d47a57f 100644 --- a/src/gru_cell.js +++ b/src/gru_cell.js @@ -2,11 +2,12 @@ import {add, mul, sub} from './binary.js'; import {matmul} from './matmul.js'; -import {Scalar} from './tensor.js'; +import {Scalar} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; import {slice} from './slice.js'; import {tanh} from './tanh.js'; import {transpose} from './transpose.js'; +import {validateGruCellParams} from './lib/validate-input.js'; /** * A single time step of the Gated Recurrent Unit [GRU] recurrent network using an update gate @@ -20,60 +21,10 @@ import {transpose} from './transpose.js'; * @param {MLGruCellOptions} options * @return {Tensor} */ -export function gruCell(input, weight, recurrentWeight, hiddenState, hiddenSize, options = {}) { - const bias = options.bias; - const recurrentBias = options.recurrentBias; - const resetAfter = options.resetAfter !== undefined ? options.resetAfter : true; - const layout = options.layout !== undefined ? options.layout : 'zrn'; - const activations = options.activations ? options.activations : [sigmoid, tanh]; - - if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { - throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); - } - if (input.rank !== 2) { - throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); - } - const batchSize = input.shape[0]; - const inputSize = input.shape[1]; - if (weight.rank !== 2) { - throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); - } - if (weight.shape[0] !== 3 * hiddenSize || weight.shape[1] !== inputSize) { - throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); - } - if (recurrentWeight.rank !== 2) { - throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); - } - if (recurrentWeight.shape[0] !== 3 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { - throw new Error(`The shape of recurrentWeight ` + - `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); - } - if (hiddenState.rank !== 2) { - throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); - } - if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { - throw new Error(`The shape of hiddenState - [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); - } - if (bias) { - if (bias.rank !== 1) { - throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); - } - if (bias.shape[0] !== 3 * hiddenSize) { - throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); - } - } - if (recurrentBias) { - if (recurrentBias.rank !== 1) { - throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); - } - if (recurrentBias.shape[0] !== 3 * hiddenSize) { - throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); - } - } - if (layout !== 'zrn' && layout !== 'rzn') { - throw new Error(`The layout ${layout} is invalid.`); - } +export function gruCell(input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter = true, + layout = 'zrn', activations = [sigmoid, tanh]} = {}) { + validateGruCellParams(...arguments); const one = new Scalar(1); const zero = new Scalar(0); diff --git a/src/broadcast.js b/src/lib/broadcast.js similarity index 100% rename from src/broadcast.js rename to src/lib/broadcast.js diff --git a/src/lib/compute-padding.js b/src/lib/compute-padding.js new file mode 100644 index 0000000..feb91ac --- /dev/null +++ b/src/lib/compute-padding.js @@ -0,0 +1,28 @@ +/** + * Compute the beginning and ending pad given input, filter and stride. + * @param {String} autoPad + * @param {Number} inputSize + * @param {Number} effectiveFilterSize + * @param {Number} stride + * @return {Array} [paddingBegin, paddingEnd] + */ +export function computePaddingForAutoPad(autoPad, inputSize, effectiveFilterSize, stride) { + const outSize = Math.ceil(inputSize / stride); + const neededInput = (outSize - 1) * stride + effectiveFilterSize; + const totalPadding = neededInput > inputSize ? neededInput - inputSize : 0; + let paddingBegin; + let paddingEnd; + switch (autoPad) { + case 'same-upper': + paddingBegin = Math.floor(totalPadding / 2); + paddingEnd = Math.floor((totalPadding + 1) / 2); + break; + case 'same-lower': + paddingBegin = Math.floor((totalPadding + 1) / 2); + paddingEnd = Math.floor(totalPadding / 2); + break; + default: + throw new Error('The autoPad is invalid.'); + } + return [paddingBegin, paddingEnd]; +} diff --git a/src/tensor.js b/src/lib/tensor.js similarity index 100% rename from src/tensor.js rename to src/lib/tensor.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js new file mode 100644 index 0000000..357083c --- /dev/null +++ b/src/lib/validate-input.js @@ -0,0 +1,376 @@ +'use strict'; + +/** + * Check the tensor whether it is a 1-D tensor and its length is equal to `expectedSize`. + * @param {Tensor} a + * @param {Number} expectedSize + * @param {String} name + */ +function check1DTensorWithSize(a, expectedSize, name) { + if (a) { + if (a.rank !== 1) { + throw new Error(`The parameter ${name} is not a 1-D tensor.`); + } else { + if (a.shape[0] !== expectedSize) { + throw new Error(`The length ${a.shape[0]} of the ${name} values is not equal to the ` + + `size ${expectedSize} of the input dimension denoted by options.axis.`); + } + } + } +} + +export function validateBatchNormalizationParams(input, mean, variance, + {axis=1, scale, bias} = {}) { + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axis ${axis}, axis should be an integer.`); + } + const dim = input.shape[axis]; + check1DTensorWithSize(mean, dim, 'mean'); + check1DTensorWithSize(variance, dim, 'variance'); + check1DTensorWithSize(scale, dim, 'scale'); + check1DTensorWithSize(bias, dim, 'bias'); +} + + +export function validateConcatParams(inputs, axis) { + const rank = inputs[0].rank; + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axis ${axis}, axis should be an integer.`); + } else { + if (axis < 0 || axis >= rank) { + throw new Error(`Invalid axis ${axis}, axis should be in the interval [0, ${rank}).`); + } + } + const inputShape = inputs[0].shape; + for (let i = 1; i < inputs.length; ++i) { + if (inputs[i].rank !== rank) { + throw new Error('All input tensors should have the same rank.'); + } else { + const shape = inputs[i].shape; + for (let j = 0; j < inputShape.length; ++j) { + if (j !== axis) { + if (inputShape[j] !== shape[j]) { + throw new Error('All input tensors should have the same shape, ' + + 'except for the size of the dimension to concatenate on.'); + } + } + } + } + } +} + +export function validateConv2dParams(input, filter, {bias, groups = 1}) { + const inputChannels = input.shape[1]; + const outputChannels = filter.shape[0]; + const filterInputChannels = filter.shape[1]; + if (input.rank !== 4) { + throw new Error('The input should be a 4-D tensor.'); + } + if (filter.rank !== 4) { + throw new Error('The filter should be a 4-D tensor.'); + } + if (inputChannels !== filterInputChannels * groups) { + throw new Error('The input channels of filter is invalid.'); + } + if (bias && (bias.rank !== 1 || bias.shape[0] != outputChannels)) { + throw new Error('the bias should be a 1-D tensor with the shape of [output_channels].'); + } +} + +export function validateGemmParams(a, b) { + if (a.rank !== 2) { + throw new Error('The input a is not a 2-D tensor.'); + } + if (b.rank !== 2) { + throw new Error('The input b is not a 2-D tensor.'); + } +} + +export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, layout = 'zrn'} = {}) { + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 2) { + throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); + } + const batchSize = input.shape[0]; + const inputSize = input.shape[1]; + if (weight.rank !== 2) { + throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); + } + if (weight.shape[0] !== 3 * hiddenSize || weight.shape[1] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); + } + if (recurrentWeight.rank !== 2) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); + } + if (recurrentWeight.shape[0] !== 3 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); + } + if (hiddenState.rank !== 2) { + throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); + } + if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { + throw new Error(`The shape of hiddenState + [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 1) { + throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (bias.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 1) { + throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (recurrentBias.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); + } + } + if (layout !== 'zrn' && layout !== 'rzn') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + +export function validateGruParams(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, + direction = 'forward', layout = 'zrn'}) { + if (!Number.isInteger(steps) || steps <= 0) { + throw new Error(`The steps ${steps} is invalid.`); + } + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 3) { + throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`); + } + if (input.shape[0] !== steps) { + throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`); + } + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') { + throw new Error(`The direction ${direction} is invalid.`); + } + const numDirections = (direction === 'both' ? 2 : 1); + if (weight.rank !== 3) { + throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`); + } + if (weight.shape[0] !== numDirections || weight.shape[1] !== 3 * hiddenSize || + weight.shape[2] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}, + ${weight.shape[2]}] is invalid.`); + } + if (recurrentWeight.rank !== 3) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`); + } + if (recurrentWeight.shape[0] !== numDirections || + recurrentWeight.shape[1] !== 3 * hiddenSize || + recurrentWeight.shape[2] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` + + `${recurrentWeight.shape[2]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 2) { + throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`); + } + if (bias.shape[0] !== numDirections || bias.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 2) { + throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`); + } + if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}, + ${recurrentBias.shape[1]}] is invalid.`); + } + } + if (initialHiddenState) { + if (initialHiddenState.rank !== 3) { + throw new Error( + `The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`); + } + if (initialHiddenState.shape[0] !== numDirections || + initialHiddenState.shape[1] !== batchSize || + initialHiddenState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]}, + ${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`); + } + } + if (layout !== 'zrn' && layout !== 'rzn') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + +export function validateMatmulParams(a, b) { + const aCols = a.shape[a.rank - 1]; + const bRows = b.shape[b.rank - 2]; + if (aCols !== bRows) { + throw new Error( + `The columns (${aCols}) of input a is not equal to rows (${bRows}) of input b.`); + } +} + +export function validatePool2dParams(input, _, {roundingType = 'floor'}) { + if (input.rank !== 4) { + throw new Error('The input should be a 4-D tensor.'); + } + if (roundingType !== 'floor' && roundingType !== 'ceil') { + throw new Error('The rounding type is invalid.'); + } +} + +export function validateReduceParams(input, _, {axes}) { + if (axes.length > input.rank) { + throw new Error(`The length ${axes.length} of axes is bigger` + + `than input rank ${input.rank}.`); + } + for (let i = 0; i < axes.length; ++i) { + if (axes[i] < 0 || axes[i] >= input.rank) { + throw new Error(`The value ${axes[i]} at axis ${i} of axes is invalid.`); + } + } +} + +export function validateSliceParams(input, starts, sizes, {axes} = {}) { + let inpAxes = axes; + const rank = input.rank; + const startsForAllAxes = new Array(rank).fill(0); + if (axes) { + if (axes.length > rank) { + throw new Error(`The length of axes ${axes.length} is greater than rank ${rank}.`); + } else { + for (const axis of axes) { + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axes value ${axis}, it should be an integer.`); + } else { + if (axis >= rank || axis < -rank) { + throw new Error(`Invalid axes value ${axis}, it should be in the interval ` + + `[${-rank}, ${rank}).`); + } + } + } + } + } else { + inpAxes = [...Array(rank).keys()]; + } + const axesLen = inpAxes.length; + if (starts.length !== axesLen) { + throw new Error(`The length ${starts.length} of starts is not equal to the length ` + + `${axesLen} of axes.`); + } + if (sizes.length !== axesLen) { + throw new Error(`The length ${sizes.length} of sizes is not equal` + + ` to the length ${axesLen} of axes.`); + } + for (let i = 0; i < axesLen; ++i) { + const axis = inpAxes[i] >= 0 ? inpAxes[i] : inpAxes[i] + rank; + const size = input.shape[axis]; + const start = starts[i]; + if (!Number.isInteger(start)) { + throw new Error(`Invalid starts value ${start}, it should be an integer.`); + } + startsForAllAxes[axis] = start >= 0 ? start : start + size; + if (start >= size || start < -size) { + throw new Error(`Invalid starts value ${start}, it shoule be in the interval ` + + `[${-size}, ${size}).`); + } else { + const sliceSize = sizes[i]; + if (!Number.isInteger(sliceSize)) { + throw new Error(`Invalid sizes value ${sliceSize}, it should be an integer.`); + } + if (sliceSize >= 0) { + if (start >= 0) { + if (start + sliceSize > size) { + throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + + `plus the size ${sliceSize} is greater than the dimensional size ${size}`); + } + } else { + if (start + sliceSize > 0) { + throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + + `plus the size ${sliceSize} is greater than the dimensional size ${size}`); + } + } + } else { + if (sliceSize !== -1) { + throw new Error(`The value ${sliceSize} of sizes is invalid,` + + ` it is required to be -1 when it is negative.`); + } + } + } + } +} + +export function validateSoftmaxParams(x) { + if (x.rank !== 2) { + throw new Error('The input is not a 2-D tensor.'); + } +} + +export function validateSplitParams(input, splits, {axis = 0} = {}) { + let inpAxis; + if (axis !== undefined) { + const rank = input.rank; + if (!Number.isInteger(axis)) { + throw new Error(`The axis ${axis} should be an integer.`); + } + if (axis >= rank || axis < -rank) { + throw new Error(`The axis ${axis} should be in the interval [${-rank}, ${rank}).`); + } + inpAxis = axis >= 0 ? axis : rank + axis; + } + if (typeof splits === 'number') { + if (!Number.isInteger(splits) || splits <= 0) { + throw new Error(`Invalid splits ${splits}, it should be a positive integer.`); + } + if (input.shape[inpAxis] % splits !== 0) { + throw new Error(`The splits ${splits} must evenly divide the dimension size ` + + `${input.shape[inpAxis]} of input along options.axis ${inpAxis}.`); + } + } else if (splits instanceof Array) { + if (!splits.every((v) => Number.isInteger(v) && v > 0)) { + throw new Error(`Invalid splits ${splits}, it should be an Array of positive integers.`); + } + const sum = splits.reduce((a, b) => a + b); + if (sum !== input.shape[inpAxis]) { + throw new Error(`Invalid [${splits}], the sum of sizes ${sum} must equal ` + + `to the dimension size ${input.shape[inpAxis]} of input` + + ` along options.axis ${inpAxis}`); + } + } +} + +export function validateSqueezeParams(input, {axes} = {}) { + if (axes) { + if (axes.length > input.rank) { + throw new Error(`The length of axes ${axes.length} is bigger ` + + `than input rank ${input.rank}.`); + } + + for (const axis of axes) { + if (axis < 0 || axis >= input.rank) { + throw new Error(`The value of axes ${axis} is invalid.`); + } + if (axes && input.shape[axis] !== 1) { + throw new Error(`The value ${input.shape[axis]} ` + + `at axis ${axis} of input shape is not 1.`); + } + } + } +} + +export function validateTranposeParams(input, {permutation}) { + if (permutation.length !== input.rank) { + throw new Error( + `The permutation length ${permutation.length} is not equal to rank ${input.rank}.`); + } +} + diff --git a/src/matmul.js b/src/matmul.js index 41e18de..262c84c 100644 --- a/src/matmul.js +++ b/src/matmul.js @@ -1,9 +1,10 @@ 'use strict'; -import {broadcast, getBroadcastShape} from './broadcast.js'; +import {broadcast, getBroadcastShape} from './lib/broadcast.js'; import {reshape} from './reshape.js'; -import {sizeOfShape, Tensor} from './tensor.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateMatmulParams} from './lib/validate-input.js'; /** * Compute the matrix product of two input tensors. @@ -12,11 +13,7 @@ import {sizeOfShape, Tensor} from './tensor.js'; * @return {Tensor} */ export function matmul(a, b) { - let scalarOutput = false; - if (a.rank === 1 && b.rank === 1) { - scalarOutput = true; - } - + const scalarOutput = a.rank === 1 && b.rank === 1; if (a.rank === 1) { a = reshape(a, [1, a.shape[0]]); } @@ -26,13 +23,9 @@ export function matmul(a, b) { if (b.rank === 1) { b = reshape(b, [b.shape[0], 1]); } - const bRows = b.shape[b.rank - 2]; const bCols = b.shape[b.rank - 1]; - if (aCols !== bRows) { - throw new Error( - `The columns (${aCols}) of input a is not equal to rows (${bRows}) of input b.`); - } + validateMatmulParams(a, b); let cShape = [aRows, bCols]; if (a.rank > 2 || b.rank > 2) { diff --git a/src/pool2d.js b/src/pool2d.js index ea92d61..85fd174 100644 --- a/src/pool2d.js +++ b/src/pool2d.js @@ -1,9 +1,10 @@ 'use strict'; -import {computePaddingForAutoPad} from './conv2d.js'; -import {Tensor} from './tensor.js'; +import {computePaddingForAutoPad} from './lib/compute-padding.js'; +import {Tensor} from './lib/tensor.js'; import {transpose} from './transpose.js'; import {meanReducer, maxReducer} from './reduce.js'; +import {validatePool2dParams} from './lib/validate-input.js'; /** * Compute a reduction operation across all the elements within the @@ -13,38 +14,28 @@ import {meanReducer, maxReducer} from './reduce.js'; * @param {MLPool2dOptions} options * @return {Tensor} */ -function pool2d(input, reductionFunc, options = {}) { - if (input.rank !== 4) { - throw new Error('The input should be a 4-D tensor.'); - } - - const padding = options.padding ? options.padding : [0, 0, 0, 0]; - const strides = options.strides ? options.strides : [1, 1]; - const dilations = options.dilations ? options.dilations : [1, 1]; - const roundingType = options.roundingType ? options.roundingType : 'floor'; - if (roundingType !== 'floor' && roundingType !== 'ceil') { - throw new Error('The rounding type is invalid.'); - } +function pool2d(input, reductionFunc, + {padding = [0, 0, 0, 0], + strides = [1, 1], + dilations = [1, 1], + roundingType = 'floor', + layout = 'nchw', + windowDimensions, + autoPad = 'explicit', + outputSizes, + }= {}) { + validatePool2dParams(...arguments); const roundingFunc = roundingType === 'floor' ? Math.floor : Math.ceil; - const layout = options.layout ? options.layout : 'nchw'; if (layout === 'nhwc') { // nhwc -> nchw input = transpose(input, {permutation: [0, 3, 1, 2]}); } - const batchCount = input.shape[0]; - const channels = input.shape[1]; - const inputHeight = input.shape[2]; - const inputWidth = input.shape[3]; - const windowDimensions = options.windowDimensions ? - options.windowDimensions : [inputHeight, inputWidth]; - const windowHeight = windowDimensions[0]; - const windowWidth = windowDimensions[1]; - const strideHeight = strides[0]; - const strideWidth = strides[1]; - const dilationHeight = dilations[0]; - const dilationWidth = dilations[1]; + const [batchCount, channels, inputHeight, inputWidth] = input.shape; + const [windowHeight, windowWidth] = windowDimensions ?? [inputHeight, inputWidth]; + const [strideHeight, strideWidth] = strides; + const [dilationHeight, dilationWidth] = dilations; const effectiveWindowHeight = windowHeight + (windowHeight - 1) * (dilationHeight - 1); const effectiveWindowWidth = windowWidth + (windowWidth - 1) * (dilationWidth - 1); @@ -52,25 +43,25 @@ function pool2d(input, reductionFunc, options = {}) { let endingPaddingHeight; let beginningPaddingWidth; let endingPaddingWidth; - if (options.autoPad === undefined || options.autoPad === 'explicit') { + if (autoPad === 'explicit') { [beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth] = padding; } else { [beginningPaddingHeight, endingPaddingHeight] = computePaddingForAutoPad( - options.autoPad, inputHeight, effectiveWindowHeight, strideHeight); + autoPad, inputHeight, effectiveWindowHeight, strideHeight); [beginningPaddingWidth, endingPaddingWidth] = computePaddingForAutoPad( - options.autoPad, inputWidth, effectiveWindowWidth, strideWidth); + autoPad, inputWidth, effectiveWindowWidth, strideWidth); } const outputShape = new Array(4); outputShape[0] = batchCount; outputShape[1] = channels; - const outputHeight = options.outputSizes ? options.outputSizes[0] : + const outputHeight = outputSizes ? outputSizes[0] : roundingFunc( 1 + (inputHeight - effectiveWindowHeight + beginningPaddingHeight + endingPaddingHeight) / strideHeight); outputShape[2] = outputHeight; - const outputWidth = options.outputSizes ? options.outputSizes[1] : + const outputWidth = outputSizes ? outputSizes[1] : roundingFunc( 1 + (inputWidth - effectiveWindowWidth + beginningPaddingWidth + endingPaddingWidth) / strideWidth); diff --git a/src/reduce.js b/src/reduce.js index d976d7a..5f8aa27 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -1,7 +1,8 @@ 'use strict'; import {squeeze} from './squeeze.js'; -import {sizeOfShape, Tensor} from './tensor.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateReduceParams} from './lib/validate-input.js'; /** * Reduce the input along the dimensions given in axes. @@ -10,31 +11,24 @@ import {sizeOfShape, Tensor} from './tensor.js'; * @param {MLReduceOptions} options * @return {Tensor} */ -function reduce(input, reduceFunc, options = {}) { - const axes = options.axes ? options.axes : - new Array(input.rank).fill(0).map((_, i) => i); - const keepDimensions = options.keepDimensions ? options.keepDimensions : false; - - if (axes.length > input.rank) { - throw new Error(`The length ${axes.length} of axes is bigger than input rank ${input.rank}.`); - } +function reduce(input, reduceFunc, {keepDimensions = false, axes} = {}) { + const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); const outputShape = input.shape.slice(); - for (let i = 0; i < axes.length; ++i) { - if (axes[i] === -1) { - axes[i] = input.rank - 1; + for (let i = 0; i < inpAxes.length; ++i) { + if (inpAxes[i] === -1) { + inpAxes[i] = input.rank - 1; } - if (axes[i] < 0 || axes[i] >= input.rank) { - throw new Error(`The value ${axes[i]} at axis ${i} of axes is invalid.`); - } - outputShape[axes[i]] = 1; + outputShape[inpAxes[i]] = 1; } + validateReduceParams(input, reduceFunc, {keepDimensions, axes: inpAxes}); + // Calculate the "strides" across the reduction dimensions given in axes. - axes.sort((a, b) => a - b); - const reduceDims = axes.map((axis) => input.shape[axis]); + inpAxes.sort((a, b) => a - b); + const reduceDims = inpAxes.map((axis) => input.shape[axis]); const reduceElements = sizeOfShape(reduceDims); - const reduceStrides = new Array(axes.length); + const reduceStrides = new Array(inpAxes.length); reduceStrides[reduceStrides.length - 1] = 1; for (let i = reduceStrides.length - 2; i >= 0; --i) { reduceStrides[i] = reduceStrides[i + 1] * reduceDims[i + 1]; @@ -48,8 +42,8 @@ function reduce(input, reduceFunc, options = {}) { // Calculate the input location given index of elements to reduce. const inputLocation = output.locationFromIndex(outputIndex); let remainingReduceIndex = reduceIndex; - for (let i = 0; i < axes.length; ++i) { - const axis = axes[i]; + for (let i = 0; i < inpAxes.length; ++i) { + const axis = inpAxes[i]; inputLocation[axis] = Math.floor(remainingReduceIndex / reduceStrides[i]); remainingReduceIndex -= inputLocation[axis] * reduceStrides[i]; } diff --git a/src/reshape.js b/src/reshape.js index f80cd2a..e1a428f 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -1,6 +1,6 @@ 'use strict'; -import {Tensor, sizeOfShape} from './tensor.js'; +import {Tensor, sizeOfShape} from './lib/tensor.js'; /** * Alter the shape of a tensor to a new shape. diff --git a/src/slice.js b/src/slice.js index 3522040..90c3910 100644 --- a/src/slice.js +++ b/src/slice.js @@ -1,6 +1,7 @@ 'use strict'; -import {Tensor, sizeOfShape} from './tensor.js'; +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateSliceParams} from './lib/validate-input.js'; /** * Produce a slice of the input tensor. @@ -10,78 +11,24 @@ import {Tensor, sizeOfShape} from './tensor.js'; * @param {MLSliceOptions} options * @return {Tensor} */ -export function slice(input, starts, sizes, options = {}) { +export function slice(input, starts, sizes, {axes} = {}) { + validateSliceParams(...arguments); const rank = input.rank; const startsForAllAxes = new Array(rank).fill(0); - let axes = options.axes; - if (axes) { - if (axes.length > rank) { - throw new Error(`The length of axes ${axes.length} is greater than rank ${rank}.`); - } else { - for (const axis of axes) { - if (!Number.isInteger(axis)) { - throw new Error(`Invalid axes value ${axis}, it should be an integer.`); - } else { - if (axis >= rank || axis < -rank) { - throw new Error(`Invalid axes value ${axis}, it should be in the interval ` + - `[${-rank}, ${rank}).`); - } - } - } - } - } else { - axes = [...Array(rank).keys()]; - } + + axes = axes ?? [...Array(rank).keys()]; const axesLen = axes.length; - if (starts.length !== axesLen) { - throw new Error(`The length ${starts.length} of starts is not equal to the length ` + - `${axesLen} of axes.`); - } - if (sizes.length !== axesLen) { - throw new Error(`The length ${sizes.length} of sizes is not equal to the length ${axesLen} ` + - 'of axes.'); - } const outputShape = input.shape.slice(); for (let i = 0; i < axesLen; ++i) { const axis = axes[i] >= 0 ? axes[i] : axes[i] + rank; const size = input.shape[axis]; const start = starts[i]; - if (!Number.isInteger(start)) { - throw new Error(`Invalid starts value ${start}, it should be an integer.`); - } startsForAllAxes[axis] = start >= 0 ? start : start + size; - if (start >= size || start < -size) { - throw new Error(`Invalid starts value ${start}, it shoule be in the interval ` + - `[${-size}, ${size}).`); + const sliceSize = sizes[i]; + if (sliceSize >= 0) { + outputShape[axis] = sliceSize; } else { - const sliceSize = sizes[i]; - if (!Number.isInteger(sliceSize)) { - throw new Error(`Invalid sizes value ${sliceSize}, it should be an integer.`); - } - if (sliceSize >= 0) { - if (start >= 0) { - if (start + sliceSize > size) { - throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + - `plus the size ${sliceSize} is greater than the dimensional size ${size}`); - } else { - outputShape[axis] = sliceSize; - } - } else { - if (start + sliceSize > 0) { - throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + - `plus the size ${sliceSize} is greater than the dimensional size ${size}`); - } else { - outputShape[axis] = sliceSize; - } - } - } else { - if (sliceSize !== -1) { - throw new Error(`The value ${sliceSize} of sizes is invalid, it is required to be -1 ` + - 'when it is negative.'); - } else { - outputShape[axis] = start >= 0 ? size - start : -start; - } - } + outputShape[axis] = start >= 0 ? size - start : -start; } } const output = new Tensor(outputShape); diff --git a/src/softmax.js b/src/softmax.js index 9c58330..356dddb 100644 --- a/src/softmax.js +++ b/src/softmax.js @@ -3,6 +3,7 @@ import {div, sub} from './binary.js'; import {exp} from './unary.js'; import {reduceMax, reduceSum} from './reduce.js'; +import {validateSoftmaxParams} from './lib/validate-input.js'; /** * Compute the softmax values of the 2-D input tensor along axis 1. @@ -10,9 +11,7 @@ import {reduceMax, reduceSum} from './reduce.js'; * @return {Tensor} */ export function softmax(x) { - if (x.rank !== 2) { - throw new Error('The input is not a 2-D tensor.'); - } + validateSoftmaxParams(...arguments); const maxX = reduceMax(x, {axes: [1], keepDimensions: true}); const expX = exp(sub(x, maxX)); return div(expX, reduceSum(expX, {axes: [1], keepDimensions: true})); diff --git a/src/split.js b/src/split.js index 0030e7d..f2c9685 100644 --- a/src/split.js +++ b/src/split.js @@ -1,6 +1,7 @@ 'use strict'; import {slice} from './slice.js'; +import {validateSplitParams} from './lib/validate-input.js'; /** * Split the input tensor into a number of sub tensors along the given axis. @@ -9,43 +10,20 @@ import {slice} from './slice.js'; * @param {MLSplitOptions} options * @return {Array.} */ -export function split(input, splits, options = {}) { +export function split(input, splits, {axis = 0} = {}) { + validateSplitParams(...arguments); const outputs = []; let sliceSizes = []; const rank = input.rank; - let axis = 0; - if (options.axis !== undefined) { - if (!Number.isInteger(options.axis)) { - throw new Error(`The axis ${options.axis} should be an integer.`); - } - if (options.axis >= rank || options.axis < -rank) { - throw new Error(`The axis ${options.axis} should be in the interval [${-rank}, ${rank}).`); - } - axis = options.axis >= 0 ? options.axis : rank + options.axis; - } + const inpAxis = axis >=0 ? axis : rank + axis; if (typeof splits === 'number') { - if (!Number.isInteger(splits) || splits <= 0) { - throw new Error(`Invalid splits ${splits}, it should be a positive integer.`); - } - if (input.shape[axis] % splits !== 0) { - throw new Error(`The splits ${splits} must evenly divide the dimension size ` + - `${input.shape[axis]} of input along options.axis ${options.axis}.`); - } - sliceSizes = new Array(splits).fill(input.shape[axis] / splits); + sliceSizes = new Array(splits).fill(input.shape[inpAxis] / splits); } else if (splits instanceof Array) { - if (!splits.every((v) => Number.isInteger(v) && v > 0)) { - throw new Error(`Invalid splits ${splits}, it should be an Array of positive integers.`); - } - const sum = splits.reduce((a, b) => a + b); - if (sum !== input.shape[axis]) { - throw new Error(`Invalid [${splits}], the sum of sizes ${sum} must equal to the dimension ` + - `size ${input.shape[axis]} of input along options.axis ${options.axis}`); - } sliceSizes = splits.slice(); } let start = 0; for (const size of sliceSizes) { - outputs.push(slice(input, [start], [size], {axes: [axis]})); + outputs.push(slice(input, [start], [size], {axes: [inpAxis]})); start += size; } return outputs; diff --git a/src/squeeze.js b/src/squeeze.js index 93d9bb6..a463f79 100644 --- a/src/squeeze.js +++ b/src/squeeze.js @@ -1,6 +1,7 @@ 'use strict'; import {reshape} from './reshape.js'; +import {validateSqueezeParams} from './lib/validate-input.js'; /** * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. @@ -8,26 +9,12 @@ import {reshape} from './reshape.js'; * @param {MLSqueezeOptions} options * @return {Tensor} */ -export function squeeze(input, options = {}) { - let axes = options.axes; - if (axes) { - if (axes.length > input.rank) { - throw new Error(`The length of axes ${axes.length} is bigger than input rank ${input.rank}.`); - } +export function squeeze(input, {axes} = {}) { + validateSqueezeParams(...arguments); + const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); - for (const axis of axes) { - if (axis < 0 || axis >= input.rank) { - throw new Error(`The value of axes ${axis} is invalid.`); - } - if (options.axes && input.shape[axis] !== 1) { - throw new Error(`The value ${input.shape[axis]} at axis ${axis} of input shape is not 1.`); - } - } - } else { - axes = new Array(input.rank).fill(0).map((_, i) => i); - } - - const outputShape = input.shape.filter((dim, axis) => !(dim === 1 && axes.indexOf(axis) !== -1)); + const outputShape = input.shape.filter((dim, axis) => + !(dim === 1 && inpAxes.indexOf(axis) !== -1)); const output = reshape(input, outputShape); return output; } diff --git a/src/transpose.js b/src/transpose.js index 2226b54..646c084 100644 --- a/src/transpose.js +++ b/src/transpose.js @@ -1,6 +1,7 @@ 'use strict'; -import {Tensor} from './tensor.js'; +import {Tensor} from './lib/tensor.js'; +import {validateTranposeParams} from './lib/validate-input.js'; /** * Permute the dimensions of the input tensor according to the permutation argument. @@ -8,22 +9,20 @@ import {Tensor} from './tensor.js'; * @param {MLTransposeOptions} [options] * @return {Tensor} */ -export function transpose(input, options = {}) { - const permutation = options.permutation ? options.permutation : - new Array(input.rank).fill(0).map((e, i, a) => a.length - i - 1); - if (permutation.length !== input.rank) { - throw new Error( - `The permutation length ${permutation.length} is not equal to rank ${input.rank}.`); - } +export function transpose(input, {permutation} = {}) { + const inpPermutation = permutation ?? + new Array(input.rank).fill(0).map((e, i, a) => a.length - i - 1); + validateTranposeParams(input, {permutation: inpPermutation}); - const outputShape = new Array(input.rank).fill(0).map((e, i, a) => input.shape[permutation[i]]); + const outputShape = new Array(input.rank).fill(0).map( + (e, i, a) => input.shape[inpPermutation[i]]); const output = new Tensor(outputShape); for (let inputIndex = 0; inputIndex < input.size; ++inputIndex) { const inputValue = input.getValueByIndex(inputIndex); const inputLocation = input.locationFromIndex(inputIndex); const outputLocation = new Array(output.rank); - for (let i = 0; i < permutation.length; ++i) { - outputLocation[i] = inputLocation[permutation[i]]; + for (let i = 0; i < inpPermutation.length; ++i) { + outputLocation[i] = inputLocation[inpPermutation[i]]; } output.setValueByLocation(outputLocation, inputValue); } diff --git a/src/unary.js b/src/unary.js index 59e472c..1d45243 100644 --- a/src/unary.js +++ b/src/unary.js @@ -1,6 +1,6 @@ 'use strict'; -import {Tensor} from './tensor.js'; +import {Tensor} from './lib/tensor.js'; /** * Compute the element-wise unary operation for input tensor. diff --git a/test/batch_normalization_test.js b/test/batch_normalization_test.js index 4988300..c2ae7e5 100644 --- a/test/batch_normalization_test.js +++ b/test/batch_normalization_test.js @@ -5,7 +5,7 @@ import {clamp} from '../src/clamp.js'; import {leakyRelu} from '../src/leaky_relu.js'; import {relu} from '../src/relu.js'; import {sigmoid} from '../src/sigmoid.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test batchNormalization', function() { diff --git a/test/binary_test.js b/test/binary_test.js index b364529..3a0fc53 100644 --- a/test/binary_test.js +++ b/test/binary_test.js @@ -1,7 +1,7 @@ 'use strict'; import {add, sub, mul, div, max, min, pow} from '../src/binary.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test binary', function() { diff --git a/test/brodcast_test.js b/test/brodcast_test.js index 7873add..8aabc92 100644 --- a/test/brodcast_test.js +++ b/test/brodcast_test.js @@ -1,7 +1,7 @@ 'use strict'; -import {broadcast} from '../src/broadcast.js'; -import {Tensor, sizeOfShape} from '../src/tensor.js'; +import {broadcast} from '../src/lib/broadcast.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test broadcast', function() { diff --git a/test/clamp_test.js b/test/clamp_test.js index 4e5bf82..7e1779b 100644 --- a/test/clamp_test.js +++ b/test/clamp_test.js @@ -1,7 +1,7 @@ 'use strict'; import {clamp} from '../src/clamp.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test clamp', function() { diff --git a/test/concat_test.js b/test/concat_test.js index e565967..4ac44e7 100644 --- a/test/concat_test.js +++ b/test/concat_test.js @@ -1,7 +1,7 @@ 'use strict'; import {concat} from '../src/concat.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test concat', function() { diff --git a/test/conv2d_test.js b/test/conv2d_test.js index fc58724..5c06248 100644 --- a/test/conv2d_test.js +++ b/test/conv2d_test.js @@ -5,7 +5,7 @@ import {clamp} from '../src/clamp.js'; import {leakyRelu} from '../src/leaky_relu.js'; import {relu} from '../src/relu.js'; import {sigmoid} from '../src/sigmoid.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; diff --git a/test/gemm_test.js b/test/gemm_test.js index 6da03ac..3aa9dea 100644 --- a/test/gemm_test.js +++ b/test/gemm_test.js @@ -1,7 +1,7 @@ 'use strict'; import {gemm} from '../src/gemm.js'; -import {Tensor, Scalar} from '../src/tensor.js'; +import {Tensor, Scalar} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test gemm', function() { diff --git a/test/gru_cell_test.js b/test/gru_cell_test.js index 59afa54..3939d34 100644 --- a/test/gru_cell_test.js +++ b/test/gru_cell_test.js @@ -1,7 +1,7 @@ 'use strict'; import {gruCell} from '../src/gru_cell.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import {sigmoid} from '../src/sigmoid.js'; import {tanh} from '../src/tanh.js'; import * as utils from './utils.js'; diff --git a/test/gru_test.js b/test/gru_test.js index eceae3b..43c4bb0 100644 --- a/test/gru_test.js +++ b/test/gru_test.js @@ -1,7 +1,7 @@ 'use strict'; import {gru} from '../src/gru.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test gru', function() { diff --git a/test/leaky_relu_test.js b/test/leaky_relu_test.js index 2354edd..c46c78f 100644 --- a/test/leaky_relu_test.js +++ b/test/leaky_relu_test.js @@ -1,7 +1,7 @@ 'use strict'; import {leakyRelu} from '../src/leaky_relu.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test leakyRelu', function() { diff --git a/test/matmul_test.js b/test/matmul_test.js index dd890e3..1f11f2f 100644 --- a/test/matmul_test.js +++ b/test/matmul_test.js @@ -1,7 +1,7 @@ 'use strict'; import {matmul} from '../src/matmul.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test matmul', function() { diff --git a/test/pool2d_test.js b/test/pool2d_test.js index d714df3..39b0996 100644 --- a/test/pool2d_test.js +++ b/test/pool2d_test.js @@ -1,6 +1,6 @@ 'use strict'; import {averagePool2d, maxPool2d} from '../src/pool2d.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; diff --git a/test/reduce_test.js b/test/reduce_test.js index f9f15ae..1e9a734 100644 --- a/test/reduce_test.js +++ b/test/reduce_test.js @@ -1,7 +1,7 @@ 'use strict'; import * as reducers from '../src/reduce.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; diff --git a/test/relu_test.js b/test/relu_test.js index f5183ec..f849ea1 100644 --- a/test/relu_test.js +++ b/test/relu_test.js @@ -1,7 +1,7 @@ 'use strict'; import {relu} from '../src/relu.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test relu', function() { diff --git a/test/reshape_test.js b/test/reshape_test.js index 88dc4fe..640aea2 100644 --- a/test/reshape_test.js +++ b/test/reshape_test.js @@ -1,7 +1,7 @@ 'use strict'; import {reshape} from '../src/reshape.js'; -import {Tensor, sizeOfShape} from '../src/tensor.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test reshape', function() { diff --git a/test/sigmoid_test.js b/test/sigmoid_test.js index 2258c7d..d7c0ca5 100644 --- a/test/sigmoid_test.js +++ b/test/sigmoid_test.js @@ -1,7 +1,7 @@ 'use strict'; import {sigmoid} from '../src/sigmoid.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test sigmoid', function() { diff --git a/test/slice_test.js b/test/slice_test.js index cabb5d4..2d7b03f 100644 --- a/test/slice_test.js +++ b/test/slice_test.js @@ -1,7 +1,7 @@ 'use strict'; import {slice} from '../src/slice.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test slice', function() { diff --git a/test/softmax_test.js b/test/softmax_test.js index 85c9d39..f081b32 100644 --- a/test/softmax_test.js +++ b/test/softmax_test.js @@ -1,7 +1,7 @@ 'use strict'; import {softmax} from '../src/softmax.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test softmax', function() { diff --git a/test/split_test.js b/test/split_test.js index bc6a0d8..738271e 100644 --- a/test/split_test.js +++ b/test/split_test.js @@ -1,7 +1,7 @@ 'use strict'; import {split} from '../src/split.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test split', function() { diff --git a/test/squeeze_test.js b/test/squeeze_test.js index cee577b..e2103be 100644 --- a/test/squeeze_test.js +++ b/test/squeeze_test.js @@ -1,7 +1,7 @@ 'use strict'; import {squeeze} from '../src/squeeze.js'; -import {Tensor, sizeOfShape} from '../src/tensor.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test squeeze', function() { diff --git a/test/tanh_test.js b/test/tanh_test.js index 3e4ad58..1d44aa2 100644 --- a/test/tanh_test.js +++ b/test/tanh_test.js @@ -1,7 +1,7 @@ 'use strict'; import {tanh} from '../src/tanh.js'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test tanh', function() { diff --git a/test/transpose_test.js b/test/transpose_test.js index 4c06dc0..cb5ae2a 100644 --- a/test/transpose_test.js +++ b/test/transpose_test.js @@ -1,6 +1,6 @@ 'use strict'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import {transpose} from '../src/transpose.js'; import * as utils from './utils.js'; diff --git a/test/unary_test.js b/test/unary_test.js index 752084b..be8899a 100644 --- a/test/unary_test.js +++ b/test/unary_test.js @@ -1,6 +1,6 @@ 'use strict'; -import {Tensor} from '../src/tensor.js'; +import {Tensor} from '../src/lib/tensor.js'; import * as unaryFunctions from '../src/unary.js'; import * as utils from './utils.js'; diff --git a/test/utils.js b/test/utils.js index bb3b3c7..1e60b0a 100644 --- a/test/utils.js +++ b/test/utils.js @@ -17,33 +17,37 @@ function getBitwise(value) { } /** - * Check the distance between a and b whether is close enough to the given ULP distance. + * Asserts that the distance between a and b whether is close enough to the given ULP distance. * @param {Number} a * @param {Number} b * @param {Number} nulp A BigInt value. + * @param {String} message A message to report when the assertion fails * @return {Boolean} A boolean value: * true: The distance between a and b is close enough to the given ULP distance. * false: The distance between a and b is far away from the given ULP distance. */ -function almostEqualUlp(a, b, nulp) { +assert.isAlmostEqualUlp = function(a, b, nulp, message) { const aBitwise = getBitwise(a); const bBitwise = getBitwise(b); let distance = aBitwise - bBitwise; distance = distance >= 0 ? distance : -distance; - return distance <= nulp; -} + return assert.isTrue(distance <= nulp, message); +}; -export function checkValue(tensor, expected, nulp = 0n) { +export function checkValue(tensor, expected, nulp = 0) { assert.isTrue(tensor.size === expected.length); for (let i = 0; i < expected.length; ++i) { - assert.isTrue(almostEqualUlp(tensor.getValueByIndex(i), expected[i], nulp)); + assert.isAlmostEqualUlp(tensor.getValueByIndex(i), expected[i], nulp, + `${tensor.getValueByIndex(i)} is almost equal to ${expected[i]}`); } } export function checkShape(tensor, expected) { - assert.equal(tensor.rank, expected.length); + assert.equal(tensor.rank, expected.length, + `Tensor has expected rank ${expected.length}: ${tensor.rank}`); for (let i = 0; i < expected.length; ++i) { - assert.equal(tensor.shape[i], expected[i]); + assert.equal(tensor.shape[i], expected[i], + `Tensor line ${i} has expected length ${expected[i]}: ${tensor.shape[i]}`); } }