From bad11c1e2f6d8ed8ae2d589d9d8b05a08fc924b9 Mon Sep 17 00:00:00 2001 From: Kilian Date: Fri, 3 Nov 2023 16:35:03 -0400 Subject: [PATCH] update cfm to icfm --- examples/cifar10/README.md | 6 +++--- examples/cifar10/train_cifar10.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md index 70319f8..745aa6a 100644 --- a/examples/cifar10/README.md +++ b/examples/cifar10/README.md @@ -17,7 +17,7 @@ python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_si - For the Conditional Flow Matching method: ```bash -python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 +python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000 ``` - For the original Flow Matching method: @@ -32,9 +32,9 @@ To compute the FID from the OT-CFM model at end of training, run: python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5 ``` -For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here: +For the other models, change the "otcfm" argument by "icfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here: -- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt) +- [icfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt) - [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt) diff --git a/examples/cifar10/train_cifar10.py b/examples/cifar10/train_cifar10.py index 9a0cead..df2d2aa 100644 --- a/examples/cifar10/train_cifar10.py +++ b/examples/cifar10/train_cifar10.py @@ -128,13 +128,13 @@ def train(argv): sigma = 0.0 if FLAGS.model == "otcfm": FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) - elif FLAGS.model == "cfm": + elif FLAGS.model == "icfm": FM = ConditionalFlowMatcher(sigma=sigma) elif FLAGS.model == "fm": FM = TargetConditionalFlowMatcher(sigma=sigma) else: raise NotImplementedError( - f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'cfm', 'fm']" + f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']" ) savedir = FLAGS.output_dir + FLAGS.model + "/"