Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completed lstm #71

Merged
merged 13 commits into from
Apr 16, 2024
3 changes: 2 additions & 1 deletion src/arg_max_min.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js';
import {squeeze} from './squeeze.js';
import {squeeze} from './reshape.js';


/**
* Get the index location of the minimum or maxmium values of all the input values along the axes.
Expand Down
3 changes: 1 addition & 2 deletions src/gru.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import {concat} from './concat.js';
import {gruCell} from './gru_cell.js';
import {reshape} from './reshape.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
import {tanh} from './tanh.js';
import {validateGruParams} from './lib/validate-input.js';

Expand Down
166 changes: 164 additions & 2 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,168 @@ export function validateGemmParams(a, b) {
}
}

export function validateLstmCellParams(input, weight, recurrentWeight,
hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 2) {
throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`);
}
const batchSize = input.shape[0];
const inputSize = input.shape[1];
if (weight.rank !== 2) {
throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`);
}
if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`);
}
if (recurrentWeight.rank !== 2) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`);
}
if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`);
}
if (hiddenState.rank !== 2) {
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
}
if (cellState.rank !== 2) {
throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`);
}
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) {
throw new Error(`The shape of cellState [${cellState.shape[0]},
${cellState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (bias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 1) {
throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (recurrentBias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 1) {
throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (peepholeWeight.shape[0] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}

export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState,
direction = 'forward', layout = 'iofg'}) {
if (!Number.isInteger(steps) || steps <= 0) {
throw new Error(`The steps ${steps} is invalid.`);
}
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 3) {
throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`);
}
if (input.shape[0] !== steps) {
throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`);
}
const batchSize = input.shape[1];
const inputSize = input.shape[2];
if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') {
throw new Error(`The direction ${direction} is invalid.`);
}
const numDirections = (direction === 'both' ? 2 : 1);
if (weight.rank !== 3) {
throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`);
}
if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize ||
weight.shape[2] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]},
${weight.shape[2]}] is invalid.`);
}
if (recurrentWeight.rank !== 3) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`);
}
if (recurrentWeight.shape[0] !== numDirections ||
recurrentWeight.shape[1] !== 4 * hiddenSize ||
recurrentWeight.shape[2] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` +
`${recurrentWeight.shape[2]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 2) {
throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`);
}
if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 2) {
throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`);
}
if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]},
${recurrentBias.shape[1]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 2) {
throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`);
}
if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]},
${peepholeWeight.shape[1]}] is invalid.`);
}
}
if (initialHiddenState) {
if (initialHiddenState.rank !== 3) {
throw new Error(
`The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`);
}
if (initialHiddenState.shape[0] !== numDirections ||
initialHiddenState.shape[1] !== batchSize ||
initialHiddenState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]},
${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`);
}
}
if (initialCellState) {
if (initialCellState.rank !== 3) {
throw new Error(
`The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`);
}
if (initialCellState.shape[0] !== numDirections ||
initialCellState.shape[1] !== batchSize ||
initialCellState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialCellState [${initialCellState.shape[0]},
${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}


export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize,
{bias, recurrentBias, layout = 'zrn'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
Expand Down Expand Up @@ -194,8 +356,8 @@ export function validateGruCellParams(input, weight, recurrentWeight, hiddenStat
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
throw new Error(`The shape of hiddenState [${hiddenState.shape[0]},
${hiddenState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
Expand Down
106 changes: 106 additions & 0 deletions src/lstm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
'use strict';

import {concat} from './concat.js';
import {lstmCell} from './lstm_cell.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {tanh} from './tanh.js';
import {validateLstmParams} from './lib/validate-input.js';

/**
*Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget,
*and cell gate to compute the output state that rolls into the output across the
* temporal sequence of the network.
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Number} steps
* @param {Number} hiddenSize
* @param {MLLstmOptions} options
* @return {Array.<Tensor>}
*/
export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg',
activations = [sigmoid, tanh, tanh]} = {}) {
validateLstmParams(...arguments);
const numDirections = (direction == 'both' ? 2 : 1);
const batchSize = input.shape[1];
const inputSize = input.shape[2];

let hiddenState;
let cellState;
if (initialHiddenState) {
hiddenState = initialHiddenState;
} else {
const initialHiddenStateShape = [numDirections, batchSize, hiddenSize];
hiddenState = new Tensor(
initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0));
}
if (initialCellState) {
cellState = initialCellState;
} else {
const initialCellState = [numDirections, batchSize, hiddenSize];
cellState = new Tensor(
initialCellState, new Array(sizeOfShape(initialCellState)).fill(0));
}

let sequence;
const currentWeight = [];
const currentRecurrentWeight = [];
const currentBias = [];
const currentRecurrentBias = [];
const currentPeepholeWeight = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize])));
currentRecurrentWeight.push(squeeze(slice(recurrentWeight,
[dir, 0, 0], [1, 4 * hiddenSize, hiddenSize])));
currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentRecurrentBias.push(recurrentBias ?
(squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentPeepholeWeight.push(peepholeWeight ?
(squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null);
}

for (let step = 0; step < steps; ++step) {
const currentHidden = [];
const currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}

for (let dir = 0; dir < numDirections; ++dir) {
const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize]));

const results = lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir],
layout: layout, activations: activations});

const output = reshape(results[0], [1, null, hiddenSize]);
const cell = reshape(results[1], [1, null, hiddenSize]);

nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? concat([nextCell, cell], 0) : cell);
}

hiddenState = nextHidden;
cellState = nextCell;

if (returnSequence) {
nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]);
sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden);
}
}

return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
}
Loading
Loading