Skip to content

Use of jax.mask in seq2seq linen example #820

Answered by myagues
iislucas asked this question in Q&A
Discussion options

You must be logged in to vote

EDIT: It seems jax.mask is being redesigned and its use is currently not recommended as per this comment.

There is not much information about jax.mask other than some comment. I guess it is still a WIP, and there are still some rough edges. The tests have some basic examples, but seq2seq is the only fully working example I know of.

Does anyone have a sense of the mask function is doing?

As I understand it, jax.mask is a way to circumvent static shape limitation in XLA, managing to ignore padded cells in the computation. It has similar function as TF and PyTorch masking, but with a different implementation.

And/or what role is has in this example?

Specifically for the seq2seq example, …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@iislucas
Comment options

Answer selected by iislucas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants