-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
completed lstm #71
completed lstm #71
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @mei1127.
src/lstm.js
Outdated
import {sizeOfShape, Tensor} from './lib/tensor.js'; | ||
import {sigmoid} from './sigmoid.js'; | ||
import {slice} from './slice.js'; | ||
import {squeeze} from './squeeze.js'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now squeeze op has been removed, would you please help also remove it from this WebNN Baseline.
You can refer to the given squeeze method in Spec, thanks.
function squeeze(builder, op) {
return builder.reshape(op, op.shape().remove(0));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition is:
function squeeze(input, axes) {
if (!axes) axes = [];
if (!axes.length)
input.shape().forEach((item, i) => { axes.push(i); });
shape = Array.from(input.shape());
for (let axis in axes.sort().reverse())
if (axis < shape.length && shape[axis] == 1)
shape.splice(axis, 1);
return builder.reshape(input, shape);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will revise it next week:)
src/lstm.js
Outdated
* @param {MLLstmOptions} options | ||
* @return {Array.<Tensor>} | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete this blank line.
src/lstm.js
Outdated
export function lstm(input, weight, recurrentWeight, steps, hiddenSize, | ||
{bias, recurrentBias, peepholeWeight, initialHiddenState, | ||
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg', | ||
activations = [sigmoid, tanh, tanh]}={}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
activations = [sigmoid, tanh, tanh]}={}) { | |
activations = [sigmoid, tanh, tanh]} = {}) { |
src/lib/validate-input.js
Outdated
throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`); | ||
} | ||
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) { | ||
throw new Error(`The shape of cellState |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message will be split :( :
The shape of cellState
[2, 3] is invalid.
src/lstm_cell.js
Outdated
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : | ||
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be easier for future readers to visual parse if the indentation was aligned:
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : | |
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; | |
const starts = (layout === 'iofg') ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 * hiddenSize} : | |
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; |
); | ||
|
||
// forget gate (f) | ||
const f = activation0( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cool how easy it is to compose these from existing operators :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good after some small comments. This is a complex operator! Thanks for adding it. 🙂
enter the commit message for your changes. Lines starting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor comments. Functionally looks correct. TY Mei.
src/reshape.js
Outdated
*/ | ||
export function squeeze(input, {axes} = {}) { | ||
validateSqueezeParams(...arguments); | ||
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); | |
const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); |
Whole word identifiers are more readable for others later (and it's consistent with input.rank
rather than inp.rank
).
src/reshape.js
Outdated
*/ | ||
export function squeeze(input, {axes} = {}) { | ||
validateSqueezeParams(...arguments); | ||
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); | |
const inputAxes = axes ?? new Array.from({length: input.rank}, (_, i) => i); |
Hmm, kinda confusing way of initializing a sequence by chaining multiple methods together (took me a bit to figure out), rather than just saying Array.from({length: input.rank}, (_, i) => i)
, which is longer character-wise, but it's clearer intention-wise (and probably more performant than a fill which is overwritten by a map). This one surprising case where C++ is shorter :b std::ranges::iota(inputAxes, 0)
.
@huningxin @fdwr @shiyi9801 @BruceDai PTAL, thanks!