Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement checkpointing #107

Open
Mikolaj opened this issue Sep 2, 2023 · 0 comments
Open

Implement checkpointing #107

Mikolaj opened this issue Sep 2, 2023 · 0 comments
Labels

Comments

@Mikolaj
Copy link
Owner

Mikolaj commented Sep 2, 2023

Try to implement checkpointing (inserting recomputation to trade-off computation vs memory use) and then automatic checkpointing, which is what pytorch/JAX users now reportedly need and can't get.

We have an old discussion starting with @tomjaguarpaw sketching an extension of the POPL paper with checkpointing Mikolaj/mostly-harmless#20. We also had two variants of (things related to) checkpoint implemented at some point due to a peak of popular interest, but it bit-rotted before anybody found it interesting again and before any benchmarks for it were written and was removed when horde-ad got simplified.

I wonder if in the current mode of operation where we do reverse differentiation symbolically instead on using the real inputs, the memory leaks problems posed in the discussion are gone. More generally, I wonder how checkpointing in the current mode would differ from what Tom describes and whether pytorch/JAX do checkpointing in both modes of operation.

I'd advise against implementing it again before we have an interest proven by tests and benchmarks written by the interested parties.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant