diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 4d1e3ae..1198479 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -29,7 +29,7 @@ python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" +torchrun --standalone --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` To compute the FID from the OT-CFM model at end of training, run: diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index 851f28c..8b13849 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -9,13 +9,13 @@ import os import torch +import tqdm from absl import app, flags from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DistributedSampler from torchdyn.core import NeuralODE from torchvision import datasets, transforms -from tqdm import trange -from utils_cifar import ema, generate_samples, infiniteloop, setup +from utils_cifar import ema, generate_samples, setup from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, @@ -46,7 +46,9 @@ flags.DEFINE_string( "master_addr", "localhost", help="master address for Distributed Data Parallel" ) -flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel") +flags.DEFINE_string( + "master_port", "12355", help="master port for Distributed Data Parallel" +) # Evaluation flags.DEFINE_integer( @@ -100,8 +102,6 @@ def train(rank, total_num_gpus, argv): drop_last=True, ) - datalooper = infiniteloop(dataloader) - # Calculate number of epochs steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch) @@ -116,9 +116,7 @@ def train(rank, total_num_gpus, argv): num_head_channels=64, attention_resolutions="16", dropout=0.1, - ).to( - rank - ) # new dropout + bs of 128 + ).to(rank) # new dropout + bs of 128 ema_model = copy.deepcopy(net_model) optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) @@ -156,46 +154,48 @@ def train(rank, total_num_gpus, argv): global_step = 0 # to keep track of the global step in training loop - with trange(num_epochs, dynamic_ncols=True) as epoch_pbar: + with tqdm.trange(num_epochs, dynamic_ncols=True) as epoch_pbar: for epoch in epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}") if sampler is not None: sampler.set_epoch(epoch) - with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: - for step in step_pbar: - global_step += step - - optim.zero_grad() - x1 = next(datalooper).to(rank) - x0 = torch.randn_like(x1) - t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) - vt = net_model(t, xt) - loss = torch.mean((vt - ut) ** 2) - loss.backward() - torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new - optim.step() - sched.step() - ema(net_model, ema_model, FLAGS.ema_decay) # new - - # sample and Saving the weights - if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: - generate_samples( - net_model, FLAGS.parallel, savedir, global_step, net_="normal" - ) - generate_samples( - ema_model, FLAGS.parallel, savedir, global_step, net_="ema" - ) - torch.save( - { - "net_model": net_model.state_dict(), - "ema_model": ema_model.state_dict(), - "sched": sched.state_dict(), - "optim": optim.state_dict(), - "step": global_step, - }, - savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", - ) + for x1, _ in tqdm.tqdm(dataloader, total=len(dataloader)): + global_step += 1 + + optim.zero_grad() + x1 = x1.to(rank) + x0 = torch.randn_like(x1) + t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) + vt = net_model(t, xt) + loss = torch.mean((vt - ut) ** 2) + loss.backward() + torch.nn.utils.clip_grad_norm_( + net_model.parameters(), FLAGS.grad_clip + ) # new + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) # new + + # sample and Saving the weights + if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: + generate_samples( + net_model, FLAGS.parallel, savedir, global_step, net_="normal" + ) + generate_samples( + ema_model, FLAGS.parallel, savedir, global_step, net_="ema" + ) + torch.save( + { + "net_model": net_model.state_dict(), + "ema_model": ema_model.state_dict(), + "sched": sched.state_dict(), + "optim": optim.state_dict(), + "step": global_step, + }, + savedir + + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", + ) def main(argv):