You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for reaching out. Our team has a fix for this and it will be available in the upcoming releases. We will update this issue once the fix is released.
In the
ParallelEmbedding
layer, when sharding accross vocab, the output is masked at the very end of the operation.It seems that the masking is done by multiplying by an hard-coded
float
mask, which leads to the actualfloat16
/bfloat16
to be upcast tofloat32
.A correct implementation would be to multiply by a mask of the same type as the intended output.
The text was updated successfully, but these errors were encountered: