Skip to content

Commit

Permalink
Merge branch 'main' into forest_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras authored Nov 10, 2023
2 parents 1ae5690 + 21cd0c8 commit 703dbb6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ To reproduce the experiments and save the weights, install the requirements from
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the Conditional Flow Matching method:
- For the Independent Conditional Flow Matching (I-CFM) method:

```bash
python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
python3 train_cifar10.py --model "icfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the original Flow Matching method:
Expand All @@ -32,9 +32,9 @@ To compute the FID from the OT-CFM model at end of training, run:
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
```

For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:
For the other models, change the "otcfm" argument by "icfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:

- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)
- [icfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)

- [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt)

Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def train(argv):
sigma = 0.0
if FLAGS.model == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "cfm":
elif FLAGS.model == "icfm":
FM = ConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'cfm', 'fm']"
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']"
)

savedir = FLAGS.output_dir + FLAGS.model + "/"
Expand Down

0 comments on commit 703dbb6

Please sign in to comment.