diff --git a/src/lstm_cell.js b/src/lstm_cell.js index bfafd46..3d07d11 100644 --- a/src/lstm_cell.js +++ b/src/lstm_cell.js @@ -10,7 +10,10 @@ import {transpose} from './transpose.js'; import {validateLstmCellParams} from './lib/validate-input.js'; /** - * + *A single time step of the Long Short-Term Memory [LSTM] recurrent network + *using a cell state, an input, output, and forget gate to compute the cell + *state and the hidden state of the next time step that rolls into the output + *across the temporal sequence of the network. * @param {Tensor} input * @param {Tensor} weight * @param {Tensor} recurrentWeight @@ -18,17 +21,16 @@ import {validateLstmCellParams} from './lib/validate-input.js'; * @param {Tensor} cellState * @param {Number} hiddenSize * @param {MLLstmCellOptions} options - * return {Tensor} + * @return {Tensor} */ - export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, {bias, recurrentBias, peepholeWeight, layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { validateLstmCellParams(...arguments); const zero = new Scalar(0); const inputSize = input.shape[1]; - const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} : - {i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize}; + const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : + {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; const activation0 = activations[0]; const activation1 = activations[1]; const activation2 = activations[2]; diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js index f84cde0..b57ffa8 100644 --- a/test/lstm_cell_test.js +++ b/test/lstm_cell_test.js @@ -6,7 +6,7 @@ import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test lstmCell', function() { - it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() { + it('lstmCell lstmCell activations=[relu, relu, relu]', function() { const batchSize = 2; const inputSize = 2; const hiddenSize = 2;