Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add is slow on GPU #31

Open
garymm opened this issue Jul 19, 2024 · 5 comments
Open

add is slow on GPU #31

garymm opened this issue Jul 19, 2024 · 5 comments

Comments

@garymm
Copy link
Contributor

garymm commented Jul 19, 2024

The benchmarks show adding to a buffer is very slow on GPU (13 ms vs 0.4 ms for reverb or stable baselines, over 30x slower). Has anyone filed a bug against Jax about this?

@EdanToledo
Copy link
Contributor

Hello, so we haven't specifically asked the JAX maintainers about this issue. However, important to note for reverb and stable baselines that the memory is not stored on the GPU so its not really a fair comparison and it would be better to look at the CPU times. If you find the GPU add times too slow and you're not doing a fully jitted training loop then you can ensure that the flashbax buffer is stored on the CPU.

@garymm
Copy link
Contributor Author

garymm commented Aug 27, 2024

Thanks for the reply. Even comparing to flashbax on TPU, it's much much slower on GPU so might be worth filing a bug about that with JAX? I'm assuming the source data is already on the GPU when you're adding?

@EdanToledo
Copy link
Contributor

Yes i believe so, i did the benchmarks a while ago but I'm sure i would have created the data on device.

@eadadi
Copy link

eadadi commented Sep 6, 2024

Hi, I would like to ask for more details regarding where we stand on in this situation.

  1. Currently, GPU speeds for adding single timesteps is bad? can we point where the delay happens? From this discussion I understand that it simply the jax operation that is used?
  2. For adding batch of timesteps we don't have these delays right?
  3. Is there anything we can do to improve the situation?

@sash-a
Copy link
Contributor

sash-a commented Oct 29, 2024

  1. Yes it's likely the underlying XLA that JAX is compiled to
  2. Less delays, but still could likely be better
  3. One possible improvement is that we could inform JAX that we are using unique_indices and indices_are_sorted during our .at[].set() over here for example. I'm not sure if this would help, but the docs imply that it might. Unfortunately I don't quite have time to test this right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants