forked from rosinality/vq-vae-2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_vqvae.py
executable file
·109 lines (79 loc) · 2.79 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from dataset import ABCDFrameDataset
from vqvae import VQVAE
from scheduler import CycleScheduler
def train(epoch, loader, model, optimizer, scheduler, device):
loader = tqdm(loader)
criterion = nn.MSELoss()
latent_loss_weight = 0.25
sample_size = 25
mse_sum = 0
mse_n = 0
for i, img in enumerate(loader):
model.zero_grad()
img = img.to(device)
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
loss.backward()
if scheduler is not None:
scheduler.step()
optimizer.step()
mse_sum += recon_loss.item() * img.shape[0]
mse_n += img.shape[0]
lr = optimizer.param_groups[0]['lr']
loader.set_description(
(
f'epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; '
f'latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; '
f'lr: {lr:.5f}'
)
)
# if i % 100 == 0:
# model.eval()
# sample = img[:sample_size]
# with torch.no_grad():
# out, _ = model(sample)
# utils.save_image(
# torch.cat([sample, out], 0),
# f'sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png',
# nrow=sample_size,
# normalize=True,
# range=(-1, 1),
# )
# model.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=4)
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--sched', type=str)
parser.add_argument('path', type=str)
args = parser.parse_args()
print(args)
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
dataset = ABCDFrameDataset(args.path)
loader = DataLoader(dataset, batch_size=args.size,
shuffle=True, num_workers=4)
model = nn.DataParallel(VQVAE()).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = None
if args.sched == 'cycle':
scheduler = CycleScheduler(
optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
)
for i in range(args.epoch):
train(i, loader, model, optimizer, scheduler, device)
torch.save(
model.module.state_dict(
), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt'
)