Skip to content

Commit

Permalink
completed lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
mei1127 committed Mar 1, 2024
1 parent 9a7e0a4 commit 24d6d03
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
95 changes: 95 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,101 @@ export function validateLstmCellParams(input, weight, recurrentWeight,
}
}

export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState,
direction = 'forward', layout = 'iofg'}) {
if (!Number.isInteger(steps) || steps <= 0) {
throw new Error(`The steps ${steps} is invalid.`);
}
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 3) {
throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`);
}
if (input.shape[0] !== steps) {
throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`);
}
const batchSize = input.shape[1];
const inputSize = input.shape[2];
if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') {
throw new Error(`The direction ${direction} is invalid.`);
}
const numDirections = (direction === 'both' ? 2 : 1);
if (weight.rank !== 3) {
throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`);
}
if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize ||
weight.shape[2] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]},
${weight.shape[2]}] is invalid.`);
}
if (recurrentWeight.rank !== 3) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`);
}
if (recurrentWeight.shape[0] !== numDirections ||
recurrentWeight.shape[1] !== 4 * hiddenSize ||
recurrentWeight.shape[2] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` +
`${recurrentWeight.shape[2]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 2) {
throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`);
}
if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 2) {
throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`);
}
if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]},
${recurrentBias.shape[1]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 2) {
throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`);
}
if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]},
${peepholeWeight.shape[1]}] is invalid.`);
}
}
if (initialHiddenState) {
if (initialHiddenState.rank !== 3) {
throw new Error(
`The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`);
}
if (initialHiddenState.shape[0] !== numDirections ||
initialHiddenState.shape[1] !== batchSize ||
initialHiddenState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]},
${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`);
}
}
if (initialCellState) {
if (initialCellState.rank !== 3) {
throw new Error(
`The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`);
}
if (initialCellState.shape[0] !== numDirections ||
initialCellState.shape[1] !== batchSize ||
initialCellState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialCellState [${initialCellState.shape[0]},
${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}


export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize,
{bias, recurrentBias, layout = 'zrn'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
Expand Down
108 changes: 108 additions & 0 deletions src/lstm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
'use strict';

import {concat} from './concat.js';
import {lstmCell} from './lstm_cell.js';
import {reshape} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
import {tanh} from './tanh.js';
import {validateLstmParams} from './lib/validate-input.js';

/**
*Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget,
*and cell gate to compute the output state that rolls into the output across the
* temporal sequence of the network.
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Number} steps
* @param {Number} hiddenSize
* @param {MLLstmOptions} options
* @return {Array.<Tensor>}
*/

export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg',
activations = [sigmoid, tanh, tanh]}={}) {
validateLstmParams(...arguments);
const numDirections = (direction == 'both' ? 2 : 1);
const batchSize = input.shape[1];
const inputSize = input.shape[2];

let hiddenState;
let cellState;
if (initialHiddenState) {
hiddenState = initialHiddenState;
} else {
const initialHiddenStateShape = [numDirections, batchSize, hiddenSize];
hiddenState = new Tensor(
initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0));
}
if (initialCellState) {
cellState = initialCellState;
} else {
const initialCellState = [numDirections, batchSize, hiddenSize];
cellState = new Tensor(
initialCellState, new Array(sizeOfShape(initialCellState)).fill(0));
}

let sequence;
const currentWeight = [];
const currentRecurrentWeight = [];
const currentBias = [];
const currentRecurrentBias = [];
const currentPeepholeWeight = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize])));
currentRecurrentWeight.push(squeeze(slice(recurrentWeight,
[dir, 0, 0], [1, 4 * hiddenSize, hiddenSize])));
currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentRecurrentBias.push(recurrentBias ?
(squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentPeepholeWeight.push(peepholeWeight ?
(squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null);
}

for (let step = 0; step < steps; ++step) {
const currentHidden = [];
const currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}

for (let dir = 0; dir < numDirections; ++dir) {
const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize]));

const results = lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir],
layout: layout, activations: activations});

const output = reshape(results[0], [1, null, hiddenSize]);
const cell = reshape(results[1], [1, null, hiddenSize]);

nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? concat([nextCell, cell], 0) : cell);
}

hiddenState = nextHidden;
cellState = nextCell;

if (returnSequence) {
nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]);
sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden);
}
}

return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
}
131 changes: 131 additions & 0 deletions test/lstm_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
'use strict';

import {lstm} from '../src/lstm.js';
import {relu} from '../src/relu.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test lstm', function() {
it('lstm returnSequence=true ' +
'activations=[relu, relu, relu]', function() {
const steps = 1;
const numDirections = 1;
const batchSize = 2;
const inputSize = 2;
const hiddenSize = 2;
const input = new Tensor([steps, batchSize, inputSize], new Float32Array([1, 2, 2, 1]));
const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize],
new Float32Array([
1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
]));
const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize],
new Array(4 * hiddenSize * hiddenSize).fill(0.1));
const bias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const recurrentBias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize],
new Float32Array(3 * hiddenSize).fill(0));
const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const initialCellState = new Tensor([numDirections, batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const returnSequence = true;
const activations = [
relu,
relu,
relu,
];
const outputs = lstm(
input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence, activations});
console.log('outputs: ', outputs);
utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]);
const expected = [
[
1, 8, 27, 216,
],
[
1, 4, 9, 36,
],
[
1, 8, 27, 216,
],
];
for (let i = 0; i < expected.length; ++i) {
utils.checkValue(outputs[i], expected[i]);
}
});

it('lstm steps=2 direction="backward" returnSequence=true' +
'activations=[relu, relu, relu]', function() {
const steps = 2;
const numDirections = 1;
const batchSize = 2;
const inputSize = 2;
const hiddenSize = 2;
const input = new Tensor([steps, batchSize, inputSize],
new Float32Array([1, 2, 2, 1, 3, 4, 1, 2]));
const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize],
new Float32Array([
1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
]));
const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize],
new Array(4 * hiddenSize * hiddenSize).fill(0.1));
const bias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const recurrentBias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize],
new Float32Array(3 * hiddenSize).fill(0));
const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const initialCellState = new Tensor([numDirections, batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const returnSequence = true;
const direction = 'backward';
const activations = [
relu,
relu,
relu,
];
const outputs = lstm(
input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, direction, returnSequence, activations});
console.log('outputs: ', outputs);
utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]);
const expected = [
[10.469, 58.02899999999999, 74.529, 518.9490000000001],
[5.51, 20.009999999999998, 19.11, 75.21000000000001],
[
1,
8,
1,
8,
10.469,
58.02899999999999,
74.529,
518.9490000000001,
],
];
for (let i = 0; i < expected.length; ++i) {
utils.checkValue(outputs[i], expected[i]);
}
});
});

0 comments on commit 24d6d03

Please sign in to comment.