-
Notifications
You must be signed in to change notification settings - Fork 42
/
solver.py
392 lines (325 loc) · 17 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import numpy as np
import os
import time
import datetime
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
from utils import *
from models import Generator, Discriminator
from data.sparse_molecular_dataset import SparseMolecularDataset
class Solver(object):
"""Solver for training and testing StarGAN."""
def __init__(self, config):
"""Initialize configurations."""
# Data loader.
self.data = SparseMolecularDataset()
self.data.load(config.mol_data_dir)
# Model configurations.
self.z_dim = config.z_dim
self.m_dim = self.data.atom_num_types
self.b_dim = self.data.bond_num_types
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.g_repeat_num = config.g_repeat_num
self.d_repeat_num = config.d_repeat_num
self.lambda_cls = config.lambda_cls
self.lambda_rec = config.lambda_rec
self.lambda_gp = config.lambda_gp
self.post_method = config.post_method
self.metric = 'validity,sas'
# Training configurations.
self.batch_size = config.batch_size
self.num_iters = config.num_iters
self.num_iters_decay = config.num_iters_decay
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.dropout = config.dropout
self.n_critic = config.n_critic
self.beta1 = config.beta1
self.beta2 = config.beta2
self.resume_iters = config.resume_iters
# Test configurations.
self.test_iters = config.test_iters
# Miscellaneous.
self.use_tensorboard = config.use_tensorboard
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Directories.
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
self.result_dir = config.result_dir
# Step size.
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
self.lr_update_step = config.lr_update_step
# Build the model and tensorboard.
self.build_model()
if self.use_tensorboard:
self.build_tensorboard()
def build_model(self):
"""Create a generator and a discriminator."""
self.G = Generator(self.g_conv_dim, self.z_dim,
self.data.vertexes,
self.data.bond_num_types,
self.data.atom_num_types,
self.dropout)
self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout)
self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout)
self.g_optimizer = torch.optim.Adam(list(self.G.parameters())+list(self.V.parameters()),
self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G')
self.print_network(self.D, 'D')
self.G.to(self.device)
self.D.to(self.device)
self.V.to(self.device)
def print_network(self, model, name):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(model)
print(name)
print("The number of parameters: {}".format(num_params))
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage))
def build_tensorboard(self):
"""Build a tensorboard logger."""
from logger import Logger
self.logger = Logger(self.log_dir)
def update_lr(self, g_lr, d_lr):
"""Decay learning rates of the generator and discriminator."""
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = g_lr
for param_group in self.d_optimizer.param_groups:
param_group['lr'] = d_lr
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def denorm(self, x):
"""Convert the range from [-1, 1] to [0, 1]."""
out = (x + 1) / 2
return out.clamp_(0, 1)
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).to(self.device)
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
return torch.mean((dydx_l2norm-1)**2)
def label2onehot(self, labels, dim):
"""Convert label indices to one-hot vectors."""
out = torch.zeros(list(labels.size())+[dim]).to(self.device)
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
return out
def classification_loss(self, logit, target, dataset='CelebA'):
"""Compute binary or softmax cross entropy loss."""
if dataset == 'CelebA':
return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
elif dataset == 'RaFD':
return F.cross_entropy(logit, target)
def sample_z(self, batch_size):
return np.random.normal(0, 1, size=(batch_size, self.z_dim))
def postprocess(self, inputs, method, temperature=1.):
def listify(x):
return x if type(x) == list or type(x) == tuple else [x]
def delistify(x):
return x if len(x) > 1 else x[0]
if method == 'soft_gumbel':
softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1,e_logits.size(-1))
/ temperature, hard=False).view(e_logits.size())
for e_logits in listify(inputs)]
elif method == 'hard_gumbel':
softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1,e_logits.size(-1))
/ temperature, hard=True).view(e_logits.size())
for e_logits in listify(inputs)]
else:
softmax = [F.softmax(e_logits / temperature, -1)
for e_logits in listify(inputs)]
return [delistify(e) for e in (softmax)]
def reward(self, mols):
rr = 1.
for m in ('logp,sas,qed,unique' if self.metric == 'all' else self.metric).split(','):
if m == 'np':
rr *= MolecularMetrics.natural_product_scores(mols, norm=True)
elif m == 'logp':
rr *= MolecularMetrics.water_octanol_partition_coefficient_scores(mols, norm=True)
elif m == 'sas':
rr *= MolecularMetrics.synthetic_accessibility_score_scores(mols, norm=True)
elif m == 'qed':
rr *= MolecularMetrics.quantitative_estimation_druglikeness_scores(mols, norm=True)
elif m == 'novelty':
rr *= MolecularMetrics.novel_scores(mols, data)
elif m == 'dc':
rr *= MolecularMetrics.drugcandidate_scores(mols, data)
elif m == 'unique':
rr *= MolecularMetrics.unique_scores(mols)
elif m == 'diversity':
rr *= MolecularMetrics.diversity_scores(mols, data)
elif m == 'validity':
rr *= MolecularMetrics.valid_scores(mols)
else:
raise RuntimeError('{} is not defined as a metric'.format(m))
return rr.reshape(-1, 1)
def train(self):
# Learning rate cache for decaying.
g_lr = self.g_lr
d_lr = self.d_lr
# Start training from scratch or resume training.
start_iters = 0
if self.resume_iters:
start_iters = self.resume_iters
self.restore_model(self.resume_iters)
# Start training.
print('Start training...')
start_time = time.time()
for i in range(start_iters, self.num_iters):
if (i+1) % self.log_step == 0:
mols, _, _, a, x, _, _, _, _ = self.data.next_validation_batch()
z = self.sample_z(a.shape[0])
print('[Valid]', '')
else:
mols, _, _, a, x, _, _, _, _ = self.data.next_train_batch(self.batch_size)
z = self.sample_z(self.batch_size)
# =================================================================================== #
# 1. Preprocess input data #
# =================================================================================== #
a = torch.from_numpy(a).to(self.device).long() # Adjacency.
x = torch.from_numpy(x).to(self.device).long() # Nodes.
a_tensor = self.label2onehot(a, self.b_dim)
x_tensor = self.label2onehot(x, self.m_dim)
z = torch.from_numpy(z).to(self.device).float()
# =================================================================================== #
# 2. Train the discriminator #
# =================================================================================== #
# Compute loss with real images.
logits_real, features_real = self.D(a_tensor, None, x_tensor)
d_loss_real = - torch.mean(logits_real)
# Compute loss with fake images.
edges_logits, nodes_logits = self.G(z)
# Postprocess with Gumbel softmax
(edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
d_loss_fake = torch.mean(logits_fake)
# Compute loss for gradient penalty.
eps = torch.rand(logits_real.size(0),1,1,1).to(self.device)
x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True)
x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True)
grad0, grad1 = self.D(x_int0, None, x_int1)
d_loss_gp = self.gradient_penalty(grad0, x_int0) + self.gradient_penalty(grad1, x_int1)
# Backward and optimize.
d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
# Logging.
loss = {}
loss['D/loss_real'] = d_loss_real.item()
loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_gp'] = d_loss_gp.item()
# =================================================================================== #
# 3. Train the generator #
# =================================================================================== #
if (i+1) % self.n_critic == 0:
# Z-to-target
edges_logits, nodes_logits = self.G(z)
# Postprocess with Gumbel softmax
(edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
g_loss_fake = - torch.mean(logits_fake)
# Real Reward
rewardR = torch.from_numpy(self.reward(mols)).to(self.device)
# Fake Reward
(edges_hard, nodes_hard) = self.postprocess((edges_logits, nodes_logits), 'hard_gumbel')
edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
for e_, n_ in zip(edges_hard, nodes_hard)]
rewardF = torch.from_numpy(self.reward(mols)).to(self.device)
# Value loss
value_logit_real,_ = self.V(a_tensor, None, x_tensor, torch.sigmoid)
value_logit_fake,_ = self.V(edges_hat, None, nodes_hat, torch.sigmoid)
g_loss_value = torch.mean((value_logit_real - rewardR) ** 2 + (
value_logit_fake - rewardF) ** 2)
#rl_loss= -value_logit_fake
#f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2
# Backward and optimize.
g_loss = g_loss_fake + g_loss_value
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging.
loss['G/loss_fake'] = g_loss_fake.item()
loss['G/loss_value'] = g_loss_value.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training information.
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
# Log update
m0, m1 = all_scores(mols, self.data, norm=True) # 'mols' is output of Fake Reward
m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()}
m0.update(m1)
loss.update(m0)
for tag, value in loss.items():
log += ", {}: {:.4f}".format(tag, value)
print(log)
if self.use_tensorboard:
for tag, value in loss.items():
self.logger.scalar_summary(tag, value, i+1)
# Save model checkpoints.
if (i+1) % self.model_save_step == 0:
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(i+1))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
torch.save(self.V.state_dict(), V_path)
print('Saved model checkpoints into {}...'.format(self.model_save_dir))
# Decay learning rates.
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
g_lr -= (self.g_lr / float(self.num_iters_decay))
d_lr -= (self.d_lr / float(self.num_iters_decay))
self.update_lr(g_lr, d_lr)
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def test(self):
# Load the trained generator.
self.restore_model(self.test_iters)
with torch.no_grad():
mols, _, _, a, x, _, _, _, _ = self.data.next_test_batch()
z = self.sample_z(a.shape[0])
# Z-to-target
edges_logits, nodes_logits = self.G(z)
# Postprocess with Gumbel softmax
(edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method)
logits_fake, features_fake = self.D(edges_hat, None, nodes_hat)
g_loss_fake = - torch.mean(logits_fake)
# Fake Reward
(edges_hard, nodes_hard) = self.postprocess((edges_logits, nodes_logits), 'hard_gumbel')
edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1]
mols = [self.data.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True)
for e_, n_ in zip(edges_hard, nodes_hard)]
# Log update
m0, m1 = all_scores(mols, self.data, norm=True) # 'mols' is output of Fake Reward
m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()}
m0.update(m1)
for tag, value in m0.items():
log += ", {}: {:.4f}".format(tag, value)