Skip to content

Commit

Permalink
revised lstm_cell.js and lstm_cell_test.js
Browse files Browse the repository at this point in the history
  • Loading branch information
mei1127 committed Feb 28, 2024
1 parent abfceb9 commit 9a7e0a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/lstm_cell.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,27 @@ import {transpose} from './transpose.js';
import {validateLstmCellParams} from './lib/validate-input.js';

/**
*
*A single time step of the Long Short-Term Memory [LSTM] recurrent network
*using a cell state, an input, output, and forget gate to compute the cell
*state and the hidden state of the next time step that rolls into the output
*across the temporal sequence of the network.
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Tensor} hiddenState
* @param {Tensor} cellState
* @param {Number} hiddenSize
* @param {MLLstmCellOptions} options
* return {Tensor}
* @return {Tensor}
*/

export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight,
layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) {
validateLstmCellParams(...arguments);
const zero = new Scalar(0);
const inputSize = input.shape[1];
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} :
{i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize};
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} :
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize};
const activation0 = activations[0];
const activation1 = activations[1];
const activation2 = activations[2];
Expand Down
2 changes: 1 addition & 1 deletion test/lstm_cell_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test lstmCell', function() {
it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() {
it('lstmCell lstmCell activations=[relu, relu, relu]', function() {
const batchSize = 2;
const inputSize = 2;
const hiddenSize = 2;
Expand Down

0 comments on commit 9a7e0a4

Please sign in to comment.