Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update of Cifar10 example which leads to a FID of 3.5 #65

Merged
merged 18 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ slurm*.out
*.jpg

notebooks/figures/

.DS_Store
477 changes: 56 additions & 421 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/169_generated_samples_otcfm.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/169_generated_samples_otcfm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 37 additions & 3 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,47 @@
# CIFAR-10 experiments using TorchCFM

This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). It is a repository in construction and we will add more features and details in the future (including FID computations and pre-trained weights). We have followed the experimental details provided in [2](https://openreview.net/forum?id=PqvMRDCJT9t).
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.

To reproduce the experiment and save the weights, install the requirements from the main repository and then run:
<p align="center">
<img src="../../assets/169_generated_samples_otcfm.png" width="600"/>
</p>

To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):

- For the OT-Conditional Flow Matching method:

```bash
python3 train_cifar10.py --model "otcfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the Conditional Flow Matching method:

```bash
python3 train_cifar10.py --model "cfm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

- For the original Flow Matching method:

```bash
python3 train_cifar10.py
python3 train_cifar10.py --model "fm" --lr 2e-4 --ema_decay 0.9999 --batch_size 128 --total_steps 400001 --save_step 20000
```

To compute the FID from the OT-CFM model at end of training, run:

```bash
python3 compute_fid.py --model "otcfm" --step 400000 --integration_method dopri5
```

For the other models, change the "otcfm" argument by "cfm" or "fm". For easy reproducibility of our results, you can download the model weights at 400000 iterations here:

- [cfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/cfm_cifar10_weights_step_400000.pt)

- [otcfm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/otcfm_cifar10_weights_step_400000.pt)

- [fm weights](https://github.com/atong01/conditional-flow-matching/releases/download/1.0.4/fm_cifar10_weights_step_400000.pt)

To recompute the FID, change the PATH variable with where you have saved the downloaded weights.

If you find this code useful in your research, please cite the following papers (expand for BibTeX):

<details>
Expand Down
105 changes: 105 additions & 0 deletions examples/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
# Alexander Tong

import os
import sys

import matplotlib.pyplot as plt
import torch
from absl import app, flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE

from torchcfm.models.unet.unet import UNetModelWrapper

FLAGS = flags.FLAGS
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_bool("parallel", False, help="multi gpu training")
flags.DEFINE_string("input_dir", "./results", help="output_directory")
flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use")
flags.DEFINE_integer("step", 400000, help="training steps")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance (absolute and relative)")
FLAGS(sys.argv)


# Define the model
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

new_net = UNetModelWrapper(
dim=(3, 32, 32),
num_res_blocks=2,
num_channels=FLAGS.num_channel,
channel_mult=[1, 2, 2, 2],
num_heads=4,
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(device)


# Load the model
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/cifar10_weights_step_{FLAGS.step}.pt"
print("path: ", PATH)
checkpoint = torch.load(PATH)
state_dict = checkpoint["ema_model"]
try:
new_net.load_state_dict(state_dict)
except RuntimeError:
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
new_net.load_state_dict(new_state_dict)
new_net.eval()


# Define the integration method if euler is used
if FLAGS.integration_method == "euler":
node = NeuralODE(new_net, solver=FLAGS.integration_method)


def gen_1_img(unused_latent):
with torch.no_grad():
x = torch.randn(500, 3, 32, 32).to(device)
if FLAGS.integration_method == "euler":
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, FLAGS.integration_steps + 1).to(device)
traj = node.trajectory(x, t_span=t_span)
else:
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2).to(device)
traj = odeint(
new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
)
traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
return img


print("Start computing FID")
score = fid.compute_fid(
gen=gen_1_img,
dataset_name="cifar10",
batch_size=500,
dataset_res=32,
num_gen=FLAGS.num_gen,
dataset_split="train",
mode="legacy_tensorflow",
)
print()
print("FID has been computed")
# print()
# print("Total NFE: ", new_net.nfe)
print()
print("FID: ", score)
Loading
Loading