Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix computation in normalization layers #876

Merged
merged 4 commits into from
Oct 18, 2024

Conversation

AakashKumarNain
Copy link
Contributor

Irrespective of the dtype passed to the normalization layers, the calculations should be done with fp32 or higher dtype. Not doing that can cause instabilities in training which is too common when someone is training models in half-precision. Now, we can delegate this to the end user to do this explicitly, but it is an easy thing to miss. This PR takes care of that.

@patrick-kidger
Copy link
Owner

Looks like the current implementation doesn't pass under strict-dtype-promotion (see failing tests), but other than that this sounds good to me!

@AakashKumarNain
Copy link
Contributor Author

Thanks @patrick-kidger I have fixed it. Please let me know if you have any more suggestions on this one

@patrick-kidger
Copy link
Owner

Alright, this LGTM! Final question before I merge this, do you have a reference for needing higher precision for these operations? (As much for me to read up on as anything else :D )

@AakashKumarNain
Copy link
Contributor Author

I think this is pretty much enough: https://pytorch.org/docs/stable/amp.html#autocast-op-reference

@patrick-kidger patrick-kidger merged commit e2d7e38 into patrick-kidger:main Oct 18, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Alright, looks good to me! Thanks for spotting this, and for the reference. Merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants