Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do mnist-distributed with checkpointing? #9

Open
brando90 opened this issue Feb 18, 2021 · 1 comment
Open

How to do mnist-distributed with checkpointing? #9

brando90 opened this issue Feb 18, 2021 · 1 comment

Comments

@brando90
Copy link

brando90 commented Feb 18, 2021

I saw the tutorial (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints):

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

but as you said the tutorial is not very well written or missing or something. I was wondering if you could extend your tutorial with checkpointing?

I am personally interested only in processing each batch quicker by using multiprocessing. So what confuses me is why the code above not simply just save the model once training is done (but instead saves it when rank==0 before training starts). As you said, its confusing. Extending your mnist-example so after I process all the data in mnist and then I can save my model would be fantastic or saving every X number of epochs as it's the common case.

Btw, thanks for your example, it is fantastic!

@linminhtoo
Copy link

linminhtoo commented Feb 19, 2021

i think the idea is to call this demo_checkpoint function every X epochs (where X is most likely 1), but of course the example is wrong since you don't want to be doing these at the end of every X epochs.

model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

Anyway, I don't think there's a need to reload the model from the saved checkpoint, since each time we call loss.backward() or some AllReduce function, the models are synchronized.

Essentially, you can just save the model normally. I save model.module.state_dict()

    def _checkpoint_model_and_opt(self, current_epoch: int):
       # call this function from rank == 0 after 1 epoch of training
        if self.dataparallel or self.gpu is not None:
            model_state_dict = self.model.module.state_dict()
        else:
            model_state_dict = self.model.state_dict()
        checkpoint_dict = {
            "epoch": current_epoch,  # epochs are 0-indexed
            "model_name": self.model_name,
            "state_dict": model_state_dict,
            "optimizer": self.optimizer.state_dict(),
            "stats": self.stats,
        }
        checkpoint_filename = (
            self.checkpoint_folder
            / f"{self.model_name}_{self.expt_name}_checkpoint_{current_epoch:04d}.pth.tar"
        )
        torch.save(checkpoint_dict, checkpoint_filename)

The idea of using Rank == 0 is just to save time since all processes share the same model parameters, you can just save once from the main process. Basically, whatever statistic printing/logging you want to do, you can put it all under rank == 0.

So far that's how I have been doing it and didn't experience any problems.

If you want to know about coordinating validation statistics, you do something like this:

                    val_batch_size = batch_data.shape[0]
                    val_batch_size = torch.tensor([val_batch_size]).cuda(self.gpu, non_blocking=True)
                    dist.all_reduce(val_batch_size, dist.ReduceOp.SUM)
                    val_batch_size = val_batch_size.item()
                    epoch_val_size += val_batch_size

dist.all_reduce coordinates and communicates the same tensor across all processes using the specified operation, in this case, sum. This way, we ensure epoch_val_size is the same across every GPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants