Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completed lstm #71

Merged
merged 13 commits into from
Apr 16, 2024
3 changes: 2 additions & 1 deletion src/arg_max_min.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {reduceMax, reduceMin, selectValuesToReduce} from './reduce.js';
import {squeeze} from './squeeze.js';
import {squeeze} from './reshape.js';


/**
* Get the index location of the minimum or maxmium values of all the input values along the axes.
Expand Down
3 changes: 1 addition & 2 deletions src/gru.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import {concat} from './concat.js';
import {gruCell} from './gru_cell.js';
import {reshape} from './reshape.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
import {tanh} from './tanh.js';
import {validateGruParams} from './lib/validate-input.js';

Expand Down
3 changes: 1 addition & 2 deletions src/lstm.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import {concat} from './concat.js';
import {lstmCell} from './lstm_cell.js';
import {reshape} from './reshape.js';
import {reshape, squeeze} from './reshape.js';
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
import {tanh} from './tanh.js';
import {validateLstmParams} from './lib/validate-input.js';

Expand Down
2 changes: 1 addition & 1 deletion src/reduce.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'use strict';

import {pow} from './binary.js';
import {squeeze} from './squeeze.js';
import {squeeze} from './reshape.js';
import {abs, exp, log} from './unary.js';
import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js';
import {validateReduceParams} from './lib/validate-input.js';
Expand Down
17 changes: 17 additions & 0 deletions src/reshape.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateSqueezeParams} from './lib/validate-input.js';

/**
* Alter the shape of a tensor to a new shape.
Expand Down Expand Up @@ -31,3 +32,19 @@ export function reshape(input, newShape) {
const output = new Tensor(outputShape, input.data);
return output;
}

/**
* Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape.
* @param {Tensor} input
* @param {MLSqueezeOptions} options
* @return {Tensor}
*/
export function squeeze(input, {axes} = {}) {
validateSqueezeParams(...arguments);
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);

Whole word identifiers are more readable for others later (and it's consistent with input.rank rather than inp.rank).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const inputAxes = axes ?? new Array.from({length: input.rank}, (_, i) => i);

Hmm, kinda confusing way of initializing a sequence by chaining multiple methods together (took me a bit to figure out), rather than just saying Array.from({length: input.rank}, (_, i) => i), which is longer character-wise, but it's clearer intention-wise (and probably more performant than a fill which is overwritten by a map). This one surprising case where C++ is shorter :b std::ranges::iota(inputAxes, 0).


const outputShape = input.shape.filter((dim, axis) =>
!(dim === 1 && inpAxes.indexOf(axis) !== -1));
const output = reshape(input, outputShape);
return output;
}
20 changes: 0 additions & 20 deletions src/squeeze.js

This file was deleted.

35 changes: 0 additions & 35 deletions test/squeeze_test.js

This file was deleted.

Loading