-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_wgan.py
119 lines (94 loc) · 4.54 KB
/
train_wgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import time
import sys
from collections import deque
from model.gan import Generator, Discriminator
from util import device
from util import create_text_slice
from datasets import VoxelDataset
from torch.utils.data import DataLoader
show_viewer = "nogui" not in sys.argv
if __name__ == "__main__":
if show_viewer:
from rendering import MeshRenderer
viewer = MeshRenderer()
generator = Generator()
generator.filename = "wgan-generator.to"
critic = Discriminator()
critic.filename = "wgan-critic.to"
critic.use_sigmoid = False
if "continue" in sys.argv:
generator.load()
critic.load()
LEARN_RATE = 0.000053
BATCH_SIZE = 50
CRITIC_UPDATES_PER_GENERATOR_UPDATE = 1
CRITIC_WEIGHT_LIMIT = 0.01
dataset = VoxelDataset.glob('data/vox64/**.npy')
data_loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=12)
generator_optimizer = optim.RMSprop(generator.parameters(), lr=LEARN_RATE)
critic_optimizer = optim.RMSprop(critic.parameters(), lr=LEARN_RATE)
log_file = open("plots/wgan_training.csv", "a" if "continue" in sys.argv else "w")
def train():
history_fake = deque(maxlen=50)
history_real = deque(maxlen=50)
for epoch in count():
batch_index = 0
epoch_start_time = time.time()
for batch in data_loader:
try:
# train critic
current_batch_size = batch.shape[0] # equals BATCH_SIZE for all batches except the last one
generator.zero_grad()
critic.zero_grad()
fake_sample = generator.generate(sample_size = current_batch_size).detach()
fake_critic_output = critic(fake_sample)
valid_critic_output = critic(batch.to(device))
critic_loss = torch.mean(fake_critic_output) - torch.mean(valid_critic_output)
critic_loss.backward()
critic_optimizer.step()
critic.clip_weights(CRITIC_WEIGHT_LIMIT)
# train generator
if batch_index % CRITIC_UPDATES_PER_GENERATOR_UPDATE == 0:
generator.zero_grad()
critic.zero_grad()
fake_sample = generator.generate(sample_size = BATCH_SIZE)
if show_viewer:
viewer.set_voxels(fake_sample[0, :, :, :].squeeze().detach().cpu().numpy())
fake_critic_output = critic(fake_sample)
generator_loss = -torch.mean(fake_critic_output)
generator_loss.backward()
generator_optimizer.step()
history_fake.append(torch.mean(fake_critic_output).item())
history_real.append(torch.mean(valid_critic_output).item())
if "verbose" in sys.argv:
print("epoch " + str(epoch) + ", batch " + str(batch_index) \
+ ": fake value: " + '{0:.1f}'.format(history_fake[-1]) \
+ ", valid value: " + '{0:.1f}'.format(history_real[-1]))
batch_index += 1
except KeyboardInterrupt:
if show_viewer:
viewer.stop()
return
generator.save()
critic.save()
if epoch % 20 == 0:
generator.save(epoch=epoch)
critic.save(epoch=epoch)
if "show_slice" in sys.argv:
voxels = generator.generate().squeeze()
print(create_text_slice(voxels))
epoch_duration = time.time() - epoch_start_time
fake_prediction = np.mean(history_fake)
valid_prediction = np.mean(history_real)
print('Epoch {:d} ({:.1f}s), critic values: {:.2f}, {:.2f}'.format(
epoch, epoch_duration, fake_prediction, valid_prediction))
log_file.write("{:d} {:.1f} {:.2f} {:.2f}\n".format(
epoch, epoch_duration, fake_prediction, valid_prediction))
log_file.flush()
train()