From 51ea2cd11f29d59526aa18986278dc82a7246df1 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Sun, 19 May 2024 19:24:52 +0200 Subject: [PATCH 01/11] make code changes in `train_cifar10.py` to allow DDP (distributed data parallel) --- examples/images/cifar10/train_cifar10.py | 40 ++++++++++++++++++------ examples/images/cifar10/utils_cifar.py | 30 ++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 14b8b04..574d752 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -8,10 +8,12 @@ import torch from absl import app, flags +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DistributedSampler from torchdyn.core import NeuralODE from torchvision import datasets, transforms from tqdm import trange -from utils_cifar import ema, generate_samples, infiniteloop +from utils_cifar import ema, generate_samples, infiniteloop, setup from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, @@ -39,6 +41,10 @@ flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") flags.DEFINE_bool("parallel", False, help="multi gpu training") +flags.DEFINE_string( + "master_addr", "localhost", help="master address for Distributed Data Parallel" +) +flags.DEFINE_string("master_port", "12355", help="master port for Distributed Data Parallel") # Evaluation flags.DEFINE_integer( @@ -56,7 +62,7 @@ def warmup_lr(step): return min(step, FLAGS.warmup) / FLAGS.warmup -def train(argv): +def train(rank, world_size, argv): print( "lr, total_steps, ema decay, save_step:", FLAGS.lr, @@ -65,6 +71,12 @@ def train(argv): FLAGS.save_step, ) + if FLAGS.parallel and world_size > 1: + # When using `DistributedDataParallel`, we need to divide the batch + # size ourselves based on the total number of GPUs of the current node. + FLAGS.batch_size = int(FLAGS.batch_size / world_size) + setup(rank, world_size, FLAGS.master_addr, FLAGS.master_port) + # DATASETS/DATALOADER dataset = datasets.CIFAR10( root="./data", @@ -81,7 +93,8 @@ def train(argv): dataloader = torch.utils.data.DataLoader( dataset, batch_size=FLAGS.batch_size, - shuffle=True, + sampler=DistributedSampler(dataset) if FLAGS.parallel else None, + shuffle=False if FLAGS.parallel else True, num_workers=FLAGS.num_workers, drop_last=True, ) @@ -99,18 +112,15 @@ def train(argv): attention_resolutions="16", dropout=0.1, ).to( - device + rank ) # new dropout + bs of 128 ema_model = copy.deepcopy(net_model) optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) if FLAGS.parallel: - print( - "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." - ) - net_model = torch.nn.DataParallel(net_model) - ema_model = torch.nn.DataParallel(ema_model) + net_model = DistributedDataParallel(net_model, device_ids=[rank]) + ema_model = DistributedDataParallel(ema_model, device_ids=[rank]) # show model size model_size = 0 @@ -169,5 +179,15 @@ def train(argv): ) +def main(argv): + # get world size (number of GPUs) + world_size = int(os.getenv("WORLD_SIZE", 1)) + + if FLAGS.parallel and world_size > 1: + train(rank=int(os.getenv("RANK", 0)), world_size=world_size, argv=argv) + else: + train(rank=device, world_size=world_size, argv=argv) + + if __name__ == "__main__": - app.run(train) + app.run(main) diff --git a/examples/images/cifar10/utils_cifar.py b/examples/images/cifar10/utils_cifar.py index 50f6a7e..33a4650 100644 --- a/examples/images/cifar10/utils_cifar.py +++ b/examples/images/cifar10/utils_cifar.py @@ -1,6 +1,8 @@ import copy +import os import torch +from torch import distributed as dist from torchdyn.core import NeuralODE # from torchvision.transforms import ToPILImage @@ -10,6 +12,34 @@ device = torch.device("cuda" if use_cuda else "cpu") +def setup( + rank: int, + world_size: int, + master_addr: str = "localhost", + master_port: str = "12355", + backend: str = "nccl", +): + """Initialize the distributed environment. + + Args: + rank: Rank of the current process. + world_size: Number of processes participating in the job. + master_addr: IP address of the master node. + master_port: Port number of the master node. + backend: Backend to use. + """ + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + # initialize the process group + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + + def generate_samples(model, parallel, savedir, step, net_="normal"): """Save 64 generated images (8 x 8) for sanity check along training. From 71122c9d2366c606d8c6b35764c8da64c3043e8f Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Sun, 19 May 2024 23:53:06 +0200 Subject: [PATCH 02/11] add instructions to README on how to run cifar10 image generation code on multiple GPUs --- examples/images/cifar10/README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 7d5012d..7ee61a4 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -26,14 +26,12 @@ python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_siz python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 ``` -Note that you can train all our methods in parallel using multiple GPUs and DataParallel. You can do this by setting the parallel flag to True in the command line. As an example: +Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True +torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` -*Note from the authors*: We have observed that training with parallel leads to slightly poorer performance than what you can get with one GPU. The reason is probably that DataParallel computes statistics over each device. We are thinking of using DistributedDataParallel to solve this problem in the future. In the meantime, we strongly encourage users to train on a single GPU (the provided scripts require about 8G of GPU memory). - To compute the FID from the OT-CFM model at end of training, run: ```bash From d0b0da24e816fea7f2db96941fff5d54ec203e73 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Sun, 19 May 2024 23:54:21 +0200 Subject: [PATCH 03/11] fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting --- examples/images/cifar10/train_cifar10.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 574d752..d238410 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -54,10 +54,6 @@ ) -use_cuda = torch.cuda.is_available() -device = torch.device("cuda" if use_cuda else "cpu") - - def warmup_lr(step): return min(step, FLAGS.warmup) / FLAGS.warmup @@ -152,7 +148,7 @@ def train(rank, world_size, argv): with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: for step in pbar: optim.zero_grad() - x1 = next(datalooper).to(device) + x1 = next(datalooper).to(rank) x0 = torch.randn_like(x1) t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) vt = net_model(t, xt) @@ -186,6 +182,8 @@ def main(argv): if FLAGS.parallel and world_size > 1: train(rank=int(os.getenv("RANK", 0)), world_size=world_size, argv=argv) else: + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") train(rank=device, world_size=world_size, argv=argv) From 333d73f2d7b9d82bf6147c0b2d6764d851b7cb63 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Tue, 21 May 2024 00:35:53 +0200 Subject: [PATCH 04/11] fix: load checkpoint on right device --- examples/images/cifar10/compute_fid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py index ffa66c2..7596699 100644 --- a/examples/images/cifar10/compute_fid.py +++ b/examples/images/cifar10/compute_fid.py @@ -51,7 +51,7 @@ # Load the model PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt" print("path: ", PATH) -checkpoint = torch.load(PATH) +checkpoint = torch.load(PATH, map_location=device) state_dict = checkpoint["ema_model"] try: new_net.load_state_dict(state_dict) From 707dfbeae6903b633bde0031807cd87e776309a2 Mon Sep 17 00:00:00 2001 From: Alexander Tong Date: Thu, 11 Jul 2024 16:58:15 -0400 Subject: [PATCH 05/11] fix runner ci requirements (#125) * change pytorch lightning version * fix pip version * fix pip in code cov --- .github/workflows/test_runner.yaml | 6 ++++-- runner-requirements.txt | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_runner.yaml b/.github/workflows/test_runner.yaml index e901d93..617caa3 100644 --- a/.github/workflows/test_runner.yaml +++ b/.github/workflows/test_runner.yaml @@ -27,7 +27,8 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip + # Fix pip version < 24.1 due to lightning incomaptibility + python -m pip install pip==23.2.1 pip install -r runner-requirements.txt pip install pytest pip install sh @@ -56,7 +57,8 @@ jobs: - name: Install dependencies run: | - python -m pip install --upgrade pip + # Fix pip version < 24.1 due to lightning incomaptibility + python -m pip install pip==23.2.1 pip install -r runner-requirements.txt pip install pytest pip install pytest-cov[toml] diff --git a/runner-requirements.txt b/runner-requirements.txt index 3f40053..90e92dd 100644 --- a/runner-requirements.txt +++ b/runner-requirements.txt @@ -8,7 +8,7 @@ # --------- pytorch --------- # torch>=1.11.0,<2.0.0 torchvision>=0.11.0 -pytorch-lightning==1.8.3 +pytorch-lightning==1.8.3.post2 torchmetrics==0.11.0 # --------- hydra --------- # From eb90f190bb7443a01e2795e76fb57248c3c9df12 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Mon, 29 Jul 2024 19:12:07 +0200 Subject: [PATCH 06/11] change variable name `world_size` to `total_num_gpus` --- examples/images/cifar10/train_cifar10.py | 16 ++++++++-------- examples/images/cifar10/utils_cifar.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index d238410..2ec3262 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -58,7 +58,7 @@ def warmup_lr(step): return min(step, FLAGS.warmup) / FLAGS.warmup -def train(rank, world_size, argv): +def train(rank, total_num_gpus, argv): print( "lr, total_steps, ema decay, save_step:", FLAGS.lr, @@ -67,11 +67,11 @@ def train(rank, world_size, argv): FLAGS.save_step, ) - if FLAGS.parallel and world_size > 1: + if FLAGS.parallel and total_num_gpus > 1: # When using `DistributedDataParallel`, we need to divide the batch # size ourselves based on the total number of GPUs of the current node. - FLAGS.batch_size = int(FLAGS.batch_size / world_size) - setup(rank, world_size, FLAGS.master_addr, FLAGS.master_port) + FLAGS.batch_size = int(FLAGS.batch_size / total_num_gpus) + setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port) # DATASETS/DATALOADER dataset = datasets.CIFAR10( @@ -177,14 +177,14 @@ def train(rank, world_size, argv): def main(argv): # get world size (number of GPUs) - world_size = int(os.getenv("WORLD_SIZE", 1)) + total_num_gpus = int(os.getenv("WORLD_SIZE", 1)) - if FLAGS.parallel and world_size > 1: - train(rank=int(os.getenv("RANK", 0)), world_size=world_size, argv=argv) + if FLAGS.parallel and total_num_gpus > 1: + train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv) else: use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") - train(rank=device, world_size=world_size, argv=argv) + train(rank=device, total_num_gpus=total_num_gpus, argv=argv) if __name__ == "__main__": diff --git a/examples/images/cifar10/utils_cifar.py b/examples/images/cifar10/utils_cifar.py index 33a4650..cfa36b8 100644 --- a/examples/images/cifar10/utils_cifar.py +++ b/examples/images/cifar10/utils_cifar.py @@ -14,7 +14,7 @@ def setup( rank: int, - world_size: int, + total_num_gpus: int, master_addr: str = "localhost", master_port: str = "12355", backend: str = "nccl", @@ -23,7 +23,7 @@ def setup( Args: rank: Rank of the current process. - world_size: Number of processes participating in the job. + total_num_gpus: Number of GPUs used in the job. master_addr: IP address of the master node. master_port: Port number of the master node. backend: Backend to use. @@ -36,7 +36,7 @@ def setup( dist.init_process_group( backend=backend, rank=rank, - world_size=world_size, + world_size=total_num_gpus, ) From f2586ce1deff6195aa56238b99cdc59fff690e8b Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Mon, 29 Jul 2024 19:16:05 +0200 Subject: [PATCH 07/11] change: do not overwrite batch size flag --- examples/images/cifar10/train_cifar10.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 2ec3262..4bd77cd 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -70,8 +70,10 @@ def train(rank, total_num_gpus, argv): if FLAGS.parallel and total_num_gpus > 1: # When using `DistributedDataParallel`, we need to divide the batch # size ourselves based on the total number of GPUs of the current node. - FLAGS.batch_size = int(FLAGS.batch_size / total_num_gpus) + batch_size_per_gpu = FLAGS.batch_size // total_num_gpus setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port) + else: + batch_size_per_gpu = FLAGS.batch_size # DATASETS/DATALOADER dataset = datasets.CIFAR10( @@ -88,7 +90,7 @@ def train(rank, total_num_gpus, argv): ) dataloader = torch.utils.data.DataLoader( dataset, - batch_size=FLAGS.batch_size, + batch_size=batch_size_per_gpu, sampler=DistributedSampler(dataset) if FLAGS.parallel else None, shuffle=False if FLAGS.parallel else True, num_workers=FLAGS.num_workers, From 443b00062f186175b4b9e774ebf6e7852a5ff761 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Mon, 29 Jul 2024 20:07:35 +0200 Subject: [PATCH 08/11] add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps --- examples/images/cifar10/train_cifar10.py | 74 +++++++++++++++--------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 4bd77cd..3ecf321 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -2,8 +2,10 @@ # Authors: Kilian Fatras # Alexander Tong +# Imahn Shekhzadeh import copy +import math import os import torch @@ -99,6 +101,10 @@ def train(rank, total_num_gpus, argv): datalooper = infiniteloop(dataloader) + # Calculate number of epochs + steps_per_epoch = math.ceil(len(dataset) / FLAGS.batch_size) + num_epochs = math.ceil(FLAGS.total_steps / steps_per_epoch) + # MODELS net_model = UNetModelWrapper( dim=(3, 32, 32), @@ -147,34 +153,46 @@ def train(rank, total_num_gpus, argv): savedir = FLAGS.output_dir + FLAGS.model + "/" os.makedirs(savedir, exist_ok=True) - with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: - for step in pbar: - optim.zero_grad() - x1 = next(datalooper).to(rank) - x0 = torch.randn_like(x1) - t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) - vt = net_model(t, xt) - loss = torch.mean((vt - ut) ** 2) - loss.backward() - torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new - optim.step() - sched.step() - ema(net_model, ema_model, FLAGS.ema_decay) # new - - # sample and Saving the weights - if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: - generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") - generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") - torch.save( - { - "net_model": net_model.state_dict(), - "ema_model": ema_model.state_dict(), - "sched": sched.state_dict(), - "optim": optim.state_dict(), - "step": step, - }, - savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt", - ) + global_step = 0 # to keep track of the global step in training loop + + with trange(num_epochs, dynamic_ncols=True) as epoch_bar: + for epoch in epoch_bar: + epoch_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}") + + with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: + for step in step_pbar: + global_step += step + + optim.zero_grad() + x1 = next(datalooper).to(rank) + x0 = torch.randn_like(x1) + t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) + vt = net_model(t, xt) + loss = torch.mean((vt - ut) ** 2) + loss.backward() + torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) # new + + # sample and Saving the weights + if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: + generate_samples( + net_model, FLAGS.parallel, savedir, global_step, net_="normal" + ) + generate_samples( + ema_model, FLAGS.parallel, savedir, global_step, net_="ema" + ) + torch.save( + { + "net_model": net_model.state_dict(), + "ema_model": ema_model.state_dict(), + "sched": sched.state_dict(), + "optim": optim.state_dict(), + "step": global_step, + }, + savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt", + ) def main(argv): From f8bc6467534e088b3d237eb424dce1e7c59f1f03 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Mon, 29 Jul 2024 20:13:18 +0200 Subject: [PATCH 09/11] fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode --- examples/images/cifar10/train_cifar10.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 3ecf321..851f28c 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -90,10 +90,11 @@ def train(rank, total_num_gpus, argv): ] ), ) + sampler = DistributedSampler(dataset) if FLAGS.parallel else None dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size_per_gpu, - sampler=DistributedSampler(dataset) if FLAGS.parallel else None, + sampler=sampler, shuffle=False if FLAGS.parallel else True, num_workers=FLAGS.num_workers, drop_last=True, @@ -155,9 +156,11 @@ def train(rank, total_num_gpus, argv): global_step = 0 # to keep track of the global step in training loop - with trange(num_epochs, dynamic_ncols=True) as epoch_bar: - for epoch in epoch_bar: - epoch_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}") + with trange(num_epochs, dynamic_ncols=True) as epoch_pbar: + for epoch in epoch_pbar: + epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}") + if sampler is not None: + sampler.set_epoch(epoch) with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar: for step in step_pbar: From c9294342c566826e72503cc50f1cf8a9772481bd Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Thu, 8 Aug 2024 10:07:47 +0200 Subject: [PATCH 10/11] rename file, update README --- examples/images/cifar10/README.md | 2 +- .../images/cifar10/{train_cifar10.py => train_cifar10_ddp.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename examples/images/cifar10/{train_cifar10.py => train_cifar10_ddp.py} (100%) diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 7ee61a4..4d1e3ae 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -29,7 +29,7 @@ python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" +torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` To compute the FID from the OT-CFM model at end of training, run: diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10_ddp.py similarity index 100% rename from examples/images/cifar10/train_cifar10.py rename to examples/images/cifar10/train_cifar10_ddp.py From 0389ea423136ed1d94f2cf788113a959640504a1 Mon Sep 17 00:00:00 2001 From: Imahn Shekhzadeh Date: Thu, 8 Aug 2024 10:08:31 +0200 Subject: [PATCH 11/11] add original CIFAR10 training file --- examples/images/cifar10/train_cifar10.py | 173 +++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 examples/images/cifar10/train_cifar10.py diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py new file mode 100644 index 0000000..14b8b04 --- /dev/null +++ b/examples/images/cifar10/train_cifar10.py @@ -0,0 +1,173 @@ +# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. + +# Authors: Kilian Fatras +# Alexander Tong + +import copy +import os + +import torch +from absl import app, flags +from torchdyn.core import NeuralODE +from torchvision import datasets, transforms +from tqdm import trange +from utils_cifar import ema, generate_samples, infiniteloop + +from torchcfm.conditional_flow_matching import ( + ConditionalFlowMatcher, + ExactOptimalTransportConditionalFlowMatcher, + TargetConditionalFlowMatcher, + VariancePreservingConditionalFlowMatcher, +) +from torchcfm.models.unet.unet import UNetModelWrapper + +FLAGS = flags.FLAGS + +flags.DEFINE_string("model", "otcfm", help="flow matching model type") +flags.DEFINE_string("output_dir", "./results/", help="output_directory") +# UNet +flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") + +# Training +flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4 +flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping") +flags.DEFINE_integer( + "total_steps", 400001, help="total training steps" +) # Lipman et al uses 400k but double batch size +flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") +flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 +flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") +flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") +flags.DEFINE_bool("parallel", False, help="multi gpu training") + +# Evaluation +flags.DEFINE_integer( + "save_step", + 20000, + help="frequency of saving checkpoints, 0 to disable during training", +) + + +use_cuda = torch.cuda.is_available() +device = torch.device("cuda" if use_cuda else "cpu") + + +def warmup_lr(step): + return min(step, FLAGS.warmup) / FLAGS.warmup + + +def train(argv): + print( + "lr, total_steps, ema decay, save_step:", + FLAGS.lr, + FLAGS.total_steps, + FLAGS.ema_decay, + FLAGS.save_step, + ) + + # DATASETS/DATALOADER + dataset = datasets.CIFAR10( + root="./data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ), + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=FLAGS.batch_size, + shuffle=True, + num_workers=FLAGS.num_workers, + drop_last=True, + ) + + datalooper = infiniteloop(dataloader) + + # MODELS + net_model = UNetModelWrapper( + dim=(3, 32, 32), + num_res_blocks=2, + num_channels=FLAGS.num_channel, + channel_mult=[1, 2, 2, 2], + num_heads=4, + num_head_channels=64, + attention_resolutions="16", + dropout=0.1, + ).to( + device + ) # new dropout + bs of 128 + + ema_model = copy.deepcopy(net_model) + optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) + sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) + if FLAGS.parallel: + print( + "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." + ) + net_model = torch.nn.DataParallel(net_model) + ema_model = torch.nn.DataParallel(ema_model) + + # show model size + model_size = 0 + for param in net_model.parameters(): + model_size += param.data.nelement() + print("Model params: %.2f M" % (model_size / 1024 / 1024)) + + ################################# + # OT-CFM + ################################# + + sigma = 0.0 + if FLAGS.model == "otcfm": + FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "icfm": + FM = ConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "fm": + FM = TargetConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "si": + FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) + else: + raise NotImplementedError( + f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" + ) + + savedir = FLAGS.output_dir + FLAGS.model + "/" + os.makedirs(savedir, exist_ok=True) + + with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: + for step in pbar: + optim.zero_grad() + x1 = next(datalooper).to(device) + x0 = torch.randn_like(x1) + t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) + vt = net_model(t, xt) + loss = torch.mean((vt - ut) ** 2) + loss.backward() + torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) # new + + # sample and Saving the weights + if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: + generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") + generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") + torch.save( + { + "net_model": net_model.state_dict(), + "ema_model": ema_model.state_dict(), + "sched": sched.state_dict(), + "optim": optim.state_dict(), + "step": step, + }, + savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt", + ) + + +if __name__ == "__main__": + app.run(train)