-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
register fused rmsnorm as pytorch custom op #296
base: gh/tianyu-l/11/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
ghstack-source-id: 401d968feaa2e58eedb573c07739694358a8d4a6 Pull Request resolved: #296
ghstack-source-id: 401d968feaa2e58eedb573c07739694358a8d4a6 Pull Request resolved: #296
17cda29
to
e34d2ac
Compare
1959d5f
to
c3daa9a
Compare
n00b but is there any benefit in making a triton function a custom op? User defined triton functions should just work with compile |
@msaroufim Hmm I don't know much. Maybe it depends on the way the triton function is wrapped? In this PR there is an autograd.function wrapping the triton function. cc: @Chillee @lessw2020 for more context. Beyond compile, making it a custom op can be helpful for other things, e.g. registering customized DTensor sharding propagation rules, allow PP tracing to work (although we no longer have it in torchtitan) etc. |
should we perhaps just remove this custom kernel? I believe in pytorch nightlies rms norm should now work |
@msaroufim oh which rms norm are you referring to? |
Stack from ghstack (oldest at bottom):