Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make generate images function work with parallel in CIFAR files #83

Merged
merged 16 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
```

*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).

To compute the FID from the OT-CFM model at end of training, run:

```bash
Expand Down
1 change: 0 additions & 1 deletion examples/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
Expand Down
14 changes: 5 additions & 9 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
flags.DEFINE_integer(
"total_steps", 400001, help="total training steps"
) # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("img_size", 32, help="image size")
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
Expand All @@ -46,10 +45,6 @@
20000,
help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_integer(
"eval_step", 0, help="frequency of evaluating model, 0 to disable during training"
)
flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation")


use_cuda = torch.cuda.is_available()
Expand Down Expand Up @@ -110,11 +105,12 @@ def train(argv):
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
if FLAGS.parallel:
print(
"Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
)
net_model = torch.nn.DataParallel(net_model)
ema_model = torch.nn.DataParallel(ema_model)

net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint")
ema_node = NeuralODE(ema_model, solver="euler", sensitivity="adjoint")
# show model size
model_size = 0
for param in net_model.parameters():
Expand Down Expand Up @@ -156,8 +152,8 @@ def train(argv):

# sample and Saving the weights
if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
generate_samples(net_node, net_model, savedir, step, net_="normal")
generate_samples(ema_node, ema_model, savedir, step, net_="ema")
generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal")
generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema")
torch.save(
{
"net_model": net_model.state_dict(),
Expand Down
28 changes: 25 additions & 3 deletions examples/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
from torchdyn.core import NeuralODE

Expand All @@ -8,15 +10,35 @@
device = torch.device("cuda" if use_cuda else "cpu")


def generate_samples(node_, model, savedir, step, net_="normal"):
def generate_samples(model, parallel, savedir, step, net_="normal"):
"""Save 64 generated images (8 x 8) for sanity check along training.

Parameters
----------
model:
represents the neural network that we want to generate samples from
parallel: bool
represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU.
savedir: str
represents the path where we want to save the generated images
step: int
represents the current step of training
"""
model.eval()

model_ = copy.deepcopy(model)
if parallel:
# Send the models from GPU to CPU for inference with NeuralODE from Torchdyn
model_ = model_.module.to(device)

node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
with torch.no_grad():
traj = node_.trajectory(
torch.randn(64, 3, 32, 32).to(device),
t_span=torch.linspace(0, 1, 100).to(device),
)
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
traj = traj / 2 + 0.5
save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8)

model.train()
Expand Down
Loading