Skip to content

Commit

Permalink
Merge pull request #71 from mei1127/add_lstm
Browse files Browse the repository at this point in the history
completed lstm
  • Loading branch information
huningxin authored Apr 16, 2024
2 parents a9f2ded + 5b098bc commit d01c2fc
Show file tree
Hide file tree
Showing 11 changed files with 623 additions and 61 deletions.
3 changes: 2 additions & 1 deletion src/arg_max_min.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js';
import {squeeze} from './squeeze.js';
import {squeeze} from './reshape.js';


/**
* Get the index location of the minimum or maxmium values of all the input values along the axes.
Expand Down
3 changes: 1 addition & 2 deletions src/gru.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import {concat} from './concat.js';
import {gruCell} from './gru_cell.js';
import {reshape} from './reshape.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
import {tanh} from './tanh.js';
import {validateGruParams} from './lib/validate-input.js';

Expand Down
166 changes: 164 additions & 2 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,168 @@ export function validateGemmParams(a, b) {
}
}

export function validateLstmCellParams(input, weight, recurrentWeight,
hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 2) {
throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`);
}
const batchSize = input.shape[0];
const inputSize = input.shape[1];
if (weight.rank !== 2) {
throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`);
}
if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`);
}
if (recurrentWeight.rank !== 2) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`);
}
if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`);
}
if (hiddenState.rank !== 2) {
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
}
if (cellState.rank !== 2) {
throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`);
}
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) {
throw new Error(`The shape of cellState [${cellState.shape[0]},
${cellState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (bias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 1) {
throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (recurrentBias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 1) {
throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (peepholeWeight.shape[0] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}

export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState,
direction = 'forward', layout = 'iofg'}) {
if (!Number.isInteger(steps) || steps <= 0) {
throw new Error(`The steps ${steps} is invalid.`);
}
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 3) {
throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`);
}
if (input.shape[0] !== steps) {
throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`);
}
const batchSize = input.shape[1];
const inputSize = input.shape[2];
if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') {
throw new Error(`The direction ${direction} is invalid.`);
}
const numDirections = (direction === 'both' ? 2 : 1);
if (weight.rank !== 3) {
throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`);
}
if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize ||
weight.shape[2] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]},
${weight.shape[2]}] is invalid.`);
}
if (recurrentWeight.rank !== 3) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`);
}
if (recurrentWeight.shape[0] !== numDirections ||
recurrentWeight.shape[1] !== 4 * hiddenSize ||
recurrentWeight.shape[2] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` +
`${recurrentWeight.shape[2]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 2) {
throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`);
}
if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 2) {
throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`);
}
if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]},
${recurrentBias.shape[1]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 2) {
throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`);
}
if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]},
${peepholeWeight.shape[1]}] is invalid.`);
}
}
if (initialHiddenState) {
if (initialHiddenState.rank !== 3) {
throw new Error(
`The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`);
}
if (initialHiddenState.shape[0] !== numDirections ||
initialHiddenState.shape[1] !== batchSize ||
initialHiddenState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]},
${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`);
}
}
if (initialCellState) {
if (initialCellState.rank !== 3) {
throw new Error(
`The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`);
}
if (initialCellState.shape[0] !== numDirections ||
initialCellState.shape[1] !== batchSize ||
initialCellState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialCellState [${initialCellState.shape[0]},
${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}


export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize,
{bias, recurrentBias, layout = 'zrn'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
Expand Down Expand Up @@ -194,8 +356,8 @@ export function validateGruCellParams(input, weight, recurrentWeight, hiddenStat
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
throw new Error(`The shape of hiddenState [${hiddenState.shape[0]},
${hiddenState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
Expand Down
106 changes: 106 additions & 0 deletions src/lstm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
'use strict';

import {concat} from './concat.js';
import {lstmCell} from './lstm_cell.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {tanh} from './tanh.js';
import {validateLstmParams} from './lib/validate-input.js';

/**
*Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget,
*and cell gate to compute the output state that rolls into the output across the
* temporal sequence of the network.
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Number} steps
* @param {Number} hiddenSize
* @param {MLLstmOptions} options
* @return {Array.<Tensor>}
*/
export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg',
activations = [sigmoid, tanh, tanh]} = {}) {
validateLstmParams(...arguments);
const numDirections = (direction == 'both' ? 2 : 1);
const batchSize = input.shape[1];
const inputSize = input.shape[2];

let hiddenState;
let cellState;
if (initialHiddenState) {
hiddenState = initialHiddenState;
} else {
const initialHiddenStateShape = [numDirections, batchSize, hiddenSize];
hiddenState = new Tensor(
initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0));
}
if (initialCellState) {
cellState = initialCellState;
} else {
const initialCellState = [numDirections, batchSize, hiddenSize];
cellState = new Tensor(
initialCellState, new Array(sizeOfShape(initialCellState)).fill(0));
}

let sequence;
const currentWeight = [];
const currentRecurrentWeight = [];
const currentBias = [];
const currentRecurrentBias = [];
const currentPeepholeWeight = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize])));
currentRecurrentWeight.push(squeeze(slice(recurrentWeight,
[dir, 0, 0], [1, 4 * hiddenSize, hiddenSize])));
currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentRecurrentBias.push(recurrentBias ?
(squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentPeepholeWeight.push(peepholeWeight ?
(squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null);
}

for (let step = 0; step < steps; ++step) {
const currentHidden = [];
const currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}

for (let dir = 0; dir < numDirections; ++dir) {
const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize]));

const results = lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir],
layout: layout, activations: activations});

const output = reshape(results[0], [1, null, hiddenSize]);
const cell = reshape(results[1], [1, null, hiddenSize]);

nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? concat([nextCell, cell], 0) : cell);
}

hiddenState = nextHidden;
cellState = nextCell;

if (returnSequence) {
nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]);
sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden);
}
}

return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
}
Loading

0 comments on commit d01c2fc

Please sign in to comment.