-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
174 lines (141 loc) · 6.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
pytorch-dl
Created by raj at 6:59 AM, 7/30/20
"""
import os
import time
from math import inf
from torch import nn
from dataset.iwslt_data import rebatch, rebatch_onmt, SimpleLossCompute, NoamOpt, LabelSmoothing
from models.transformer import TransformerEncoderDecoder
from models.utils.model_utils import load_model_state, save_state, get_perplexity
"""Train models."""
import torch
import onmt.opts as opts
from onmt.utils.misc import set_random_seed
from onmt.utils.logging import init_logger, logger
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
def train(opt):
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
set_random_seed(opt.seed, False)
# Load checkpoint if we resume from a previous training.
if opt.train_from:
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
vocab = checkpoint['vocab']
else:
vocab = torch.load(opt.data + '.vocab.pt')
# check for code where vocab is saved instead of fields
# (in the future this will be done in a smarter way)
if old_style_vocab(vocab):
fields = load_old_vocab(
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
else:
fields = vocab
src_vocab = fields['src'].base_field.vocab
trg_vocab = fields['tgt'].base_field.vocab
src_vocab_size = len(src_vocab)
trg_vocab_size = len(trg_vocab)
pad_idx = src_vocab.stoi["<blank>"]
unk_idx = src_vocab.stoi["<unk>"]
start_symbol = trg_vocab.stoi["<s>"]
# patch for fields that may be missing in old data/model
patch_fields(opt, fields)
if len(opt.data_ids) > 1:
train_shards = []
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
if opt.data_ids[0] is not None:
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
model_dir = opt.save_model
try:
os.makedirs(model_dir)
except OSError:
pass
model_dim = opt.state_dim
heads = opt.heads
depth = opt.enc_layers
max_len = 100
model = TransformerEncoderDecoder(k=model_dim, heads=heads, dropout=opt.dropout[0],
depth=depth,
num_emb=src_vocab_size,
num_emb_target=trg_vocab_size,
max_len=max_len,
mask_future_steps=True)
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
start_steps = load_model_state(os.path.join(model_dir, 'checkpoints_best.pt'), opt, model,
data_parallel=False)
criterion = LabelSmoothing(size=trg_vocab_size, padding_idx=pad_idx, smoothing=opt.label_smoothing)
optimizer = NoamOpt(model_dim, 1, 2000, torch.optim.Adam(model.parameters(),
lr=opt.learning_rate,
betas=(0.9, 0.98), eps=1e-9))
compute_loss = SimpleLossCompute(model.generator, criterion, optimizer)
cuda_condition = torch.cuda.is_available() and opt.gpu_ranks
device = torch.device("cuda:0" if cuda_condition else "cpu")
if cuda_condition:
model.cuda()
if cuda_condition and torch.cuda.device_count() > 1:
print("Using %d GPUS for BERT" % torch.cuda.device_count())
model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
previous_best = inf
# start steps defines if training was intrupted
global_steps = start_steps
iterations = 0
max_steps = opt.train_steps
while global_steps <= max_steps:
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
iterations += 1
for i, batch in enumerate(rebatch_onmt(pad_idx, b, device=device) for b in train_iter):
global_steps += 1
model.train()
out = model(batch.src, batch.src_mask, batch.trg, batch.trg_mask)
loss = compute_loss(out, batch.trg_y, batch.ntokens)
total_loss += loss
total_tokens += batch.ntokens
tokens += batch.ntokens
if i % opt.report_every == 0 and i > 0:
elapsed = time.time() - start
print("Global Steps/Max Steps: %d/%d Step: %d Loss: %f PPL: %f Tokens per Sec: %f" %
(global_steps, max_steps, i, loss / batch.ntokens, get_perplexity(loss / batch.ntokens), tokens / elapsed))
start = time.time()
tokens = 0
# checkpoint = "checkpoint.{}.".format(total_loss / total_tokens) + 'epoch' + str(epoch) + ".pt"
# save_state(os.path.join(model_dir, checkpoint), model, criterion, optimizer, epoch)
loss_average = total_loss / total_tokens
checkpoint = "checkpoint.{}.".format(loss_average) + 'epoch' + str(iterations) + ".pt"
save_state(os.path.join(model_dir, checkpoint), model, criterion, optimizer, global_steps, fields, opt)
if previous_best > loss_average:
save_state(os.path.join(model_dir, 'checkpoints_best.pt'), model, criterion, optimizer, global_steps, fields, opt)
previous_best = loss_average
def _get_parser():
parser = ArgumentParser(description='train.py')
opts.config_opts(parser)
opts.model_opts(parser)
opts.train_opts(parser)
return parser
def main():
parser = _get_parser()
opt = parser.parse_args()
train(opt)
if __name__ == "__main__":
main()