-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WMT slower in Pytorch than Jax #467
Comments
The main cause of the speed difference seems to be the update step, which is ~20% faster in Jax. The data loading is also faster in Jax, but the difference is insignificant in absolute terms. Juhan and me suspect that the model code is responsible for the slowdown. I will try to 1) use some new optimized functions for transformers from PyTorch 2 and 2) rewrite the masking to be compatible with |
FWIW I did try using a HF transformer implementation with boolean masks and that |
@runame I managed to
I don't have perf numbers yet since it feels stuck at
|
@msaroufim Thank you so much for looking into this! Yes, I will try this and see if there are any new issues. |
@msaroufim @pomonam Thanks a lot for investigating! I won't get to working on this before Sunday or Monday, will check this thread for updates then. |
@runame One thing I wanted to try out (but did not have time to do) was to use the PyTorch default functions (e.g., |
Oh interesting so you can confirm that it does compile but you're getting graph breaks? Are you running |
@msaroufim I just tried running with Branch
Without
|
Branch
|
Well it is promising that If I could get access to your VM to run experiments quickly that might help as well @priyakasimbeg - can chat more in person tomorrow EDIT: This is a legit error actually, FWIW i created an |
@msaroufim Is the |
Hmm it should also be usable by |
@msaroufim After the refactor in #489 the only remaining graph breaks are of this type:
After setting |
@runame I chatted with Will Constable about this and his point is the DDPOptimizer will always give you graph breaks so if you have it enabled you won't be able to do Do either you or @pomonam have a smaller repro of the linked error here #487 (comment) |
@msaroufim Not sure if it's useful, but I have created a smaller repro using the same model here. It runs successfully on a single GPU and fails with DDP. Update: I decided to not follow up on this because this issue is currently not blocking us. |
@msaroufim Since we don't know why we get basically no speed improvements from using
|
@runame @msaroufim Not sure if this is relevant anymore (if you're unblocked some other way) but this specific issue about 'hooks' from the DDPOptimizer should have been fixed on master by pytorch/pytorch#107834. |
@BoyuanFeng is working on this. Compiling the loss function in addition to the model seems to significantly speed up WMT. |
Resolved in #597 after torch.compiling the loss functions. |
WMT pytorch is currently slower than Jax.
This bug is intended to at least document possible causes.
@runame @pomonam could you please summarize current findings and possible solutions?
The text was updated successfully, but these errors were encountered: