-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval.py
110 lines (92 loc) · 4.26 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import time
import random
import json
from heapq import nlargest
from torch.utils.data import DataLoader, SequentialSampler
from transformers import RobertaForSequenceClassification, RobertaConfig
from transformers import AutoTokenizer
from dataset import TypingDataset
from model import roberta_mnli_typing
import argparse
def eval(args, eval_dataset, model, tokenizer):
curr_time = time.strftime("%H_%M_%S_%b_%d_%Y", time.localtime())
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1, collate_fn=lambda x: zip(*x))
type_vocab = eval_dataset.label_lst
eval_res = []
for sample in tqdm(eval_dataloader, desc='eval progress'):
premise, entity, annotation, _, _, _, idx = [items for items in sample]
premise = str(premise[0])
entity = str(entity[0])
annotation = list(annotation[0])
idx = str(idx[0])
res = {'id': idx, 'premise': premise, 'entity': entity, 'annotation': annotation}
res_buffer = {}
for batch_id in range(0, len(type_vocab), args.batch):
dat_buffer = type_vocab[batch_id: batch_id+args.batch]
sequence = [f'{premise}{2*tokenizer.sep_token}{entity} is a {label}.' for label in dat_buffer]
inputs = tokenizer(sequence, padding=True, return_tensors='pt').to(args.device)
outputs = model(**inputs)[:, -1]
confidence = outputs.detach().cpu().numpy().tolist()
for idx in range(len(dat_buffer)):
res_buffer[dat_buffer[idx]] = confidence[idx]
confidence_ranking = {labels: res_buffer[labels] for labels in res_buffer
if res_buffer[labels] > args.threshold}
confidence_ranking = {k: v for k, v in sorted(confidence_ranking.items(), key=lambda item: -item[1])}
res['confidence_ranking'] = confidence_ranking
eval_res.append(res)
return eval_res
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir',
type=str,
default='',
help='fine-tuned model dir path')
parser.add_argument('--eval_data_path',
type=str,
default='/data/processed_data/dev_processed.json',
help='dev/test file path')
parser.add_argument('--type_vocab_file',
type=str,
default='',
help='type vocab file path')
parser.add_argument('--batch',
type=int,
default=8,
help='To batchify candidate type words or phrases')
parser.add_argument('--threshold',
type=float,
default=0.0,
help='Threshold for confident score, 0 to print the full ranking of candidates')
args = parser.parse_args()
if not os.path.exists(args.model_dir):
raise ValueError("Cannot find model checkpoint: {}".format(args.model_dir))
try:
# output file would be modelFileName_evalFileName.json
output_suffix = args.eval_data_path.split('/')[-1]
output_path = os.path.join(args.model_dir, f'Evaluation_{output_suffix}')
except:
raise ValueError("Cannot generate output file name, please manually input")
model = roberta_mnli_typing()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
chkpt_path = os.path.join(args.model_dir, 'model')
chkpt = torch.load(chkpt_path, map_location='cpu')
model.load_state_dict(chkpt['model'])
model.to(device)
model.eval()
print(f'Evaluating {args.model_dir}\n on {args.eval_data_path} '
f'\n result file will be saved to {output_path}')
eval_dataset = TypingDataset(args.eval_data_path, args.type_vocab_file)
eval_res = eval(args, eval_dataset, model, tokenizer)
# save res file
with open(output_path, 'w+') as fout:
fout.write("\n".join([json.dumps(items) for items in eval_res]))
if __name__ == "__main__":
main()