From b29d4184795de3720f93b91088bd791b22534f04 Mon Sep 17 00:00:00 2001 From: Kilian Date: Wed, 13 Dec 2023 15:53:42 -0500 Subject: [PATCH] make generate images function work with parallel in CIFAR files (#83) * remove unused flag variables * make generate_image function work with parallel * add docstring to generate_samples function --------- Co-authored-by: Alex Tong Co-authored-by: kilian.fatras Reviewed: Quentin Bertrand --- examples/cifar10/README.md | 8 ++++++++ examples/cifar10/compute_fid.py | 1 - examples/cifar10/train_cifar10.py | 14 +++++--------- examples/cifar10/utils_cifar.py | 28 +++++++++++++++++++++++++--- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md index b219138..ee1d264 100644 --- a/examples/cifar10/README.md +++ b/examples/cifar10/README.md @@ -26,6 +26,14 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 ``` +Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example: + +```bash +python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True +``` + +*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory). + To compute the FID from the OT-CFM model at end of training, run: ```bash diff --git a/examples/cifar10/compute_fid.py b/examples/cifar10/compute_fid.py index 8d23bf9..ed16407 100644 --- a/examples/cifar10/compute_fid.py +++ b/examples/cifar10/compute_fid.py @@ -20,7 +20,6 @@ flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") # Training -flags.DEFINE_bool("parallel", False, help="multi gpu training") flags.DEFINE_string("input_dir", "./results", help="output_directory") flags.DEFINE_string("model", "otcfm", help="flow matching model type") flags.DEFINE_integer("integration_steps", 100, help="number of inference steps") diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index df2d2aa..9e29963 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -33,7 +33,6 @@ flags.DEFINE_integer( "total_steps", 400001, help="total training steps" ) # Lipman et al uses 400k but double batch size -flags.DEFINE_integer("img_size", 32, help="image size") flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") @@ -46,10 +45,6 @@ 20000, help="frequency of saving checkpoints, 0 to disable during training", ) -flags.DEFINE_integer( - "eval_step", 0, help="frequency of evaluating model, 0 to disable during training" -) -flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation") use_cuda = torch.cuda.is_available() @@ -110,11 +105,12 @@ def train(argv): optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) if FLAGS.parallel: + print( + "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." + ) net_model = torch.nn.DataParallel(net_model) ema_model = torch.nn.DataParallel(ema_model) - net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint") - ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint") # show model size model_size = 0 for param in net_model.parameters(): @@ -156,8 +152,8 @@ def train(argv): # sample and Saving the weights if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: - generate_samples(net_node, net_model, savedir, step, net_="normal") - generate_samples(ema_node, ema_model, savedir, step, net_="ema") + generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") + generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") torch.save( { "net_model": net_model.state_dict(), diff --git a/examples/cifar10/utils_cifar.py b/examples/cifar10/utils_cifar.py index bc47cbb..1eec02e 100644 --- a/examples/cifar10/utils_cifar.py +++ b/examples/cifar10/utils_cifar.py @@ -1,3 +1,5 @@ +import copy + import torch from torchdyn.core import NeuralODE @@ -8,15 +10,35 @@ device = torch.device("cuda" if use_cuda else "cpu") -def generate_samples(node_, model, savedir, step, net_="normal"): +def generate_samples(model, parallel, savedir, step, net_="normal"): + """Save 64 generated images (8 x 8) for sanity check along training. + + Parameters + ---------- + model: + represents the neural network that we want to generate samples from + parallel: bool + represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU. + savedir: str + represents the path where we want to save the generated images + step: int + represents the current step of training + """ model.eval() + + model_ = copy.deepcopy(model) + if parallel: + # Send the models from GPU to CPU for inference with NeuralODE from Torchdyn + model_ = model_.module.to(device) + + node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint") with torch.no_grad(): traj = node_.trajectory( torch.randn(64, 3, 32, 32).to(device), t_span=torch.linspace(0, 1, 100).to(device), ) - traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1) - traj = traj / 2 + 0.5 + traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1) + traj = traj / 2 + 0.5 save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8) model.train()