forked from pytorch/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
142 lines (110 loc) · 5.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import time
import glob
import torch
import torch.optim as O
import torch.nn as nn
from torchtext import data
from torchtext import datasets
from model import SNLIClassifier
from util import get_args, makedirs
args = get_args()
if torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device('cpu')
inputs = data.Field(lower=args.lower, tokenize='spacy')
answers = data.Field(sequential=False)
train, dev, test = datasets.SNLI.splits(inputs, answers)
inputs.build_vocab(train, dev, test)
if args.word_vectors:
if os.path.isfile(args.vector_cache):
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(args.word_vectors)
makedirs(os.path.dirname(args.vector_cache))
torch.save(inputs.vocab.vectors, args.vector_cache)
answers.build_vocab(train)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test), batch_size=args.batch_size, device=device)
config = args
config.n_embed = len(inputs.vocab)
config.d_out = len(answers.vocab)
config.n_cells = config.n_layers
# double the number of cells for bidirectional networks
if config.birnn:
config.n_cells *= 2
if args.resume_snapshot:
model = torch.load(args.resume_snapshot, map_location=device)
else:
model = SNLIClassifier(config)
if args.word_vectors:
model.embed.weight.data.copy_(inputs.vocab.vectors)
model.to(device)
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=args.lr)
iterations = 0
start = time.time()
best_dev_acc = -1
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
makedirs(args.save_path)
print(header)
for epoch in range(args.epochs):
train_iter.init_epoch()
n_correct, n_total = 0, 0
for batch_idx, batch in enumerate(train_iter):
# switch model to training mode, clear gradient accumulators
model.train(); opt.zero_grad()
iterations += 1
# forward pass
answer = model(batch)
# calculate accuracy of predictions in the current batch
n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
n_total += batch.batch_size
train_acc = 100. * n_correct/n_total
# calculate loss of the network output with respect to training labels
loss = criterion(answer, batch.label)
# backpropagate and update optimizer learning rate
loss.backward(); opt.step()
# checkpoint model periodically
if iterations % args.save_every == 0:
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.item(), iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
# evaluate performance on validation set periodically
if iterations % args.dev_every == 0:
# switch model to evaluation mode
model.eval(); dev_iter.init_epoch()
# calculate accuracy on validation set
n_dev_correct, dev_loss = 0, 0
with torch.no_grad():
for dev_batch_idx, dev_batch in enumerate(dev_iter):
answer = model(dev_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()) == dev_batch.label).sum().item()
dev_loss = criterion(answer, dev_batch.label)
dev_acc = 100. * n_dev_correct / len(dev)
print(dev_log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.item(), dev_loss.item(), train_acc, dev_acc))
# update best valiation set accuracy
if dev_acc > best_dev_acc:
# found a model with better validation set accuracy
best_dev_acc = dev_acc
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.item(), iterations)
# save model, delete previous 'best_snapshot' files
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
elif iterations % args.log_every == 0:
# print progress message
print(log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.item(), ' '*8, n_correct/n_total*100, ' '*12))