diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index 7ee61a4..4d1e3ae 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -29,7 +29,7 @@ python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size Note that you can train all our methods in parallel using multiple GPUs and DistributedDataParallel. You can do this by providing the number of GPUs, setting the parallel flag to True and providing the master address and port in the command line. As an example: ```bash -torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" +torchrun --nproc_per_node=NUM_GPUS_YOU_HAVE train_cifar10_ddp.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 --parallel True --master_addr "MASTER_ADDR" --master_port "MASTER_PORT" ``` To compute the FID from the OT-CFM model at end of training, run: diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10_ddp.py similarity index 100% rename from examples/images/cifar10/train_cifar10.py rename to examples/images/cifar10/train_cifar10_ddp.py