-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add is slow on GPU #31
Comments
Hello, so we haven't specifically asked the JAX maintainers about this issue. However, important to note for reverb and stable baselines that the memory is not stored on the GPU so its not really a fair comparison and it would be better to look at the CPU times. If you find the GPU add times too slow and you're not doing a fully jitted training loop then you can ensure that the flashbax buffer is stored on the CPU. |
Thanks for the reply. Even comparing to flashbax on TPU, it's much much slower on GPU so might be worth filing a bug about that with JAX? I'm assuming the source data is already on the GPU when you're adding? |
Yes i believe so, i did the benchmarks a while ago but I'm sure i would have created the data on device. |
Hi, I would like to ask for more details regarding where we stand on in this situation.
|
|
The benchmarks show adding to a buffer is very slow on GPU (13 ms vs 0.4 ms for reverb or stable baselines, over 30x slower). Has anyone filed a bug against Jax about this?
The text was updated successfully, but these errors were encountered: