-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do mnist-distributed with checkpointing? #9
Comments
i think the idea is to call this demo_checkpoint function every X epochs (where X is most likely 1), but of course the example is wrong since you don't want to be doing these at the end of every X epochs.
Anyway, I don't think there's a need to reload the model from the saved checkpoint, since each time we call loss.backward() or some AllReduce function, the models are synchronized. Essentially, you can just save the model normally. I save model.module.state_dict()
The idea of using Rank == 0 is just to save time since all processes share the same model parameters, you can just save once from the main process. Basically, whatever statistic printing/logging you want to do, you can put it all under rank == 0. So far that's how I have been doing it and didn't experience any problems. If you want to know about coordinating validation statistics, you do something like this:
dist.all_reduce coordinates and communicates the same tensor across all processes using the specified operation, in this case, sum. This way, we ensure epoch_val_size is the same across every GPU |
I saw the tutorial (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints):
but as you said the tutorial is not very well written or missing or something. I was wondering if you could extend your tutorial with checkpointing?
I am personally interested only in processing each batch quicker by using multiprocessing. So what confuses me is why the code above not simply just save the model once training is done (but instead saves it when rank==0 before training starts). As you said, its confusing. Extending your mnist-example so after I process all the data in mnist and then I can save my model would be fantastic or saving every X number of epochs as it's the common case.
Btw, thanks for your example, it is fantastic!
The text was updated successfully, but these errors were encountered: