-
Hello, I was reading https://github.com/google/flax/tree/master/linen_examples/seq2seq/#L312 and noticed that it uses jax.mask, but when I looked that up in the jax API documentation I couldn't find it. I did find an undocumented implementation here: https://jax.readthedocs.io/en/latest/_modules/jax/api.html?highlight=mask# Does anyone have a sense of the mask function is doing? And/or what role is has in this example? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
EDIT: It seems There is not much information about
As I understand it,
Specifically for the seq2seq example, your input flax/linen_examples/seq2seq/train.py Lines 308 to 320 in 8092211 Considering Hope this helps! |
Beta Was this translation helpful? Give feedback.
EDIT: It seems
jax.mask
is being redesigned and its use is currently not recommended as per this comment.There is not much information about
jax.mask
other than some comment. I guess it is still a WIP, and there are still some rough edges. The tests have some basic examples, but seq2seq is the only fully working example I know of.As I understand it,
jax.mask
is a way to circumvent static shape limitation in XLA, managing to ignore padded cells in the computation. It has similar function as TF and PyTorch masking, but with a different implementation.Specifically for the seq2seq example, …