Large Model output (bsz x seq_len x vocab_size) repeated 3 times in memory with jax.value_and_grad(). #4097
Unanswered
saschafrey
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm training a model with a custom vocabulary of a few thousand tokens. Of course, this brings with it certain memory constraints. What I am struggling with, is why the expected output of the model (see title) is being represented 3 times in the GPU memory during my training step. Is there a way to work around this? As it does impact the model/batch size I can load onto a given GPU quite significantly. TB-profiler traces and memory views are attached, the XLA OPS which produce the 3 large tensors are highlighted in red. Happy to provide more details.
I have already tried using the jax.checkpoint() decorator on modules other than the final output layer ("decoder"), with no major effect.
Beta Was this translation helpful? Give feedback.
All reactions