-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optax.MultiSteps out of memory #472
Comments
Hi! Interesting - thanks for reporting this! Are you also at more than ~2/3 memory usage when you use I'm not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them). |
I have most of my available memory preallocated by JAX. I tried reducing the batch size from 120 (which works with apply_every) to 30, but it still crashed with MultiSteps. |
I am training Llama 2 7B on TPU. Without |
I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models. |
I also noticed this issue |
I am having this issue as well for use in diffusion models |
Facing the same issue. |
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449
Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561390202
Hi everyone, thanks for flagging it up. I just merged a new version of |
you're a king |
Hi @hbq1! Thank you for the fix! One question, I am still seeing a larger consumption with MultiStep when compared with the function apply_every. This was supposed to happen? |
As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):
If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads. Still trying to figure out a way to solve it, though. |
Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence. |
This is really great! |
Awesome work @celiolarcher!
|
I'm glad to be able to help! |
I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1.
Using optax.apply_every(k=1) in a chain works fine.
Later I'm using
opt_state = optimizer.init(params)
and
I have no idea what I could be doing wrong. I'm not changing anything else, like batch size.
The text was updated successfully, but these errors were encountered: