Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for distributed data parallel training #116

Merged
merged 11 commits into from
Aug 21, 2024
Merged
6 changes: 2 additions & 4 deletions examples/images/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example:
Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True
torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT"
```

*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory).

To compute the FID from the OT-CFM model at end of training, run:

```bash
Expand Down
2 changes: 1 addition & 1 deletion examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
checkpoint = torch.load(PATH, map_location=device)
kilianFatras marked this conversation as resolved.
Show resolved Hide resolved
state_dict = checkpoint["ema_model"]
try:
new_net.load_state_dict(state_dict)
Expand Down
48 changes: 33 additions & 15 deletions examples/images/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import torch
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop
from utils_cifar import ema, generate_samples, infiniteloop, setup

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
Expand Down Expand Up @@ -39,6 +41,10 @@
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string(
"master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel")

# Evaluation
flags.DEFINE_integer(
Expand All @@ -48,15 +54,11 @@
)


use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


def warmup_lr(step):
return min(step, FLAGS.warmup) / FLAGS.warmup


def train(argv):
def train(rank, world_size, argv):
print(
"lr, total_steps, ema decay, save_step:",
FLAGS.lr,
Expand All @@ -65,6 +67,12 @@ def train(argv):
FLAGS.save_step,
)

if FLAGS.parallel and world_size > 1:
# When using `DistributedDataParallel`, we need to divide the batch
# size ourselves based on the total number of GPUs of the current node.
FLAGS.batch_size = int(FLAGS.batch_size / world_size)
kilianFatras marked this conversation as resolved.
Show resolved Hide resolved
setup(rank, world_size, FLAGS.master_addr, FLAGS.master_port)

# DATASETS/DATALOADER
dataset = datasets.CIFAR10(
root="./data",
Expand All @@ -81,7 +89,8 @@ def train(argv):
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=FLAGS.batch_size,
shuffle=True,
sampler=DistributedSampler(dataset) if FLAGS.parallel else None,
shuffle=False if FLAGS.parallel else True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum. I am rather unsure about this. where do you shuffle the data then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. I found this warning in the PyTorch docs:
image

So in my current implementation, train_sampler.set_epoch(epoch) is missing, which I will add now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect. Once you have finished your change, I will run the code myself. Once I get it working, I will merge the PR.

Final question, can you try to load and run the existing checkpoints? I just want to be sure that people can reproduce our results. Thx.

Copy link
Contributor Author

@ImahnShekhzadeh ImahnShekhzadeh Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I refactored the training loop to use num_epochs instead of FLAGS.total_steps, since sampler.set_epoch(epoch) uses an epoch count. However, I think we need to change more than this. The PyTorch warning I pasted above mentions that we need to use sampler.set_epoch(epoch) "before creating the DataLoader iterator", but right now, the data loader iterator is created once before the training loop:

from utils_cifar import infiniteloop

datalooper = infiniteloop(dataloader)

The way I would change this is by having a training loop like this:

# datalooper = infiniteloop(dataloader)

with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
    for epoch in epoch_pbar:
        epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        if sampler is not None:
            sampler.set_epoch(epoch)
        for batch_idx, data in enumerate(dataloader): 
                # step += 1  # tricky
                optim.zero_grad()
                x1 = data.to(rank)  # old: `x1 = next(datalooper)`
                [...]

Is this fine by you? IMO, what is a bit tricky is to handle the step counter correctly (based on which checkpoints are saved and some samples during training are generated). In a distributed setup, we'll have several processes running in parallel, and thus, we would probably save checkpoints and images multiple times (once per process/GPU). However, since the filenames do not reflect the process ID, one process would also overwrite the files of the other. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About your question: When you say "existing checkpoints", which ones do you mean? I had once run the training and generation of samples on one GPU and gotten an FID of 3.8 (which is only slightly worse than the 3.5 you report).

num_workers=FLAGS.num_workers,
drop_last=True,
)
Expand All @@ -99,18 +108,15 @@ def train(argv):
attention_resolutions="16",
dropout=0.1,
).to(
device
rank
) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
if FLAGS.parallel:
print(
"Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
)
net_model = torch.nn.DataParallel(net_model)
ema_model = torch.nn.DataParallel(ema_model)
net_model = DistributedDataParallel(net_model, device_ids=[rank])
ema_model = DistributedDataParallel(ema_model, device_ids=[rank])

# show model size
model_size = 0
Expand Down Expand Up @@ -142,7 +148,7 @@ def train(argv):
with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar:
for step in pbar:
optim.zero_grad()
x1 = next(datalooper).to(device)
x1 = next(datalooper).to(rank)
x0 = torch.randn_like(x1)
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
vt = net_model(t, xt)
Expand All @@ -169,5 +175,17 @@ def train(argv):
)


def main(argv):
# get world size (number of GPUs)
world_size = int(os.getenv("WORLD_SIZE", 1))
kilianFatras marked this conversation as resolved.
Show resolved Hide resolved

if FLAGS.parallel and world_size > 1:
train(rank=int(os.getenv("RANK", 0)), world_size=world_size, argv=argv)
else:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train(rank=device, world_size=world_size, argv=argv)


if __name__ == "__main__":
app.run(train)
app.run(main)
30 changes: 30 additions & 0 deletions examples/images/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import os

import torch
from torch import distributed as dist
from torchdyn.core import NeuralODE

# from torchvision.transforms import ToPILImage
Expand All @@ -10,6 +12,34 @@
device = torch.device("cuda" if use_cuda else "cpu")


def setup(
rank: int,
world_size: int,
master_addr: str = "localhost",
master_port: str = "12355",
backend: str = "nccl",
):
"""Initialize the distributed environment.

Args:
rank: Rank of the current process.
world_size: Number of processes participating in the job.
master_addr: IP address of the master node.
master_port: Port number of the master node.
backend: Backend to use.
"""

os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

# initialize the process group
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
)


def generate_samples(model, parallel, savedir, step, net_="normal"):
"""Save 64 generated images (8 x 8) for sanity check along training.

Expand Down
Loading