-
Notifications
You must be signed in to change notification settings - Fork 2
/
cm_main.py
93 lines (81 loc) · 2.8 KB
/
cm_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from tqdm import tqdm
from cm.cm import ConsistencyModel
from cm.toy_tasks.data_generator import DataGenerator
from cm.visualization.vis_utils import plot_main_figure
"""
Discrete training of the consistency model on a toy task.
For better performance, one can pre-training the model with the karras diffusion objective
and then use the weights as initialization for the consistency model.
"""
if __name__ == "__main__":
device = 'cpu'
use_pretraining = True
n_sampling_steps = 20
cm = ConsistencyModel(
lr=1e-4,
sampler_type='onestep',
sigma_data=0.5,
sigma_min=0.05,
sigma_max=1,
conditioned=False,
device='cuda',
rho=7,
t_steps_min=100,
t_steps=100,
ema_rate=0.999,
n_sampling_steps=n_sampling_steps,
use_karras_noise_conditioning=True,
)
train_epochs = 2001
# chose one of the following toy tasks: 'three_gmm_1D' 'uneven_two_gmm_1D' 'two_gmm_1D' 'single_gaussian_1D'
data_manager = DataGenerator('single_gaussian_1D')
samples, cond = data_manager.generate_samples(10000)
samples = samples.reshape(-1, 1).to(device)
pbar = tqdm(range(train_epochs))
# Pretraining if desired
if use_pretraining:
for i in range(train_epochs):
cond = cond.reshape(-1, 1).to(device)
loss = cm.diffusion_train_step(samples, cond, i, train_epochs)
pbar.set_description(f"Step {i}, Loss: {loss:.8f}")
pbar.update(1)
# plot the results of the pretraining diffusion model to compare with the consistency model
plot_main_figure(
data_manager.compute_log_prob,
cm,
100,
train_epochs,
sampling_method='euler',
n_sampling_steps=n_sampling_steps,
x_range=[-4, 4],
save_path='./plots'
)
cm.update_target_network()
pbar = tqdm(range(train_epochs))
for i in range(train_epochs):
cond = cond.reshape(-1, 1).to(device)
loss = cm.train_step(samples, cond, i, train_epochs)
pbar.set_description(f"Step {i}, Loss: {loss:.8f}")
pbar.update(1)
# Plotting the results of the training
# We do this for the one-step and the multi-step sampler to compare the results
plot_main_figure(
data_manager.compute_log_prob,
cm,
100,
train_epochs,
sampling_method='onestep',
x_range=[-4, 4],
save_path='./plots'
)
plot_main_figure(
data_manager.compute_log_prob,
cm,
100,
train_epochs,
sampling_method='multistep',
n_sampling_steps=n_sampling_steps,
x_range=[-4, 4],
save_path='./plots'
)
print('done')