You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.
We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing Mikolaj/mostly-harmless#20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.
I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.
I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.
The text was updated successfully, but these errors were encountered:
Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.
We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing Mikolaj/mostly-harmless#20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.
I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.
I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.
The text was updated successfully, but these errors were encountered: