Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completed lstm #71

Merged
merged 13 commits into from
Apr 16, 2024
162 changes: 162 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,168 @@ export function validateGemmParams(a, b) {
}
}

export function validateLstmCellParams(input, weight, recurrentWeight,
hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 2) {
throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`);
}
const batchSize = input.shape[0];
const inputSize = input.shape[1];
if (weight.rank !== 2) {
throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`);
}
if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`);
}
if (recurrentWeight.rank !== 2) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`);
}
if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`);
}
if (hiddenState.rank !== 2) {
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
}
if (cellState.rank !== 2) {
throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`);
}
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) {
throw new Error(`The shape of cellState
Copy link

@fdwr fdwr Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message will be split :( :

The shape of cellState
  [2, 3] is invalid.

[${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (bias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 1) {
throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (recurrentBias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 1) {
throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (peepholeWeight.shape[0] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}

export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState,
direction = 'forward', layout = 'iofg'}) {
if (!Number.isInteger(steps) || steps <= 0) {
throw new Error(`The steps ${steps} is invalid.`);
}
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 3) {
throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`);
}
if (input.shape[0] !== steps) {
throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`);
}
const batchSize = input.shape[1];
const inputSize = input.shape[2];
if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') {
throw new Error(`The direction ${direction} is invalid.`);
}
const numDirections = (direction === 'both' ? 2 : 1);
if (weight.rank !== 3) {
throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`);
}
if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize ||
weight.shape[2] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]},
${weight.shape[2]}] is invalid.`);
}
if (recurrentWeight.rank !== 3) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`);
}
if (recurrentWeight.shape[0] !== numDirections ||
recurrentWeight.shape[1] !== 4 * hiddenSize ||
recurrentWeight.shape[2] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` +
`${recurrentWeight.shape[2]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 2) {
throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`);
}
if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 2) {
throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`);
}
if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]},
${recurrentBias.shape[1]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 2) {
throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`);
}
if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]},
${peepholeWeight.shape[1]}] is invalid.`);
}
}
if (initialHiddenState) {
if (initialHiddenState.rank !== 3) {
throw new Error(
`The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`);
}
if (initialHiddenState.shape[0] !== numDirections ||
initialHiddenState.shape[1] !== batchSize ||
initialHiddenState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]},
${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`);
}
}
if (initialCellState) {
if (initialCellState.rank !== 3) {
throw new Error(
`The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`);
}
if (initialCellState.shape[0] !== numDirections ||
initialCellState.shape[1] !== batchSize ||
initialCellState.shape[2] !== hiddenSize) {
throw new Error(`The shape of initialCellState [${initialCellState.shape[0]},
${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}


export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize,
{bias, recurrentBias, layout = 'zrn'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
Expand Down
108 changes: 108 additions & 0 deletions src/lstm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
'use strict';

import {concat} from './concat.js';
import {lstmCell} from './lstm_cell.js';
import {reshape} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now squeeze op has been removed, would you please help also remove it from this WebNN Baseline.
You can refer to the given squeeze method in Spec, thanks.

function squeeze(builder, op) {
  return builder.reshape(op, op.shape().remove(0));
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition is:

function squeeze(input, axes) {
  if (!axes) axes = [];
  if (!axes.length)
    input.shape().forEach((item, i) => { axes.push(i); });
  shape = Array.from(input.shape());
  for (let axis in axes.sort().reverse())
    if (axis < shape.length && shape[axis] == 1)
      shape.splice(axis, 1);
  return builder.reshape(input, shape);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will revise it next week:)

import {tanh} from './tanh.js';
import {validateLstmParams} from './lib/validate-input.js';

/**
*Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget,
*and cell gate to compute the output state that rolls into the output across the
* temporal sequence of the network.
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Number} steps
* @param {Number} hiddenSize
* @param {MLLstmOptions} options
* @return {Array.<Tensor>}
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete this blank line.

export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg',
activations = [sigmoid, tanh, tanh]}={}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
activations = [sigmoid, tanh, tanh]}={}) {
activations = [sigmoid, tanh, tanh]} = {}) {

validateLstmParams(...arguments);
const numDirections = (direction == 'both' ? 2 : 1);
const batchSize = input.shape[1];
const inputSize = input.shape[2];

let hiddenState;
let cellState;
if (initialHiddenState) {
hiddenState = initialHiddenState;
} else {
const initialHiddenStateShape = [numDirections, batchSize, hiddenSize];
hiddenState = new Tensor(
initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0));
}
if (initialCellState) {
cellState = initialCellState;
} else {
const initialCellState = [numDirections, batchSize, hiddenSize];
cellState = new Tensor(
initialCellState, new Array(sizeOfShape(initialCellState)).fill(0));
}

let sequence;
const currentWeight = [];
const currentRecurrentWeight = [];
const currentBias = [];
const currentRecurrentBias = [];
const currentPeepholeWeight = [];

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize])));
currentRecurrentWeight.push(squeeze(slice(recurrentWeight,
[dir, 0, 0], [1, 4 * hiddenSize, hiddenSize])));
currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentRecurrentBias.push(recurrentBias ?
(squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentPeepholeWeight.push(peepholeWeight ?
(squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null);
}

for (let step = 0; step < steps; ++step) {
const currentHidden = [];
const currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}

for (let dir = 0; dir < numDirections; ++dir) {
const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize]));

const results = lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir],
layout: layout, activations: activations});

const output = reshape(results[0], [1, null, hiddenSize]);
const cell = reshape(results[1], [1, null, hiddenSize]);

nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? concat([nextCell, cell], 0) : cell);
}

hiddenState = nextHidden;
cellState = nextCell;

if (returnSequence) {
nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]);
sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden);
}
}

return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
}
Loading
Loading