diff --git a/src/arg_max_min.js b/src/arg_max_min.js index 8e1f404..4c4d2a0 100644 --- a/src/arg_max_min.js +++ b/src/arg_max_min.js @@ -2,7 +2,8 @@ import {Tensor, sizeOfShape} from './lib/tensor.js'; import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js'; -import {squeeze} from './squeeze.js'; +import {squeeze} from './reshape.js'; + /** * Get the index location of the minimum or maxmium values of all the input values along the axes. diff --git a/src/gru.js b/src/gru.js index 80583c0..cf93319 100644 --- a/src/gru.js +++ b/src/gru.js @@ -2,11 +2,10 @@ import {concat} from './concat.js'; import {gruCell} from './gru_cell.js'; -import {reshape} from './reshape.js'; +import {reshape, squeeze} from './reshape.js'; import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; import {slice} from './slice.js'; -import {squeeze} from './squeeze.js'; import {tanh} from './tanh.js'; import {validateGruParams} from './lib/validate-input.js'; diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 723099d..2a8cd49 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -167,6 +167,168 @@ export function validateGemmParams(a, b) { } } +export function validateLstmCellParams(input, weight, recurrentWeight, + hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) { + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 2) { + throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); + } + const batchSize = input.shape[0]; + const inputSize = input.shape[1]; + if (weight.rank !== 2) { + throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); + } + if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); + } + if (recurrentWeight.rank !== 2) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); + } + if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); + } + if (hiddenState.rank !== 2) { + throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); + } + if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { + throw new Error(`The shape of hiddenState + [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + } + if (cellState.rank !== 2) { + throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`); + } + if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) { + throw new Error(`The shape of cellState [${cellState.shape[0]}, + ${cellState.shape[1]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 1) { + throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (bias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 1) { + throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (recurrentBias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); + } + } + if (peepholeWeight) { + if (peepholeWeight.rank !== 1) { + throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (peepholeWeight.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`); + } + } + if (layout !== 'iofg' && layout !== 'ifgo') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + +export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, + direction = 'forward', layout = 'iofg'}) { + if (!Number.isInteger(steps) || steps <= 0) { + throw new Error(`The steps ${steps} is invalid.`); + } + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 3) { + throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`); + } + if (input.shape[0] !== steps) { + throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`); + } + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') { + throw new Error(`The direction ${direction} is invalid.`); + } + const numDirections = (direction === 'both' ? 2 : 1); + if (weight.rank !== 3) { + throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`); + } + if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize || + weight.shape[2] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}, + ${weight.shape[2]}] is invalid.`); + } + if (recurrentWeight.rank !== 3) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`); + } + if (recurrentWeight.shape[0] !== numDirections || + recurrentWeight.shape[1] !== 4 * hiddenSize || + recurrentWeight.shape[2] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` + + `${recurrentWeight.shape[2]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 2) { + throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`); + } + if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 2) { + throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`); + } + if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}, + ${recurrentBias.shape[1]}] is invalid.`); + } + } + if (peepholeWeight) { + if (peepholeWeight.rank !== 2) { + throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`); + } + if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}, + ${peepholeWeight.shape[1]}] is invalid.`); + } + } + if (initialHiddenState) { + if (initialHiddenState.rank !== 3) { + throw new Error( + `The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`); + } + if (initialHiddenState.shape[0] !== numDirections || + initialHiddenState.shape[1] !== batchSize || + initialHiddenState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]}, + ${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`); + } + } + if (initialCellState) { + if (initialCellState.rank !== 3) { + throw new Error( + `The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`); + } + if (initialCellState.shape[0] !== numDirections || + initialCellState.shape[1] !== batchSize || + initialCellState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialCellState [${initialCellState.shape[0]}, + ${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`); + } + } + if (layout !== 'iofg' && layout !== 'ifgo') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + + export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, {bias, recurrentBias, layout = 'zrn'} = {}) { if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { @@ -194,8 +356,8 @@ export function validateGruCellParams(input, weight, recurrentWeight, hiddenStat throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); } if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { - throw new Error(`The shape of hiddenState - [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + throw new Error(`The shape of hiddenState [${hiddenState.shape[0]}, + ${hiddenState.shape[1]}] is invalid.`); } if (bias) { if (bias.rank !== 1) { diff --git a/src/lstm.js b/src/lstm.js new file mode 100644 index 0000000..99474d7 --- /dev/null +++ b/src/lstm.js @@ -0,0 +1,106 @@ +'use strict'; + +import {concat} from './concat.js'; +import {lstmCell} from './lstm_cell.js'; +import {reshape, squeeze} from './reshape.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {tanh} from './tanh.js'; +import {validateLstmParams} from './lib/validate-input.js'; + +/** + *Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget, + *and cell gate to compute the output state that rolls into the output across the + * temporal sequence of the network. + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Number} steps + * @param {Number} hiddenSize + * @param {MLLstmOptions} options + * @return {Array.} + */ +export function lstm(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg', + activations = [sigmoid, tanh, tanh]} = {}) { + validateLstmParams(...arguments); + const numDirections = (direction == 'both' ? 2 : 1); + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + + let hiddenState; + let cellState; + if (initialHiddenState) { + hiddenState = initialHiddenState; + } else { + const initialHiddenStateShape = [numDirections, batchSize, hiddenSize]; + hiddenState = new Tensor( + initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0)); + } + if (initialCellState) { + cellState = initialCellState; + } else { + const initialCellState = [numDirections, batchSize, hiddenSize]; + cellState = new Tensor( + initialCellState, new Array(sizeOfShape(initialCellState)).fill(0)); + } + + let sequence; + const currentWeight = []; + const currentRecurrentWeight = []; + const currentBias = []; + const currentRecurrentBias = []; + const currentPeepholeWeight = []; + + for (let dir = 0; dir < numDirections; ++dir) { + currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize]))); + currentRecurrentWeight.push(squeeze(slice(recurrentWeight, + [dir, 0, 0], [1, 4 * hiddenSize, hiddenSize]))); + currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null); + currentRecurrentBias.push(recurrentBias ? + (squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null); + currentPeepholeWeight.push(peepholeWeight ? + (squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null); + } + + for (let step = 0; step < steps; ++step) { + const currentHidden = []; + const currentCell = []; + let nextHidden = null; + let nextCell = null; + + for (let dir = 0; dir < numDirections; ++dir) { + currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]))); + currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]))); + } + + for (let dir = 0; dir < numDirections; ++dir) { + const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step); + const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize])); + + const results = lstmCell( + currentInput, currentWeight[dir], currentRecurrentWeight[dir], + currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir], + recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], + layout: layout, activations: activations}); + + const output = reshape(results[0], [1, null, hiddenSize]); + const cell = reshape(results[1], [1, null, hiddenSize]); + + nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); + nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); + } + + hiddenState = nextHidden; + cellState = nextCell; + + if (returnSequence) { + nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]); + sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); + } + } + + return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]); +} diff --git a/src/lstm_cell.js b/src/lstm_cell.js new file mode 100644 index 0000000..7a52374 --- /dev/null +++ b/src/lstm_cell.js @@ -0,0 +1,145 @@ +'use strict'; + +import {add, mul} from './binary.js'; +import {matmul} from './matmul.js'; +import {Scalar} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {tanh} from './tanh.js'; +import {transpose} from './transpose.js'; +import {validateLstmCellParams} from './lib/validate-input.js'; + +/** + *A single time step of the Long Short-Term Memory [LSTM] recurrent network + *using a cell state, an input, output, and forget gate to compute the cell + *state and the hidden state of the next time step that rolls into the output + *across the temporal sequence of the network. + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Tensor} hiddenState + * @param {Tensor} cellState + * @param {Number} hiddenSize + * @param {MLLstmCellOptions} options + * @return {Tensor} + */ +export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, + layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { + validateLstmCellParams(...arguments); + const zero = new Scalar(0); + const inputSize = input.shape[1]; + const starts = (layout === 'iofg') ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 * hiddenSize} : + {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; + const activation0 = activations[0]; + const activation1 = activations[1]; + const activation2 = activations[2]; + + // input gate (i) + const i = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.i], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.i], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.i], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.i, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.i, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // forget gate (f) + const f = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.f], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.f], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.f], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.f, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose( + slice(recurrentWeight, [starts.f, 0], [hiddenSize, hiddenSize]), + ), + ), + ), + ), + ), + ); + + // cell gate (g) + const g = activation1( + add( + add( + (bias ? slice(bias, [starts.g], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.g], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.g, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.g, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ); + + // output gate (o) + const o = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.o], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.o], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.o], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.o, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.o, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // output cell state (ct) + const ct = add(mul(f, cellState), mul(i, g)); + + // output hidden state (ht) + const ht = mul(o, activation2(ct)); + + return [ht, ct]; +} diff --git a/src/reduce.js b/src/reduce.js index aa7428a..3dd8fbf 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -1,7 +1,7 @@ 'use strict'; import {pow} from './binary.js'; -import {squeeze} from './squeeze.js'; +import {squeeze} from './reshape.js'; import {abs, exp, log} from './unary.js'; import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; import {validateReduceParams} from './lib/validate-input.js'; diff --git a/src/reshape.js b/src/reshape.js index 4cb0ffc..112b567 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -1,6 +1,7 @@ 'use strict'; import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateSqueezeParams} from './lib/validate-input.js'; /** * Alter the shape of a tensor to a new shape. @@ -31,3 +32,18 @@ export function reshape(input, newShape) { const output = new Tensor(outputShape, input.data); return output; } + +/** + * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. + * @param {Tensor} input + * @param {MLSqueezeOptions} options + * @return {Tensor} + */ +export function squeeze(input, {axes} = {}) { + validateSqueezeParams(...arguments); + const inputAxes = axes ?? Array.from({length: input.rank}, (_, i) => i); + const outputShape = input.shape.filter((dim, axis) => + !(dim === 1 && inputAxes.indexOf(axis) !== -1)); + const output = reshape(input, outputShape); + return output; +} diff --git a/src/squeeze.js b/src/squeeze.js deleted file mode 100644 index a463f79..0000000 --- a/src/squeeze.js +++ /dev/null @@ -1,20 +0,0 @@ -'use strict'; - -import {reshape} from './reshape.js'; -import {validateSqueezeParams} from './lib/validate-input.js'; - -/** - * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. - * @param {Tensor} input - * @param {MLSqueezeOptions} options - * @return {Tensor} - */ -export function squeeze(input, {axes} = {}) { - validateSqueezeParams(...arguments); - const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); - - const outputShape = input.shape.filter((dim, axis) => - !(dim === 1 && inpAxes.indexOf(axis) !== -1)); - const output = reshape(input, outputShape); - return output; -} diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js new file mode 100644 index 0000000..b57ffa8 --- /dev/null +++ b/test/lstm_cell_test.js @@ -0,0 +1,57 @@ +'use strict'; + +import {lstmCell} from '../src/lstm_cell.js'; +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test lstmCell', function() { + it('lstmCell lstmCell activations=[relu, relu, relu]', function() { + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([batchSize, inputSize], [1, 2, 2, 1]); + const weight = new Tensor([4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([4 * hiddenSize, hiddenSize], + new Float32Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const cellState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const bias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([3* hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstmCell( + input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, activations}); + utils.checkShape(outputs[0], [batchSize, hiddenSize]); + utils.checkShape(outputs[1], [batchSize, hiddenSize]); + const expected = [ + [ + 1, 8, 27, 216, + ], + [ + 1, 4, 9, 36, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); +}); diff --git a/test/lstm_test.js b/test/lstm_test.js new file mode 100644 index 0000000..c03d885 --- /dev/null +++ b/test/lstm_test.js @@ -0,0 +1,131 @@ +'use strict'; + +import {lstm} from '../src/lstm.js'; +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test lstm', function() { + it('lstm returnSequence=true ' + + 'activations=[relu, relu, relu]', function() { + const steps = 1; + const numDirections = 1; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], new Float32Array([1, 2, 2, 1])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const initialCellState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const returnSequence = true; + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, returnSequence, activations}); + console.log('outputs: ', outputs); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [ + 1, 8, 27, 216, + ], + [ + 1, 4, 9, 36, + ], + [ + 1, 8, 27, 216, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); + + it('lstm steps=2 direction="backward" returnSequence=true' + + 'activations=[relu, relu, relu]', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], + new Float32Array([1, 2, 2, 1, 3, 4, 1, 2])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const initialCellState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const returnSequence = true; + const direction = 'backward'; + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, direction, returnSequence, activations}); + console.log('outputs: ', outputs); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [10.469, 58.02899999999999, 74.529, 518.9490000000001], + [5.51, 20.009999999999998, 19.11, 75.21000000000001], + [ + 1, + 8, + 1, + 8, + 10.469, + 58.02899999999999, + 74.529, + 518.9490000000001, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); +}); diff --git a/test/squeeze_test.js b/test/squeeze_test.js deleted file mode 100644 index e2103be..0000000 --- a/test/squeeze_test.js +++ /dev/null @@ -1,35 +0,0 @@ -'use strict'; - -import {squeeze} from '../src/squeeze.js'; -import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; -import * as utils from './utils.js'; - -describe('test squeeze', function() { - function testSqueeze(oldShape, axes, expectedShape) { - const bufferSize = sizeOfShape(oldShape); - const inputBuffer = new Array(bufferSize); - for (let i = 0; i < inputBuffer.length; ++i) { - inputBuffer[i] = Math.random(); - } - const x = new Tensor(oldShape, inputBuffer); - const y = squeeze(x, {axes}); - utils.checkShape(y, expectedShape); - utils.checkValue(y, inputBuffer); - } - - it('squeeze one dimension by default', function() { - testSqueeze([1, 3, 4, 5], undefined, [3, 4, 5]); - }); - - it('squeeze one dimension with axes', function() { - testSqueeze([1, 3, 1, 5], [0], [3, 1, 5]); - }); - - it('squeeze two dimensions by default', function() { - testSqueeze([1, 3, 1, 5], undefined, [3, 5]); - }); - - it('squeeze two dimensions with axes', function() { - testSqueeze([1, 3, 1, 5], [0, 2], [3, 5]); - }); -});