forked from rasbt/machine-learning-book
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ch17_part2.py
528 lines (315 loc) · 12.4 KB
/
ch17_part2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
# coding: utf-8
import sys
from python_environment_check import check_packages
#from google.colab import drive
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import grad as torch_grad
# # Machine Learning with PyTorch and Scikit-Learn
# # -- Code Examples
# ## Package version checks
# Add folder to path in order to load from the check_packages.py script:
sys.path.insert(0, '..')
# Check recommended package versions:
d = {
'torch': '1.8.0',
'torchvision': '0.9.0',
'numpy': '1.21.2',
'matplotlib': '3.4.3',
}
check_packages(d)
# # Chapter 17 - Generative Adversarial Networks for Synthesizing New Data (Part 2/2)
# **Contents**
#
# - [Improving the quality of synthesized images using a convolutional and Wasserstein GAN](#Improving-the-quality-of-synthesized-images-using-a-convolutional-and-Wasserstein-GAN)
# - [Transposed convolution](#Transposed-convolution)
# - [Batch normalization](#Batch-normalization)
# - [Implementing the generator and discriminator](#Implementing-the-generator-and-discriminator)
# - [Dissimilarity measures between two distributions](#Dissimilarity-measures-between-two-distributions)
# - [Using EM distance in practice for GANs](#Using-EM-distance-in-practice-for-GANs)
# - [Gradient penalty](#Gradient-penalty)
# - [Implementing WGAN-GP to train the DCGAN model](#Implementing-WGAN-GP-to-train-the-DCGAN-model)
# - [Mode collapse](#Mode-collapse)
# - [Other GAN applications](#Other-GAN-applications)
# - [Summary](#Summary)
# Note that the optional watermark extension is a small IPython notebook plugin that I developed to make the code reproducible. You can just skip the following line(s).
# # Improving the quality of synthesized images using a convolutional and Wasserstein GAN
# ## Transposed convolution
# ## Batch normalization
# ## Implementing the generator and discriminator
# * **Setting up the Google Colab**
#drive.mount('/content/drive/')
print(torch.__version__)
print("GPU Available:", torch.cuda.is_available())
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = "cpu"
# ## Train the DCGAN model
image_path = './'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))
])
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
train=True,
transform=transform,
download=False)
batch_size = 64
torch.manual_seed(1)
np.random.seed(1)
## Set up the dataset
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,
shuffle=True, drop_last=True)
def make_generator_network(input_size, n_filters):
model = nn.Sequential(
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
bias=False),
nn.BatchNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
nn.BatchNorm2d(n_filters*2),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_filters),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
nn.Tanh())
return model
class Discriminator(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_filters * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
nn.BatchNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
nn.Sigmoid())
def forward(self, input):
output = self.network(input)
return output.view(-1, 1).squeeze(0)
z_size = 100
image_size = (28, 28)
n_filters = 32
gen_model = make_generator_network(z_size, n_filters).to(device)
print(gen_model)
disc_model = Discriminator(n_filters).to(device)
print(disc_model)
## Loss function and optimizers:
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0003)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
def create_noise(batch_size, z_size, mode_z):
if mode_z == 'uniform':
input_z = torch.rand(batch_size, z_size, 1, 1)*2 - 1
elif mode_z == 'normal':
input_z = torch.randn(batch_size, z_size, 1, 1)
return input_z
## Train the discriminator
def d_train(x):
disc_model.zero_grad()
# Train discriminator with a real batch
batch_size = x.size(0)
x = x.to(device)
d_labels_real = torch.ones(batch_size, 1, device=device)
d_proba_real = disc_model(x)
d_loss_real = loss_fn(d_proba_real, d_labels_real)
# Train discriminator on a fake batch
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_proba_fake = disc_model(g_output)
d_labels_fake = torch.zeros(batch_size, 1, device=device)
d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)
# gradient backprop & optimize ONLY D's parameters
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()
## Train the generator
def g_train(x):
gen_model.zero_grad()
batch_size = x.size(0)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_labels_real = torch.ones((batch_size, 1), device=device)
g_output = gen_model(input_z)
d_proba_fake = disc_model(g_output)
g_loss = loss_fn(d_proba_fake, g_labels_real)
# gradient backprop & optimize ONLY G's parameters
g_loss.backward()
g_optimizer.step()
return g_loss.data.item()
mode_z = 'uniform'
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)
def create_samples(g_model, input_z):
g_output = g_model(input_z)
images = torch.reshape(g_output, (batch_size, *image_size))
return (images+1)/2.0
epoch_samples = []
num_epochs = 100
torch.manual_seed(1)
for epoch in range(1, num_epochs+1):
gen_model.train()
d_losses, g_losses = [], []
for i, (x, _) in enumerate(mnist_dl):
d_loss, d_proba_real, d_proba_fake = d_train(x)
d_losses.append(d_loss)
g_losses.append(g_train(x))
print(f'Epoch {epoch:03d} | Avg Losses >>'
f' G/D {torch.FloatTensor(g_losses).mean():.4f}'
f'/{torch.FloatTensor(d_losses).mean():.4f}')
gen_model.eval()
epoch_samples.append(
create_samples(gen_model, fixed_z).detach().cpu().numpy())
selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
for j in range(5):
ax = fig.add_subplot(6, 5, i*5+j+1)
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
ax.text(
-0.06, 0.5, f'Epoch {e}',
rotation=90, size=18, color='red',
horizontalalignment='right',
verticalalignment='center',
transform=ax.transAxes)
image = epoch_samples[e-1][j]
ax.imshow(image, cmap='gray_r')
# plt.savefig('figures/ch17-dcgan-samples.pdf')
plt.show()
# ## Dissimilarity measures between two distributions
# ## Using EM distance in practice for GANs
# ## Gradient penalty
# ## Implementing WGAN-GP to train the DCGAN model
def make_generator_network_wgan(input_size, n_filters):
model = nn.Sequential(
nn.ConvTranspose2d(input_size, n_filters*4, 4, 1, 0,
bias=False),
nn.InstanceNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*4, n_filters*2, 3, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters*2),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters*2, n_filters, 4, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(n_filters, 1, 4, 2, 1, bias=False),
nn.Tanh())
return model
class DiscriminatorWGAN(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, n_filters, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters, n_filters*2, 4, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*2, n_filters*4, 3, 2, 1, bias=False),
nn.InstanceNorm2d(n_filters*4),
nn.LeakyReLU(0.2),
nn.Conv2d(n_filters*4, 1, 4, 1, 0, bias=False),
nn.Sigmoid())
def forward(self, input):
output = self.network(input)
return output.view(-1, 1).squeeze(0)
gen_model = make_generator_network_wgan(z_size, n_filters).to(device)
disc_model = DiscriminatorWGAN(n_filters).to(device)
g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0002)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)
def gradient_penalty(real_data, generated_data):
batch_size = real_data.size(0)
# Calculate interpolation
alpha = torch.rand(real_data.shape[0], 1, 1, 1, requires_grad=True, device=device)
interpolated = alpha * real_data + (1 - alpha) * generated_data
# Calculate probability of interpolated examples
proba_interpolated = disc_model(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=proba_interpolated, inputs=interpolated,
grad_outputs=torch.ones(proba_interpolated.size(), device=device),
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(batch_size, -1)
gradients_norm = gradients.norm(2, dim=1)
return lambda_gp * ((gradients_norm - 1)**2).mean()
## Train the discriminator
def d_train_wgan(x):
disc_model.zero_grad()
batch_size = x.size(0)
x = x.to(device)
# Calculate probabilities on real and generated data
d_real = disc_model(x)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_generated = disc_model(g_output)
d_loss = d_generated.mean() - d_real.mean() + gradient_penalty(x.data, g_output.data)
d_loss.backward()
d_optimizer.step()
return d_loss.data.item()
## Train the generator
def g_train_wgan(x):
gen_model.zero_grad()
batch_size = x.size(0)
input_z = create_noise(batch_size, z_size, mode_z).to(device)
g_output = gen_model(input_z)
d_generated = disc_model(g_output)
g_loss = -d_generated.mean()
# gradient backprop & optimize ONLY G's parameters
g_loss.backward()
g_optimizer.step()
return g_loss.data.item()
epoch_samples_wgan = []
lambda_gp = 10.0
num_epochs = 100
torch.manual_seed(1)
critic_iterations = 5
for epoch in range(1, num_epochs+1):
gen_model.train()
d_losses, g_losses = [], []
for i, (x, _) in enumerate(mnist_dl):
for _ in range(critic_iterations):
d_loss = d_train_wgan(x)
d_losses.append(d_loss)
g_losses.append(g_train_wgan(x))
print(f'Epoch {epoch:03d} | D Loss >>'
f' {torch.FloatTensor(d_losses).mean():.4f}')
gen_model.eval()
epoch_samples_wgan.append(
create_samples(gen_model, fixed_z).detach().cpu().numpy())
selected_epochs = [1, 2, 4, 10, 50, 100]
# selected_epochs = [1, 10, 20, 30, 50, 70]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
for j in range(5):
ax = fig.add_subplot(6, 5, i*5+j+1)
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
ax.text(
-0.06, 0.5, f'Epoch {e}',
rotation=90, size=18, color='red',
horizontalalignment='right',
verticalalignment='center',
transform=ax.transAxes)
image = epoch_samples_wgan[e-1][j]
ax.imshow(image, cmap='gray_r')
# plt.savefig('figures/ch17-wgan-gp-samples.pdf')
plt.show()
# ## Mode collapse
#
# ----
#
#
# Readers may ignore the next cell.
#
#