Skip to content

Commit

Permalink
make generate images function work with parallel in CIFAR files (#83)
Browse files Browse the repository at this point in the history
* remove unused flag variables

* make generate_image function work with parallel

* add docstring to generate_samples function

---------

Co-authored-by: Alex Tong <[email protected]>
Co-authored-by: kilian.fatras <[email protected]>
Reviewed: Quentin Bertrand
  • Loading branch information
3 people authored Dec 13, 2023
1 parent 3c0cfdd commit b29d418
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 13 deletions.
8 changes: 8 additions & 0 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
```

*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).

To compute the FID from the OT-CFM model at end of training, run:

```bash
Expand Down
1 change: 0 additions & 1 deletion examples/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
Expand Down
14 changes: 5 additions & 9 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
flags.DEFINE_integer(
"total_steps", 400001, help="total training steps"
) # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("img_size", 32, help="image size")
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
Expand All @@ -46,10 +45,6 @@
20000,
help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_integer(
"eval_step", 0, help="frequency of evaluating model, 0 to disable during training"
)
flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation")


use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -110,11 +105,12 @@ def train(argv):
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
if FLAGS.parallel:
print(
"Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
)
net_model = torch.nn.DataParallel(net_model)
ema_model = torch.nn.DataParallel(ema_model)

net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint")
ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint")
# show model size
model_size = 0
for param in net_model.parameters():
Expand Down Expand Up @@ -156,8 +152,8 @@ def train(argv):

# sample and Saving the weights
if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
generate_samples(net_node, net_model, savedir, step, net_="normal")
generate_samples(ema_node, ema_model, savedir, step, net_="ema")
generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal")
generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema")
torch.save(
{
"net_model": net_model.state_dict(),
Expand Down
28 changes: 25 additions & 3 deletions examples/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
from torchdyn.core import NeuralODE

Expand All @@ -8,15 +10,35 @@
device = torch.device("cuda" if use_cuda else "cpu")


def generate_samples(node_, model, savedir, step, net_="normal"):
def generate_samples(model, parallel, savedir, step, net_="normal"):
"""Save 64 generated images (8 x 8) for sanity check along training.
Parameters
----------
model:
represents the neural network that we want to generate samples from
parallel: bool
represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU.
savedir: str
represents the path where we want to save the generated images
step: int
represents the current step of training
"""
model.eval()

model_ = copy.deepcopy(model)
if parallel:
# Send the models from GPU to CPU for inference with NeuralODE from Torchdyn
model_ = model_.module.to(device)

node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
with torch.no_grad():
traj = node_.trajectory(
torch.randn(64, 3, 32, 32).to(device),
t_span=torch.linspace(0, 1, 100).to(device),
)
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8)

model.train()
Expand Down

0 comments on commit b29d418

Please sign in to comment.