forked from kimiyoung/transformer-xl
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy patheval.py
184 lines (154 loc) · 7.65 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# coding: utf-8
"""
git clone https://github.com/cybertronai/transformer-xl.git
cd transformer-xl/pytorch
source activate pytorch_p36
# Match eval parameters from paper.
# https://github.com/kimiyoung/transformer-xl/blob/master/tf/scripts/wt103_large_tpu.sh
python eval.py --data=data/wikitext-103 --dataset=wt103 --batch_size=8 --tgt_len=128 --clamp_len=1000 --mem_len=1600 --work_dir=/ncluster/runs/ben-batch-sched-slow2.01/
# new dataset
python eval.py --data=data/wikiextracted/ --dataset=wiki --batch_size=8 --tgt_len=128 --clamp_len=1000 --mem_len=1600 --work_dir=/ncluster/runs.new/ben-txl-large-adam.05 --bpe
"""
import argparse
import math
import os
from typing import List, Tuple
import torch
import tqdm
import globals as g # global state current run, shared between modules
from data_utils import get_lm_corpus
from generate import generate_text, prepare_git_context
from search import hidden_to_softmax
from util import unwrap_model
from utils.exp_utils import get_logger
def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8', 'wt2', 'wiki'],
help='dataset name')
parser.add_argument('--split', type=str, default='all',
choices=['all', 'valid', 'test'],
help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--tgt_len', type=int, default=5,
help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=-1,
help='max positional embedding index')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--bpe', action='store_true', default=False,
help='Use BPE instead of traditional vocabulary.')
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Get logger
logging = get_logger(os.path.join(args.work_dir, 'eval-log.txt'),
log_=not args.no_log)
# Load dataset
corpus = get_lm_corpus(args.data, args.dataset, use_bpe=args.bpe)
ntokens = len(corpus.vocab)
# Load the best saved model.
with open(os.path.join(args.work_dir, 'model-best.pt'), 'rb') as f:
model = torch.load(f)
model_tokens = model.n_token if hasattr(model, 'n_token') else model.module.n_token
assert model_tokens == ntokens, 'vocab size mismatch, did you mean `--bpe`?'
model = model.to(device)
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
if hasattr(model, 'reset_length'):
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
else:
model.module.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
model.clamp_len = args.clamp_len
if args.same_length:
model.same_length = True
# Run on test data.
for split in ('valid', 'test'):
if args.split in (split, 'all'):
it = corpus.get_iterator(split, args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
logging(format_log(args, evaluate(model, it, split), split))
def evaluate(model, eval_iter, label: str, max_eval_steps: int = 0, reset_mems_interval: int = None):
# Turn on evaluation mode which disables dropout.
model.eval()
total_len, total_count = 0, 0
total_loss, total_top1, total_top5, MRR_total = 0., 0, 0, 0.
with torch.no_grad():
mems = tuple()
bar = tqdm.tqdm(eval_iter, leave=False)
for i, (data, target, seq_len) in enumerate(bar):
if 0 < max_eval_steps <= i:
break
if reset_mems_interval is not None and i % reset_mems_interval == 0:
mems = tuple()
ret = model(data, target, *mems, return_hidden=True)
pred_hid, loss, mems = ret[0].half(), ret[1], ret[2:]
softmax = hidden_to_softmax(unwrap_model(model), pred_hid)
# Loss calculation
loss = loss.mean()
total_loss += seq_len * loss.item()
# Accuracy calculation
_, pred_top = torch.topk(softmax, 5)
true_pos = pred_top == target.unsqueeze(-1).expand_as(pred_top)
true_top1 = true_pos[:, :, 0].sum()
true_top5 = true_pos[:, :, :5].sum()
total_top1 += true_top1
total_top5 += true_top5
# MRR calculation
MRR_total += float(
(
true_pos.double() / (
torch.arange(end=true_pos.size(-1), dtype=torch.double, device=true_pos.device) + 1)
).sum()
)
total_len += seq_len
total_count += seq_len * target.size(1)
MRR_top5 = MRR_total / total_count
accuracy_top1 = float(total_top1) / total_count
accuracy_top5 = float(total_top5) / total_count
bar.set_description(f'{label} '
f'| loss: {total_loss / total_len:.2f} '
f'| accuracy@1: {accuracy_top1:.2f} '
f'| accuracy@5: {accuracy_top5:.2f} '
f'| MRR@5: {MRR_top5:.2f} ')
metrics = {
"total_loss": total_loss,
"accuracy_top1": accuracy_top1,
"accuracy_top5": accuracy_top5,
"MRR_top5": MRR_top5,
"total_len": total_len,
}
return metrics
def sample_text(model, length: int, conditional_files: List[str] = None, temperature: float = 1.0) -> Tuple[str, str]:
if not conditional_files:
context = prepare_git_context()
else:
context = prepare_git_context(conditional_files[-1],
conditional_files[:-1] if len(conditional_files) > 1 else None)
text = generate_text(unwrap_model(model), context, length, num_diversity_groups=1, tokenizer=g.corpus.vocab.tokenizer, verbose=False)[0][0]
return context, text
def format_log(args, metrics, split):
if args.dataset in ['enwik8', 'text8']:
special = f'bpc {metrics["total_loss"] / math.log(2):9.5f}'
elif args.dataset == "git":
special = f'accuracy@1 {metrics["accuracy_top1"]} ' \
f'| accuracy@5 {metrics["accuracy_top1"]} ' \
f'| MRR@5 {metrics["MRR_top5"]}'
else:
special = f'ppl {math.exp(metrics["total_loss"]/metrics["total_len"]):9.3f}'
return f'| {split} loss\t{metrics["total_loss"]/metrics["total_len"]:5.4f} | {split}\t{special}\t' \
f'loss {metrics["total_loss"]:.1f}\ttokens {metrics["total_len"]}\n'
if __name__ == '__main__':
main()