From c25e1918a80dfacbe9475c055d61ac997f28262a Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh <120030096+ImahnShekhzadeh@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:27:21 +0200 Subject: [PATCH] add support for distributed data parallel training (#116) * make code changes in `train_cifar10.py` to allow DDP (distributed data parallel) * add instructions to README on how to run cifar10 image generation code on multiple GPUs * fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting * fix: load checkpoint on right device * fix runner ci requirements (#125) * change pytorch lightning version * fix pip version * fix pip in code cov * change variable name `world_size` to `total_num_gpus` * change: do not overwrite batch size flag * add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps * fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode * rename file, update README * add original CIFAR10 training file --------- Co-authored-by: Alexander Tong --- examples/images/cifar10/README.md | 6 +- examples/images/cifar10/compute_fid.py | 2 +- examples/images/cifar10/train_cifar10_ddp.py | 214 +++++++++++++++++++ examples/images/cifar10/utils_cifar.py | 30 +++ 4 files changed, 247 insertions(+), 5 deletions(-) create mode 100644 examples/images/cifar10/train_cifar10_ddp.py diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 7d5012d..4d1e3ae 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 ``` -Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example: +Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True +torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` -*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory). - To compute the FID from the OT-CFM model at end of training, run: ```bash diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py index ffa66c2..7596699 100644 --- a/examples/images/cifar10/compute_fid.py +++ b/examples/images/cifar10/compute_fid.py @@ -51,7 +51,7 @@ # Load the model PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt" print("path: ", PATH) -checkpoint = torch.load(PATH) +checkpoint = torch.load(PATH, map_location=device) state_dict = checkpoint["ema_model"] try: new_net.load_state_dict(state_dict) diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py new file mode 100644 index 0000000..851f28c --- /dev/null +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -0,0 +1,214 @@ +# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. + +# Authors: Kilian Fatras +# Alexander Tong +# Imahn Shekhzadeh + +import copy +import math +import os + +import torch +from absl import app, flags +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DistributedSampler +from torchdyn.core import NeuralODE +from torchvision import datasets, transforms +from tqdm import trange +from utils_cifar import ema, generate_samples, infiniteloop, setup + +from torchcfm.conditional_flow_matching import ( + ConditionalFlowMatcher, + ExactOptimalTransportConditionalFlowMatcher, + TargetConditionalFlowMatcher, + VariancePreservingConditionalFlowMatcher, +) +from torchcfm.models.unet.unet import UNetModelWrapper + +FLAGS = flags.FLAGS + +flags.DEFINE_string("model", "otcfm", help="flow matching model type") +flags.DEFINE_string("output_dir", "./results/", help="output_directory") +# UNet +flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") + +# Training +flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4 +flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping") +flags.DEFINE_integer( + "total_steps", 400001, help="total training steps" +) # Lipman et al uses 400k but double batch size +flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") +flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 +flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") +flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") +flags.DEFINE_bool("parallel", False, help="multi gpu training") +flags.DEFINE_string( + "master_addr", "localhost", help="master address for Distributed Data Parallel" +) +flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel") + +# Evaluation +flags.DEFINE_integer( + "save_step", + 20000, + help="frequency of saving checkpoints, 0 to disable during training", +) + + +def warmup_lr(step): + return min(step, FLAGS.warmup) / FLAGS.warmup + + +def train(rank, total_num_gpus, argv): + print( + "lr, total_steps, ema decay, save_step:", + FLAGS.lr, + FLAGS.total_steps, + FLAGS.ema_decay, + FLAGS.save_step, + ) + + if FLAGS.parallel and total_num_gpus > 1: + # When using `DistributedDataParallel`, we need to divide the batch + # size ourselves based on the total number of GPUs of the current node. + batch_size_per_gpu = FLAGS.batch_size // total_num_gpus + setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port) + else: + batch_size_per_gpu = FLAGS.batch_size + + # DATASETS/DATALOADER + dataset = datasets.CIFAR10( + root="./data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + sampler = DistributedSampler(dataset) if FLAGS.parallel else None + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size_per_gpu, + sampler=sampler, + shuffle=False if FLAGS.parallel else True, + num_workers=FLAGS.num_workers, + drop_last=True, + ) + + datalooper = infiniteloop(dataloader) + + # Calculate number of epochs + steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) + num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch) + + # MODELS + net_model = UNetModelWrapper( + dim=(3, 32, 32), + num_res_blocks=2, + num_channels=FLAGS.num_channel, + channel_mult=[1, 2, 2, 2], + num_heads=4, + num_head_channels=64, + attention_resolutions="16", + dropout=0.1, + ).to( + rank + ) # new dropout + bs of 128 + + ema_model = copy.deepcopy(net_model) + optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) + sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) + if FLAGS.parallel: + net_model = DistributedDataParallel(net_model, device_ids=[rank]) + ema_model = DistributedDataParallel(ema_model, device_ids=[rank]) + + # show model size + model_size = 0 + for param in net_model.parameters(): + model_size += param.data.nelement() + print("Model params: %.2f M" % (model_size / 1024 / 1024)) + + ################################# + # OT-CFM + ################################# + + sigma = 0.0 + if FLAGS.model == "otcfm": + FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "icfm": + FM = ConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "fm": + FM = TargetConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "si": + FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) + else: + raise NotImplementedError( + f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" + ) + + savedir = FLAGS.output_dir + FLAGS.model + "/" + os.makedirs(savedir, exist_ok=True) + + global_step = 0 # to keep track of the global step in training loop + + with trange(num_epochs, dynamic_ncols=True) as epoch_pbar: + for epoch in epoch_pbar: + epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}") + if sampler is not None: + sampler.set_epoch(epoch) + + with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: + for step in step_pbar: + global_step += step + + optim.zero_grad() + x1 = next(datalooper).to(rank) + x0 = torch.randn_like(x1) + t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) + vt = net_model(t, xt) + loss = torch.mean((vt - ut) ** 2) + loss.backward() + torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) # new + + # sample and Saving the weights + if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: + generate_samples( + net_model, FLAGS.parallel, savedir, global_step, net_="normal" + ) + generate_samples( + ema_model, FLAGS.parallel, savedir, global_step, net_="ema" + ) + torch.save( + { + "net_model": net_model.state_dict(), + "ema_model": ema_model.state_dict(), + "sched": sched.state_dict(), + "optim": optim.state_dict(), + "step": global_step, + }, + savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", + ) + + +def main(argv): + # get world size (number of GPUs) + total_num_gpus = int(os.getenv("WORLD_SIZE", 1)) + + if FLAGS.parallel and total_num_gpus > 1: + train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv) + else: + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + train(rank=device, total_num_gpus=total_num_gpus, argv=argv) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/images/cifar10/utils_cifar.py b/examples/images/cifar10/utils_cifar.py index 50f6a7e..cfa36b8 100644 --- a/examples/images/cifar10/utils_cifar.py +++ b/examples/images/cifar10/utils_cifar.py @@ -1,6 +1,8 @@ import copy +import os import torch +from torch import distributed as dist from torchdyn.core import NeuralODE # from torchvision.transforms import ToPILImage @@ -10,6 +12,34 @@ device = torch.device("cuda" if use_cuda else "cpu") +def setup( + rank: int, + total_num_gpus: int, + master_addr: str = "localhost", + master_port: str = "12355", + backend: str = "nccl", +): + """Initialize the distributed environment. + + Args: + rank: Rank of the current process. + total_num_gpus: Number of GPUs used in the job. + master_addr: IP address of the master node. + master_port: Port number of the master node. + backend: Backend to use. + """ + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + # initialize the process group + dist.init_process_group( + backend=backend, + rank=rank, + world_size=total_num_gpus, + ) + + def generate_samples(model, parallel, savedir, step, net_="normal"): """Save 64 generated images (8 x 8) for sanity check along training.