-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
192 lines (180 loc) · 8.07 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import configargparse
import os
import json
import random
from functools import partial
from collections import defaultdict, OrderedDict, Counter
from tqdm import tqdm, trange
import torch
import torch.backends.cudnn as cudnn
from torchtext.data import Iterator, Dataset
from corpus import Corpus
from evaluate import evaluate_lm, evaluate_lm_at_t
from utils import DotDict, Logger, load_config, lm_factory, get_lm_optimizer
def main(opt):
exit_code = 0
opt.hostname = os.uname()[1]
opt.running = True
# cudnn
if opt.device > -1:
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.device)
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
# seed
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.device > -1:
torch.cuda.manual_seed_all(opt.manualSeed)
##################################################################################################################
# Data
##################################################################################################################
# load config
data_opt = load_config(os.path.join('config', opt.corpus, opt.config, 'corpus.yaml'))
opt.update(data_opt)
# load data
corpus = Corpus(opt.dataroot)
# split
trainset, valset, testset = corpus.split(opt.config, opt.min_freq)
# dataloaders
# -- train
train_loader = Iterator(trainset, opt.batch_size, repeat=False, sort_within_batch=True, device=device)
# -- val
ts_val = sorted(list(set([ex.timestep for ex in valset])))
val_loaders = []
for t in ts_val:
val_t = Dataset(valset.examples, valset.fields, filter_pred=lambda x: x.timestep == t)
val_t.sort_key = lambda x: len(x.text)
val_t_loader = Iterator(val_t, opt.batch_size, train=False, device=device)
val_loaders.append((t, val_t_loader))
val_loaders = OrderedDict(val_loaders)
# -- test
ts_tests = sorted(list(set([ex.timestep for ex in testset])))
test_loaders = []
if opt.config == 'prediction':
for t, loader in val_loaders.items():
test_loaders.append((t, loader))
for t in ts_tests:
test_t = Dataset(testset.examples, testset.fields, filter_pred=lambda x: x.timestep == t)
test_t.sort_key = lambda x: len(x.text)
test_t_loader = Iterator(test_t, opt.batch_size, train=False, device=device)
test_loaders.append((t, test_t_loader))
test_loaders = OrderedDict(test_loaders)
# opt
opt.ntoken = corpus.vocab_size
opt.padding_idx = corpus.pad_idx
opt.nts = max(ex.timestep for ex in trainset) + 1
opt.nwords = sum(len(ex.text) for ex in trainset)
# print info
print('Vocab size: {}'.format(opt.ntoken))
print(f'{len(trainset)} training documents with {opt.nwords} tokens on {opt.nts} timesteps')
##################################################################################################################
# Model
##################################################################################################################
# load config
model_opt = load_config(os.path.join('config', opt.corpus, opt.config, '{}.yaml'.format(opt.model)))
opt.update(model_opt)
# buid model
print('Building model...')
model = lm_factory(opt).to(device)
##################################################################################################################
# Optimizer
##################################################################################################################
optimizer = get_lm_optimizer(model, opt)
if 'lr_scheduling' in opt:
if opt.lr_scheduling == 'linear':
opt.min_lr == 0
opt.niter = opt.niter_burnin + opt.niter_scheduling
niter = opt.niter_scheduling
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lr_lambda=lambda i: max(0, (niter - i) / niter))
if opt.lr_scheduling == 'reduce_on_plateau':
assert opt.min_lr > 0
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
patience=opt.patience, factor=opt.lr_decay)
else:
lr_scheduler = None
##################################################################################################################
# Log
##################################################################################################################
opt.xproot = os.path.join(opt.xproot, opt.corpus, opt.config, opt.model, opt.name)
print(f'New experiment logged at {opt.xproot}')
logger = Logger(opt.xproot)
logger.init(opt)
##################################################################################################################
# Trainning
##################################################################################################################
print('Training...')
pb = trange(opt.niter, ncols=0)
ppl_eval = None
finished = False
itr = -1
try:
while not finished:
for batch in train_loader:
itr += 1
model.train()
# io
text = batch.text[0][:-1]
target = batch.text[0][1:]
timestep = batch.timestep
# closure
log_train = model.closure(text, target, timestep, optimizer, opt)
# eval
if itr > 0 and itr % opt.niter_checkpoint == 0:
model.eval()
with torch.no_grad():
score, log_val = evaluate_lm(model, val_loaders, opt)
# checkpoint
log_train['lr'] = optimizer.param_groups[0]['lr']
logger.log(itr, 'train', log_train)
logger.log(itr, 'val', log_val)
logger.checkpoint(itr)
# reduce_on_plateau lr scheduling
if lr_scheduler and itr >= opt.niter_burnin and opt.lr_scheduling == 'reduce_on_plateau':
lr_scheduler.step(score)
lr = optimizer.param_groups[0]['lr']
if lr < opt.min_lr:
finished = True
break
# progress bar
pb.update(opt.niter_checkpoint)
pb.set_postfix(chkpt=logger.chkpt, loss=log_train['loss'], score=score, lr=lr)
# other lr scheduling
if lr_scheduler and itr >= opt.niter_burnin and opt.lr_scheduling != 'reduce_on_plateau':
lr_scheduler.step()
lr = optimizer.param_groups[0]['lr']
if lr < opt.min_lr:
finished = True
except KeyboardInterrupt:
exit_code = 130
pb.close()
print('Evaluating...')
model.eval()
with torch.no_grad():
_, log_val = evaluate_lm(model, val_loaders, opt)
_, results = evaluate_lm(model, test_loaders, opt)
log_train['lr'] = optimizer.param_groups[0]['lr']
logger.log(itr, 'train', log_train)
logger.log(itr, 'val', log_val)
logger.log(itr, 'test', results)
logger.checkpoint(itr)
logger.terminate(model, optimizer)
return exit_code
if __name__ == '__main__':
# arguments
p = configargparse.ArgParser()
p.add('--xproot', type=str, default='xp', help='Base saving directory')
p.add('--corpus', required=True, type=str, help='Corpus name')
p.add('--config', required=True, type=str, help='Evaluation configuration: prediction | modeling')
p.add('--model', required=True, type=str, help='Model name: lstm | drlm')
p.add('--name', type=str, default='debug', help='Experiment name')
p.add('--device', type=int, default=-1, help='-1: cpu; > -1: cuda device id')
p.add('--manualSeed', type=int, help='manual seed')
# parse
opt = p.parse_args()
# main
main(DotDict(vars(opt)))