-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c929434
commit 0389ea4
Showing
1 changed file
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master. | ||
|
||
# Authors: Kilian Fatras | ||
# Alexander Tong | ||
|
||
import copy | ||
import os | ||
|
||
import torch | ||
from absl import app, flags | ||
from torchdyn.core import NeuralODE | ||
from torchvision import datasets, transforms | ||
from tqdm import trange | ||
from utils_cifar import ema, generate_samples, infiniteloop | ||
|
||
from torchcfm.conditional_flow_matching import ( | ||
ConditionalFlowMatcher, | ||
ExactOptimalTransportConditionalFlowMatcher, | ||
TargetConditionalFlowMatcher, | ||
VariancePreservingConditionalFlowMatcher, | ||
) | ||
from torchcfm.models.unet.unet import UNetModelWrapper | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("model", "otcfm", help="flow matching model type") | ||
flags.DEFINE_string("output_dir", "./results/", help="output_directory") | ||
# UNet | ||
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet") | ||
|
||
# Training | ||
flags.DEFINE_float("lr", 2e-4, help="target learning rate") # TRY 2e-4 | ||
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping") | ||
flags.DEFINE_integer( | ||
"total_steps", 400001, help="total training steps" | ||
) # Lipman et al uses 400k but double batch size | ||
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup") | ||
flags.DEFINE_integer("batch_size", 128, help="batch size") # Lipman et al uses 128 | ||
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader") | ||
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate") | ||
flags.DEFINE_bool("parallel", False, help="multi gpu training") | ||
|
||
# Evaluation | ||
flags.DEFINE_integer( | ||
"save_step", | ||
20000, | ||
help="frequency of saving checkpoints, 0 to disable during training", | ||
) | ||
|
||
|
||
use_cuda = torch.cuda.is_available() | ||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
|
||
def warmup_lr(step): | ||
return min(step, FLAGS.warmup) / FLAGS.warmup | ||
|
||
|
||
def train(argv): | ||
print( | ||
"lr, total_steps, ema decay, save_step:", | ||
FLAGS.lr, | ||
FLAGS.total_steps, | ||
FLAGS.ema_decay, | ||
FLAGS.save_step, | ||
) | ||
|
||
# DATASETS/DATALOADER | ||
dataset = datasets.CIFAR10( | ||
root="./data", | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=FLAGS.batch_size, | ||
shuffle=True, | ||
num_workers=FLAGS.num_workers, | ||
drop_last=True, | ||
) | ||
|
||
datalooper = infiniteloop(dataloader) | ||
|
||
# MODELS | ||
net_model = UNetModelWrapper( | ||
dim=(3, 32, 32), | ||
num_res_blocks=2, | ||
num_channels=FLAGS.num_channel, | ||
channel_mult=[1, 2, 2, 2], | ||
num_heads=4, | ||
num_head_channels=64, | ||
attention_resolutions="16", | ||
dropout=0.1, | ||
).to( | ||
device | ||
) # new dropout + bs of 128 | ||
|
||
ema_model = copy.deepcopy(net_model) | ||
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) | ||
sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) | ||
if FLAGS.parallel: | ||
print( | ||
"Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory." | ||
) | ||
net_model = torch.nn.DataParallel(net_model) | ||
ema_model = torch.nn.DataParallel(ema_model) | ||
|
||
# show model size | ||
model_size = 0 | ||
for param in net_model.parameters(): | ||
model_size += param.data.nelement() | ||
print("Model params: %.2f M" % (model_size / 1024 / 1024)) | ||
|
||
################################# | ||
# OT-CFM | ||
################################# | ||
|
||
sigma = 0.0 | ||
if FLAGS.model == "otcfm": | ||
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "icfm": | ||
FM = ConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "fm": | ||
FM = TargetConditionalFlowMatcher(sigma=sigma) | ||
elif FLAGS.model == "si": | ||
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) | ||
else: | ||
raise NotImplementedError( | ||
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" | ||
) | ||
|
||
savedir = FLAGS.output_dir + FLAGS.model + "/" | ||
os.makedirs(savedir, exist_ok=True) | ||
|
||
with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: | ||
for step in pbar: | ||
optim.zero_grad() | ||
x1 = next(datalooper).to(device) | ||
x0 = torch.randn_like(x1) | ||
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1) | ||
vt = net_model(t, xt) | ||
loss = torch.mean((vt - ut) ** 2) | ||
loss.backward() | ||
torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip) # new | ||
optim.step() | ||
sched.step() | ||
ema(net_model, ema_model, FLAGS.ema_decay) # new | ||
|
||
# sample and Saving the weights | ||
if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: | ||
generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal") | ||
generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema") | ||
torch.save( | ||
{ | ||
"net_model": net_model.state_dict(), | ||
"ema_model": ema_model.state_dict(), | ||
"sched": sched.state_dict(), | ||
"optim": optim.state_dict(), | ||
"step": step, | ||
}, | ||
savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(train) |