Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using infiniteloop in train_cifar10_ddp.py #145

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/images/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size
Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example:

```bash
torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
torchrun --standalone --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
```

To compute the FID from the OT-CFM model at end of training, run:
Expand Down
86 changes: 43 additions & 43 deletions examples/images/cifar10/train_cifar10_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import os

import torch
import tqdm
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup
from utils_cifar import ema, generate_samples, setup

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
Expand Down Expand Up @@ -46,7 +46,9 @@
flags.DEFINE_string(
"master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel")
flags.DEFINE_string(
"master_port", "12355", help="master port for Distributed Data Parallel"
)

# Evaluation
flags.DEFINE_integer(
Expand Down Expand Up @@ -100,8 +102,6 @@ def train(rank, total_num_gpus, argv):
drop_last=True,
)

datalooper = infiniteloop(dataloader)

# Calculate number of epochs
steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size)
num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch)
Expand All @@ -116,9 +116,7 @@ def train(rank, total_num_gpus, argv):
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
rank
) # new dropout + bs of 128
).to(rank) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
Expand Down Expand Up @@ -156,46 +154,48 @@ def train(rank, total_num_gpus, argv):

global_step = 0 # to keep track of the global step in training loop

with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
with tqdm.trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
if sampler is not None:
sampler.set_epoch(epoch)

with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
for step in step_pbar:
global_step += step

optim.zero_grad()
x1 = next(datalooper).to(rank)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay) # new

# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
)
torch.save(
{
"net_model": net_model.state_dict(),
"ema_model": ema_model.state_dict(),
"sched": sched.state_dict(),
"optim": optim.state_dict(),
"step": global_step,
},
savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
)
for x1, _ in tqdm.tqdm(dataloader, total=len(dataloader)):
global_step += 1

optim.zero_grad()
x1 = x1.to(rank)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
loss = torch.mean((vt - ut) ** 2)
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), FLAGS.grad_clip
) # new
optim.step()
sched.step()
ema(net_model, ema_model, FLAGS.ema_decay) # new

# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
)
torch.save(
{
"net_model": net_model.state_dict(),
"ema_model": ema_model.state_dict(),
"sched": sched.state_dict(),
"optim": optim.state_dict(),
"step": global_step,
},
savedir
+ f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
)
Comment on lines +163 to +198
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chunk seems to change a lot but it actually only modifies where x1 is read. The chunk comes from the indentation change as we do not need step_pbar now.



def main(argv):
Expand Down
Loading