Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong output fp16/bf16 dtype in ParallelEmbedding when sharding accross vocab #35

Open
dacorvo opened this issue Nov 20, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@dacorvo
Copy link

dacorvo commented Nov 20, 2024

In the ParallelEmbedding layer, when sharding accross vocab, the output is masked at the very end of the operation.

It seems that the masking is done by multiplying by an hard-coded float mask, which leads to the actual float16/bfloat16 to be upcast to float32.

A correct implementation would be to multiply by a mask of the same type as the intended output.

@fayyadd
Copy link
Contributor

fayyadd commented Nov 21, 2024

thanks for reporting! Our team will take a look.

@fayyadd fayyadd added the bug Something isn't working label Nov 21, 2024
@fayyadd
Copy link
Contributor

fayyadd commented Nov 21, 2024

Thanks for reaching out. Our team has a fix for this and it will be available in the upcoming releases. We will update this issue once the fix is released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants