From abfceb9acbc17794b729ae7ef99c8745abfa8bc0 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Wed, 28 Feb 2024 15:49:55 +0800 Subject: [PATCH 01/11] complete lstm_cell --- src/lib/validate-input.js | 67 ++++++++++++++++++ src/lstm_cell.js | 143 ++++++++++++++++++++++++++++++++++++++ test/lstm_cell_test.js | 57 +++++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 src/lstm_cell.js create mode 100644 test/lstm_cell_test.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index ca1e92c..448f525 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -133,6 +133,73 @@ export function validateGemmParams(a, b) { } } +export function validateLstmCellParams(input, weight, recurrentWeight, + hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) { + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 2) { + throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); + } + const batchSize = input.shape[0]; + const inputSize = input.shape[1]; + if (weight.rank !== 2) { + throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); + } + if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); + } + if (recurrentWeight.rank !== 2) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); + } + if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); + } + if (hiddenState.rank !== 2) { + throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); + } + if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { + throw new Error(`The shape of hiddenState + [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + } + if (cellState.rank !== 2) { + throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`); + } + if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) { + throw new Error(`The shape of cellState + [${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 1) { + throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (bias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 1) { + throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (recurrentBias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); + } + } + if (peepholeWeight) { + if (peepholeWeight.rank !== 1) { + throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (peepholeWeight.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`); + } + } + if (layout !== 'iofg' && layout !== 'ifgo') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, {bias, recurrentBias, layout = 'zrn'} = {}) { if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { diff --git a/src/lstm_cell.js b/src/lstm_cell.js new file mode 100644 index 0000000..bfafd46 --- /dev/null +++ b/src/lstm_cell.js @@ -0,0 +1,143 @@ +'use strict'; + +import {add, mul} from './binary.js'; +import {matmul} from './matmul.js'; +import {Scalar} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {tanh} from './tanh.js'; +import {transpose} from './transpose.js'; +import {validateLstmCellParams} from './lib/validate-input.js'; + +/** + * + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Tensor} hiddenState + * @param {Tensor} cellState + * @param {Number} hiddenSize + * @param {MLLstmCellOptions} options + * return {Tensor} + */ + +export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, + layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { + validateLstmCellParams(...arguments); + const zero = new Scalar(0); + const inputSize = input.shape[1]; + const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} : + {i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize}; + const activation0 = activations[0]; + const activation1 = activations[1]; + const activation2 = activations[2]; + + // input gate (i) + const i = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.i], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.i], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.i], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.i, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.i, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // forget gate (f) + const f = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.f], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.f], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.f], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.f, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose( + slice(recurrentWeight, [starts.f, 0], [hiddenSize, hiddenSize]), + ), + ), + ), + ), + ), + ); + + // cell gate (g) + const g = activation1( + add( + add( + (bias ? slice(bias, [starts.g], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.g], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.g, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.g, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ); + + // output gate (o) + const o = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.o], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.o], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.o], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.o, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.o, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // output cell state (ct) + const ct = add(mul(f, cellState), mul(i, g)); + + // output hidden state (ht) + const ht = mul(o, activation2(ct)); + + return [ht, ct]; +} diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js new file mode 100644 index 0000000..f84cde0 --- /dev/null +++ b/test/lstm_cell_test.js @@ -0,0 +1,57 @@ +'use strict'; + +import {lstmCell} from '../src/lstm_cell.js'; +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test lstmCell', function() { + it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() { + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([batchSize, inputSize], [1, 2, 2, 1]); + const weight = new Tensor([4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([4 * hiddenSize, hiddenSize], + new Float32Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const cellState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const bias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([3* hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstmCell( + input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, activations}); + utils.checkShape(outputs[0], [batchSize, hiddenSize]); + utils.checkShape(outputs[1], [batchSize, hiddenSize]); + const expected = [ + [ + 1, 8, 27, 216, + ], + [ + 1, 4, 9, 36, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); +}); From 9a7e0a486fe07165dbb91d9f4702734817a89d28 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Wed, 28 Feb 2024 16:17:05 +0800 Subject: [PATCH 02/11] revised lstm_cell.js and lstm_cell_test.js --- src/lstm_cell.js | 12 +++++++----- test/lstm_cell_test.js | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lstm_cell.js b/src/lstm_cell.js index bfafd46..3d07d11 100644 --- a/src/lstm_cell.js +++ b/src/lstm_cell.js @@ -10,7 +10,10 @@ import {transpose} from './transpose.js'; import {validateLstmCellParams} from './lib/validate-input.js'; /** - * + *A single time step of the Long Short-Term Memory [LSTM] recurrent network + *using a cell state, an input, output, and forget gate to compute the cell + *state and the hidden state of the next time step that rolls into the output + *across the temporal sequence of the network. * @param {Tensor} input * @param {Tensor} weight * @param {Tensor} recurrentWeight @@ -18,17 +21,16 @@ import {validateLstmCellParams} from './lib/validate-input.js'; * @param {Tensor} cellState * @param {Number} hiddenSize * @param {MLLstmCellOptions} options - * return {Tensor} + * @return {Tensor} */ - export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, {bias, recurrentBias, peepholeWeight, layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { validateLstmCellParams(...arguments); const zero = new Scalar(0); const inputSize = input.shape[1]; - const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} : - {i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize}; + const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : + {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; const activation0 = activations[0]; const activation1 = activations[1]; const activation2 = activations[2]; diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js index f84cde0..b57ffa8 100644 --- a/test/lstm_cell_test.js +++ b/test/lstm_cell_test.js @@ -6,7 +6,7 @@ import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test lstmCell', function() { - it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() { + it('lstmCell lstmCell activations=[relu, relu, relu]', function() { const batchSize = 2; const inputSize = 2; const hiddenSize = 2; From 24d6d030734bd2fd3b392d152fa3c424b6157896 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Fri, 1 Mar 2024 13:11:32 +0800 Subject: [PATCH 03/11] completed lstm --- src/lib/validate-input.js | 95 +++++++++++++++++++++++++++ src/lstm.js | 108 +++++++++++++++++++++++++++++++ test/lstm_test.js | 131 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 src/lstm.js create mode 100644 test/lstm_test.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 448f525..ae9bb8d 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -200,6 +200,101 @@ export function validateLstmCellParams(input, weight, recurrentWeight, } } +export function validateLstmParams(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, + direction = 'forward', layout = 'iofg'}) { + if (!Number.isInteger(steps) || steps <= 0) { + throw new Error(`The steps ${steps} is invalid.`); + } + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 3) { + throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`); + } + if (input.shape[0] !== steps) { + throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`); + } + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') { + throw new Error(`The direction ${direction} is invalid.`); + } + const numDirections = (direction === 'both' ? 2 : 1); + if (weight.rank !== 3) { + throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`); + } + if (weight.shape[0] !== numDirections || weight.shape[1] !== 4 * hiddenSize || + weight.shape[2] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}, + ${weight.shape[2]}] is invalid.`); + } + if (recurrentWeight.rank !== 3) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`); + } + if (recurrentWeight.shape[0] !== numDirections || + recurrentWeight.shape[1] !== 4 * hiddenSize || + recurrentWeight.shape[2] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` + + `${recurrentWeight.shape[2]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 2) { + throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`); + } + if (bias.shape[0] !== numDirections || bias.shape[1] !== 4 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 2) { + throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`); + } + if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 4 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}, + ${recurrentBias.shape[1]}] is invalid.`); + } + } + if (peepholeWeight) { + if (peepholeWeight.rank !== 2) { + throw new Error(`The peepholeWeight (rank ${peepholeWeight.rank}) is not a 2-D tensor.`); + } + if (peepholeWeight.shape[0] !== numDirections || peepholeWeight.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}, + ${peepholeWeight.shape[1]}] is invalid.`); + } + } + if (initialHiddenState) { + if (initialHiddenState.rank !== 3) { + throw new Error( + `The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`); + } + if (initialHiddenState.shape[0] !== numDirections || + initialHiddenState.shape[1] !== batchSize || + initialHiddenState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]}, + ${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`); + } + } + if (initialCellState) { + if (initialCellState.rank !== 3) { + throw new Error( + `The initialCellState (rank ${initialCellState.rank}) is not a 3-D tensor.`); + } + if (initialCellState.shape[0] !== numDirections || + initialCellState.shape[1] !== batchSize || + initialCellState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialCellState [${initialCellState.shape[0]}, + ${initialCellState.shape[1]}, ${initialCellState.shape[2]}] is invalid.`); + } + } + if (layout !== 'iofg' && layout !== 'ifgo') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + + export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, {bias, recurrentBias, layout = 'zrn'} = {}) { if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { diff --git a/src/lstm.js b/src/lstm.js new file mode 100644 index 0000000..3c0b5f6 --- /dev/null +++ b/src/lstm.js @@ -0,0 +1,108 @@ +'use strict'; + +import {concat} from './concat.js'; +import {lstmCell} from './lstm_cell.js'; +import {reshape} from './reshape.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {squeeze} from './squeeze.js'; +import {tanh} from './tanh.js'; +import {validateLstmParams} from './lib/validate-input.js'; + +/** + *Long Short-Term Memory [LSTM] recurrent network uses an input, output, forget, + *and cell gate to compute the output state that rolls into the output across the + * temporal sequence of the network. + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Number} steps + * @param {Number} hiddenSize + * @param {MLLstmOptions} options + * @return {Array.} + */ + +export function lstm(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg', + activations = [sigmoid, tanh, tanh]}={}) { + validateLstmParams(...arguments); + const numDirections = (direction == 'both' ? 2 : 1); + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + + let hiddenState; + let cellState; + if (initialHiddenState) { + hiddenState = initialHiddenState; + } else { + const initialHiddenStateShape = [numDirections, batchSize, hiddenSize]; + hiddenState = new Tensor( + initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0)); + } + if (initialCellState) { + cellState = initialCellState; + } else { + const initialCellState = [numDirections, batchSize, hiddenSize]; + cellState = new Tensor( + initialCellState, new Array(sizeOfShape(initialCellState)).fill(0)); + } + + let sequence; + const currentWeight = []; + const currentRecurrentWeight = []; + const currentBias = []; + const currentRecurrentBias = []; + const currentPeepholeWeight = []; + + for (let dir = 0; dir < numDirections; ++dir) { + currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize]))); + currentRecurrentWeight.push(squeeze(slice(recurrentWeight, + [dir, 0, 0], [1, 4 * hiddenSize, hiddenSize]))); + currentBias.push(bias ? (squeeze(slice(bias, [dir, 0], [1, 4 * hiddenSize]))) : null); + currentRecurrentBias.push(recurrentBias ? + (squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null); + currentPeepholeWeight.push(peepholeWeight ? + (squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null); + } + + for (let step = 0; step < steps; ++step) { + const currentHidden = []; + const currentCell = []; + let nextHidden = null; + let nextCell = null; + + for (let dir = 0; dir < numDirections; ++dir) { + currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]))); + currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]))); + } + + for (let dir = 0; dir < numDirections; ++dir) { + const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step); + const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize])); + + const results = lstmCell( + currentInput, currentWeight[dir], currentRecurrentWeight[dir], + currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir], + recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir], + layout: layout, activations: activations}); + + const output = reshape(results[0], [1, null, hiddenSize]); + const cell = reshape(results[1], [1, null, hiddenSize]); + + nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output); + nextCell = (nextCell ? concat([nextCell, cell], 0) : cell); + } + + hiddenState = nextHidden; + cellState = nextCell; + + if (returnSequence) { + nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]); + sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden); + } + } + + return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]); +} diff --git a/test/lstm_test.js b/test/lstm_test.js new file mode 100644 index 0000000..c03d885 --- /dev/null +++ b/test/lstm_test.js @@ -0,0 +1,131 @@ +'use strict'; + +import {lstm} from '../src/lstm.js'; +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test lstm', function() { + it('lstm returnSequence=true ' + + 'activations=[relu, relu, relu]', function() { + const steps = 1; + const numDirections = 1; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], new Float32Array([1, 2, 2, 1])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const initialCellState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const returnSequence = true; + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, returnSequence, activations}); + console.log('outputs: ', outputs); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [ + 1, 8, 27, 216, + ], + [ + 1, 4, 9, 36, + ], + [ + 1, 8, 27, 216, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); + + it('lstm steps=2 direction="backward" returnSequence=true' + + 'activations=[relu, relu, relu]', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([steps, batchSize, inputSize], + new Float32Array([1, 2, 2, 1, 3, 4, 1, 2])); + const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize], + new Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([numDirections, 4 * hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([numDirections, 3 * hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const initialCellState = new Tensor([numDirections, batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const returnSequence = true; + const direction = 'backward'; + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstm( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, peepholeWeight, initialHiddenState, + initialCellState, direction, returnSequence, activations}); + console.log('outputs: ', outputs); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [10.469, 58.02899999999999, 74.529, 518.9490000000001], + [5.51, 20.009999999999998, 19.11, 75.21000000000001], + [ + 1, + 8, + 1, + 8, + 10.469, + 58.02899999999999, + 74.529, + 518.9490000000001, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); +}); From 8d6e71e8b3d8f7e9987272a8a0f077d8e9294bd9 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Fri, 1 Mar 2024 16:55:47 +0800 Subject: [PATCH 04/11] revised lstm.js lstm_cell.js and validate-input.js --- src/lib/validate-input.js | 6 ++---- src/lstm.js | 3 +-- src/lstm_cell.js | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index ae9bb8d..6a4611f 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -168,8 +168,7 @@ export function validateLstmCellParams(input, weight, recurrentWeight, throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`); } if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) { - throw new Error(`The shape of cellState - [${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`); + throw new Error(`The shape of cellState [${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`); } if (bias) { if (bias.rank !== 1) { @@ -322,8 +321,7 @@ export function validateGruCellParams(input, weight, recurrentWeight, hiddenStat throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); } if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { - throw new Error(`The shape of hiddenState - [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + throw new Error(`The shape of hiddenState [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); } if (bias) { if (bias.rank !== 1) { diff --git a/src/lstm.js b/src/lstm.js index 3c0b5f6..648b1b1 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -22,11 +22,10 @@ import {validateLstmParams} from './lib/validate-input.js'; * @param {MLLstmOptions} options * @return {Array.} */ - export function lstm(input, weight, recurrentWeight, steps, hiddenSize, {bias, recurrentBias, peepholeWeight, initialHiddenState, initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg', - activations = [sigmoid, tanh, tanh]}={}) { + activations = [sigmoid, tanh, tanh]} = {}) { validateLstmParams(...arguments); const numDirections = (direction == 'both' ? 2 : 1); const batchSize = input.shape[1]; diff --git a/src/lstm_cell.js b/src/lstm_cell.js index 3d07d11..7a52374 100644 --- a/src/lstm_cell.js +++ b/src/lstm_cell.js @@ -29,8 +29,8 @@ export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, validateLstmCellParams(...arguments); const zero = new Scalar(0); const inputSize = input.shape[1]; - const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : - {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; + const starts = (layout === 'iofg') ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 * hiddenSize} : + {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; const activation0 = activations[0]; const activation1 = activations[1]; const activation2 = activations[2]; From 95c0ab178119f1ab2a4379387116c936c1f75bcd Mon Sep 17 00:00:00 2001 From: mei1127 Date: Fri, 1 Mar 2024 16:59:13 +0800 Subject: [PATCH 05/11] revised validata-input.js to ensure length Date: Mon, 4 Mar 2024 11:49:09 +0800 Subject: [PATCH 06/11] revised some files related to squeeze --- src/gru.js | 3 +-- src/lstm.js | 3 +-- src/reshape.js | 17 +++++++++++++++++ src/squeeze.js | 20 -------------------- test/squeeze_test.js | 35 ----------------------------------- 5 files changed, 19 insertions(+), 59 deletions(-) delete mode 100644 src/squeeze.js delete mode 100644 test/squeeze_test.js diff --git a/src/gru.js b/src/gru.js index 80583c0..cf93319 100644 --- a/src/gru.js +++ b/src/gru.js @@ -2,11 +2,10 @@ import {concat} from './concat.js'; import {gruCell} from './gru_cell.js'; -import {reshape} from './reshape.js'; +import {reshape, squeeze} from './reshape.js'; import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; import {slice} from './slice.js'; -import {squeeze} from './squeeze.js'; import {tanh} from './tanh.js'; import {validateGruParams} from './lib/validate-input.js'; diff --git a/src/lstm.js b/src/lstm.js index 648b1b1..99474d7 100644 --- a/src/lstm.js +++ b/src/lstm.js @@ -2,11 +2,10 @@ import {concat} from './concat.js'; import {lstmCell} from './lstm_cell.js'; -import {reshape} from './reshape.js'; +import {reshape, squeeze} from './reshape.js'; import {sizeOfShape, Tensor} from './lib/tensor.js'; import {sigmoid} from './sigmoid.js'; import {slice} from './slice.js'; -import {squeeze} from './squeeze.js'; import {tanh} from './tanh.js'; import {validateLstmParams} from './lib/validate-input.js'; diff --git a/src/reshape.js b/src/reshape.js index 4cb0ffc..c4c3f49 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -1,6 +1,7 @@ 'use strict'; import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateSqueezeParams} from './lib/validate-input.js'; /** * Alter the shape of a tensor to a new shape. @@ -31,3 +32,19 @@ export function reshape(input, newShape) { const output = new Tensor(outputShape, input.data); return output; } + +/** + * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. + * @param {Tensor} input + * @param {MLSqueezeOptions} options + * @return {Tensor} + */ +export function squeeze(input, {axes} = {}) { + validateSqueezeParams(...arguments); + const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + + const outputShape = input.shape.filter((dim, axis) => + !(dim === 1 && inpAxes.indexOf(axis) !== -1)); + const output = reshape(input, outputShape); + return output; +} diff --git a/src/squeeze.js b/src/squeeze.js deleted file mode 100644 index a463f79..0000000 --- a/src/squeeze.js +++ /dev/null @@ -1,20 +0,0 @@ -'use strict'; - -import {reshape} from './reshape.js'; -import {validateSqueezeParams} from './lib/validate-input.js'; - -/** - * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. - * @param {Tensor} input - * @param {MLSqueezeOptions} options - * @return {Tensor} - */ -export function squeeze(input, {axes} = {}) { - validateSqueezeParams(...arguments); - const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); - - const outputShape = input.shape.filter((dim, axis) => - !(dim === 1 && inpAxes.indexOf(axis) !== -1)); - const output = reshape(input, outputShape); - return output; -} diff --git a/test/squeeze_test.js b/test/squeeze_test.js deleted file mode 100644 index e2103be..0000000 --- a/test/squeeze_test.js +++ /dev/null @@ -1,35 +0,0 @@ -'use strict'; - -import {squeeze} from '../src/squeeze.js'; -import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; -import * as utils from './utils.js'; - -describe('test squeeze', function() { - function testSqueeze(oldShape, axes, expectedShape) { - const bufferSize = sizeOfShape(oldShape); - const inputBuffer = new Array(bufferSize); - for (let i = 0; i < inputBuffer.length; ++i) { - inputBuffer[i] = Math.random(); - } - const x = new Tensor(oldShape, inputBuffer); - const y = squeeze(x, {axes}); - utils.checkShape(y, expectedShape); - utils.checkValue(y, inputBuffer); - } - - it('squeeze one dimension by default', function() { - testSqueeze([1, 3, 4, 5], undefined, [3, 4, 5]); - }); - - it('squeeze one dimension with axes', function() { - testSqueeze([1, 3, 1, 5], [0], [3, 1, 5]); - }); - - it('squeeze two dimensions by default', function() { - testSqueeze([1, 3, 1, 5], undefined, [3, 5]); - }); - - it('squeeze two dimensions with axes', function() { - testSqueeze([1, 3, 1, 5], [0, 2], [3, 5]); - }); -}); From bda1651d1d9423918c017a6a436df1ba7659d233 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Mon, 4 Mar 2024 11:57:25 +0800 Subject: [PATCH 07/11] revised reduce.js enter the commit message for your changes. Lines starting --- src/reduce.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/reduce.js b/src/reduce.js index 390fa47..3468be4 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -1,7 +1,7 @@ 'use strict'; import {pow} from './binary.js'; -import {squeeze} from './squeeze.js'; +import {squeeze} from './reshape.js'; import {abs, exp, log} from './unary.js'; import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; import {validateReduceParams} from './lib/validate-input.js'; From c451f9115b950db93982c591e1c05391869e7456 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Mon, 4 Mar 2024 12:06:19 +0800 Subject: [PATCH 08/11] revised arg_max_min.js --- src/arg_max_min.js | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 src/arg_max_min.js diff --git a/src/arg_max_min.js b/src/arg_max_min.js new file mode 100644 index 0000000..cb5fe70 --- /dev/null +++ b/src/arg_max_min.js @@ -0,0 +1,68 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js'; +import {squeeze} from './reshape.js'; + +/** + * Get the index location of the minimum or maxmium values of all the input values along the axes. + * @param {Tensor} input + * @param {Function} reduceFunc + * @param {MLArgMinMaxOptions} [options] + * @return {Tensor} + */ +export function argMaxMin( + input, + reduceFunc, + { + axes = null, + keepDimensions = false, + selectLastIndex = false, + } = {}) { + // If axes aren't present (defaulting to null), all dimensions are reduced. + // See https://webmachinelearning.github.io/webnn/#dom-mlargminmaxoptions-axes. + const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + const outputShape = input.shape.slice(); + + for (let i = 0; i < inputAxes.length; ++i) { + outputShape[inputAxes[i]] = 1; + } + + let output = new Tensor(outputShape); + const tensor = reduceFunc(input, {axes: inputAxes, keepDimensions: true}); + + for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) { + const value = tensor.getValueByIndex(outputIndex); + const inputLocation = output.locationFromIndex(outputIndex); + const selectedArray = selectValuesToReduce(input, inputAxes, inputLocation); + const index = + selectLastIndex ? selectedArray.lastIndexOf(value) : selectedArray.indexOf(value); + output.setValueByIndex(outputIndex, index); + } + + if (!keepDimensions) { + output = squeeze(output, {axes}); + } + + return output; +} + +/** + * Get the index location of the maxmium values of all the input values along the axes. + * @param {Tensor} input + * @param {MLArgMinMaxOptions} [options] + * @return {Tensor} + */ +export function argMax(input, options = {}) { + return argMaxMin(input, reduceMax, options); +} + +/** + * Get the index location of the minimum values of all the input values along the axes. + * @param {Tensor} input + * @param {MLArgMinMaxOptions} [options] + * @return {Tensor} + */ +export function argMin(input, options = {}) { + return argMaxMin(input, reduceMin, options); +} From 190fe607182ac867d008bd8d96cce09a8ad1a25c Mon Sep 17 00:00:00 2001 From: mei1127 Date: Mon, 4 Mar 2024 13:59:10 +0800 Subject: [PATCH 09/11] revised reshape.js --- src/reshape.js | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/reshape.js b/src/reshape.js index c4c3f49..a25335f 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -41,10 +41,9 @@ export function reshape(input, newShape) { */ export function squeeze(input, {axes} = {}) { validateSqueezeParams(...arguments); - const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); - + const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); const outputShape = input.shape.filter((dim, axis) => - !(dim === 1 && inpAxes.indexOf(axis) !== -1)); + !(dim === 1 && inputAxes.indexOf(axis) !== -1)); const output = reshape(input, outputShape); return output; } From d8ebc9386bc70864b8a6bf736f1c643d0f66af20 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Mon, 4 Mar 2024 14:02:46 +0800 Subject: [PATCH 10/11] revised reshape.js --- src/reshape.js | 1 + 1 file changed, 1 insertion(+) diff --git a/src/reshape.js b/src/reshape.js index a25335f..ac2323d 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -42,6 +42,7 @@ export function reshape(input, newShape) { export function squeeze(input, {axes} = {}) { validateSqueezeParams(...arguments); const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + const outputShape = input.shape.filter((dim, axis) => !(dim === 1 && inputAxes.indexOf(axis) !== -1)); const output = reshape(input, outputShape); From 5b098bcb0228f142247105b15c6c77287c2eade5 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Mon, 4 Mar 2024 14:51:28 +0800 Subject: [PATCH 11/11] revised reshape.js --- src/reshape.js | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/reshape.js b/src/reshape.js index ac2323d..112b567 100644 --- a/src/reshape.js +++ b/src/reshape.js @@ -41,8 +41,7 @@ export function reshape(input, newShape) { */ export function squeeze(input, {axes} = {}) { validateSqueezeParams(...arguments); - const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); - + const inputAxes = axes ?? Array.from({length: input.rank}, (_, i) => i); const outputShape = input.shape.filter((dim, axis) => !(dim === 1 && inputAxes.indexOf(axis) !== -1)); const output = reshape(input, outputShape);