From 4178ea500fd0a7e7e85efdc0e92bf1554f66c0ec Mon Sep 17 00:00:00 2001 From: Xiaoming Zhao Date: Sat, 16 Nov 2024 15:39:48 -0600 Subject: [PATCH 1/3] Fixed global_step in train_cifar10_ddp.py --- examples/images/cifar10/train_cifar10_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index 851f28c..2932443 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -164,7 +164,7 @@ def train(rank, total_num_gpus, argv): with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: for step in step_pbar: - global_step += step + global_step += 1 optim.zero_grad() x1 = next(datalooper).to(rank) From 2781067c3446bbaf2340bb773ba94929729b6592 Mon Sep 17 00:00:00 2001 From: xz Date: Sat, 16 Nov 2024 16:53:35 -0600 Subject: [PATCH 2/3] avoding using infinite generator for train_cifar10_ddp.py --- examples/images/cifar10/README.md | 2 +- examples/images/cifar10/train_cifar10_ddp.py | 82 ++++++++++---------- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 4d1e3ae..1198479 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -29,7 +29,7 @@ python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" +torchrun --standalone --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` To compute the FID from the OT-CFM model at end of training, run: diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index 2932443..ed731ab 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -9,12 +9,12 @@ import os import torch +import tqdm from absl import app, flags from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DistributedSampler from torchdyn.core import NeuralODE from torchvision import datasets, transforms -from tqdm import trange from utils_cifar import ema, generate_samples, infiniteloop, setup from torchcfm.conditional_flow_matching import ( @@ -46,7 +46,9 @@ flags.DEFINE_string( "master_addr", "localhost", help="master address for Distributed Data Parallel" ) -flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel") +flags.DEFINE_string( + "master_port", "12355", help="master port for Distributed Data Parallel" +) # Evaluation flags.DEFINE_integer( @@ -116,9 +118,7 @@ def train(rank, total_num_gpus, argv): num_head_channels=64, attention_resolutions="16", dropout=0.1, - ).to( - rank - ) # new dropout + bs of 128 + ).to(rank) # new dropout + bs of 128 ema_model = copy.deepcopy(net_model) optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) @@ -156,46 +156,48 @@ def train(rank, total_num_gpus, argv): global_step = 0 # to keep track of the global step in training loop - with trange(num_epochs, dynamic_ncols=True) as epoch_pbar: + with tqdm.trange(num_epochs, dynamic_ncols=True) as epoch_pbar: for epoch in epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}") if sampler is not None: sampler.set_epoch(epoch) - with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: - for step in step_pbar: - global_step += 1 - - optim.zero_grad() - x1 = next(datalooper).to(rank) - x0 = torch.randn_like(x1) - t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) - vt = net_model(t, xt) - loss = torch.mean((vt - ut) ** 2) - loss.backward() - torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new - optim.step() - sched.step() - ema(net_model, ema_model, FLAGS.ema_decay) # new - - # sample and Saving the weights - if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: - generate_samples( - net_model, FLAGS.parallel, savedir, global_step, net_="normal" - ) - generate_samples( - ema_model, FLAGS.parallel, savedir, global_step, net_="ema" - ) - torch.save( - { - "net_model": net_model.state_dict(), - "ema_model": ema_model.state_dict(), - "sched": sched.state_dict(), - "optim": optim.state_dict(), - "step": global_step, - }, - savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", - ) + for x1, _ in tqdm.tqdm(dataloader, total=len(dataloader)): + global_step += 1 + + optim.zero_grad() + x1 = x1.to(rank) + x0 = torch.randn_like(x1) + t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) + vt = net_model(t, xt) + loss = torch.mean((vt - ut) ** 2) + loss.backward() + torch.nn.utils.clip_grad_norm_( + net_model.parameters(), FLAGS.grad_clip + ) # new + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) # new + + # sample and Saving the weights + if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: + generate_samples( + net_model, FLAGS.parallel, savedir, global_step, net_="normal" + ) + generate_samples( + ema_model, FLAGS.parallel, savedir, global_step, net_="ema" + ) + torch.save( + { + "net_model": net_model.state_dict(), + "ema_model": ema_model.state_dict(), + "sched": sched.state_dict(), + "optim": optim.state_dict(), + "step": global_step, + }, + savedir + + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", + ) def main(argv): From db65caafb62bea5d2606f5961e21fd6ddfdc44eb Mon Sep 17 00:00:00 2001 From: xz Date: Sat, 16 Nov 2024 16:56:18 -0600 Subject: [PATCH 3/3] remove unused infiniteloop --- examples/images/cifar10/train_cifar10_ddp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index ed731ab..8b13849 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -15,7 +15,7 @@ from torch.utils.data import DistributedSampler from torchdyn.core import NeuralODE from torchvision import datasets, transforms -from utils_cifar import ema, generate_samples, infiniteloop, setup +from utils_cifar import ema, generate_samples, setup from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, @@ -102,8 +102,6 @@ def train(rank, total_num_gpus, argv): drop_last=True, ) - datalooper = infiniteloop(dataloader) - # Calculate number of epochs steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)