forked from MishaLaskin/vqvae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
102 lines (79 loc) · 3.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import utils
from models.vqvae import VQVAE
parser = argparse.ArgumentParser()
"""
Hyperparameters
"""
timestamp = utils.readable_timestamp()
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--n_updates", type=int, default=5000)
parser.add_argument("--n_hiddens", type=int, default=128)
parser.add_argument("--n_residual_hiddens", type=int, default=32)
parser.add_argument("--n_residual_layers", type=int, default=2)
parser.add_argument("--embedding_dim", type=int, default=64)
parser.add_argument("--n_embeddings", type=int, default=512)
parser.add_argument("--beta", type=float, default=.25)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--log_interval", type=int, default=50)
parser.add_argument("--dataset", type=str, default='CIFAR10')
# whether or not to save model
parser.add_argument("-save", action="store_true")
parser.add_argument("--filename", type=str, default=timestamp)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.save:
print('Results will be saved in ./results/vqvae_' + args.filename + '.pth')
"""
Load data and define batch data loaders
"""
training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders(
args.dataset, args.batch_size)
"""
Set up VQ-VAE model with components defined in ./models/ folder
"""
model = VQVAE(args.n_hiddens, args.n_residual_hiddens,
args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device)
"""
Set up optimizer and training loop
"""
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)
model.train()
results = {
'n_updates': 0,
'recon_errors': [],
'loss_vals': [],
'perplexities': [],
}
def train():
for i in range(args.n_updates):
(x, _) = next(iter(training_loader))
x = x.to(device)
optimizer.zero_grad()
embedding_loss, x_hat, perplexity = model(x)
recon_loss = torch.mean((x_hat - x)**2) / x_train_var
loss = recon_loss + embedding_loss
loss.backward()
optimizer.step()
results["recon_errors"].append(recon_loss.cpu().detach().numpy())
results["perplexities"].append(perplexity.cpu().detach().numpy())
results["loss_vals"].append(loss.cpu().detach().numpy())
results["n_updates"] = i
if i % args.log_interval == 0:
"""
save model and print values
"""
if args.save:
hyperparameters = args.__dict__
utils.save_model_and_results(
model, results, hyperparameters, args.filename)
print('Update #', i, 'Recon Error:',
np.mean(results["recon_errors"][-args.log_interval:]),
'Loss', np.mean(results["loss_vals"][-args.log_interval:]),
'Perplexity:', np.mean(results["perplexities"][-args.log_interval:]))
if __name__ == "__main__":
train()