-
Notifications
You must be signed in to change notification settings - Fork 28
/
wgan_v2.py
108 lines (89 loc) · 4.11 KB
/
wgan_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import time
import argparse
import importlib
import tensorflow as tf
from scipy.misc import imsave
from visualize import *
class WassersteinGAN(object):
def __init__(self, g_net, d_net, x_sampler, z_sampler, data, model, scale=10.0):
self.model = model
self.data = data
self.g_net = g_net
self.d_net = d_net
self.x_sampler = x_sampler
self.z_sampler = z_sampler
self.x_dim = self.d_net.x_dim
self.z_dim = self.g_net.z_dim
self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x')
self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
self.x_ = self.g_net(self.z)
self.d = self.d_net(self.x, reuse=False)
self.d_ = self.d_net(self.x_)
self.g_loss = tf.reduce_mean(self.d_)
self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)
epsilon = tf.random_uniform([], 0.0, 1.0)
x_hat = epsilon * self.x + (1 - epsilon) * self.x_
d_hat = self.d_net(x_hat)
ddx = tf.gradients(d_hat, x_hat)[0]
print(ddx.get_shape().as_list())
ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale)
self.d_loss = self.d_loss + ddx
self.d_adam, self.g_adam = None, None
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9)\
.minimize(self.d_loss, var_list=self.d_net.vars)
self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9)\
.minimize(self.g_loss, var_list=self.g_net.vars)
gpu_options = tf.GPUOptions(allow_growth=True)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def train(self, batch_size=64, num_batches=1000000):
plt.ion()
self.sess.run(tf.global_variables_initializer())
start_time = time.time()
for t in range(0, num_batches):
d_iters = 5
#if t % 500 == 0 or t < 25:
# d_iters = 100
for _ in range(0, d_iters):
bx = self.x_sampler(batch_size)
bz = self.z_sampler(batch_size, self.z_dim)
self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz})
bz = self.z_sampler(batch_size, self.z_dim)
self.sess.run(self.g_adam, feed_dict={self.z: bz, self.x: bx})
if t % 100 == 0:
bx = self.x_sampler(batch_size)
bz = self.z_sampler(batch_size, self.z_dim)
d_loss = self.sess.run(
self.d_loss, feed_dict={self.x: bx, self.z: bz}
)
g_loss = self.sess.run(
self.g_loss, feed_dict={self.z: bz}
)
print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' %
(t, time.time() - start_time, d_loss, g_loss))
if t % 100 == 0:
bz = self.z_sampler(batch_size, self.z_dim)
bx = self.sess.run(self.x_, feed_dict={self.z: bz})
bx = xs.data2img(bx)
#fig = plt.figure(self.data + '.' + self.model)
#grid_show(fig, bx, xs.shape)
bx = grid_transform(bx, xs.shape)
imsave('logs/{}/{}.png'.format(self.data, t/100), bx)
#fig.savefig('logs/{}/{}.png'.format(self.data, t/100))
if __name__ == '__main__':
parser = argparse.ArgumentParser('')
parser.add_argument('--data', type=str, default='mnist')
parser.add_argument('--model', type=str, default='dcgan')
parser.add_argument('--gpus', type=str, default='0')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
data = importlib.import_module(args.data)
model = importlib.import_module(args.data + '.' + args.model)
xs = data.DataSampler()
zs = data.NoiseSampler()
d_net = model.Discriminator()
g_net = model.Generator()
wgan = WassersteinGAN(g_net, d_net, xs, zs, args.data, args.model)
wgan.train()