-
Notifications
You must be signed in to change notification settings - Fork 0
/
explore_ufet.py
181 lines (165 loc) · 7.27 KB
/
explore_ufet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Created by xunannancy at 2024/03/04
"""
dataset setting for UFET
Choi, Eunsol, et al. "Ultra-fine entity typing." arXiv preprint arXiv:1807.04905 (2018).
"""
import argparse
from transformers import BertTokenizer
from dataset_utilities import dataset_folder_dict, instance_counts_dict
import os
import pickle
import json
from tqdm import tqdm
from copy import deepcopy
import gzip
import argparse
def construct_dataset_with_context(dataset, max_context_token=256,
context_token='[CONTEXT]', query_token='[QUERY]',
splits='train', debug=False, label_type='combined', overwrite=False):
parent_dataset = 'UFET'
split_name_mapping = {
'train': 'trn',
'dev': 'dev',
'test': 'tst',
}
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
special_tokens_dict = {'additional_special_tokens': [context_token, query_token]}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
nickname = label_type[len('combined'):]
saved_folder = f'{dataset_folder_dict[dataset]}/dpr_processed/query_context_direct_{max_context_token}{nickname}'
os.makedirs(saved_folder, exist_ok=True)
print(f'saved_folder: {saved_folder}')
left_part = context_token + ' '
left_ids = tokenizer.encode(left_part)
label_title_id_mapping = pickle.load(open(f'{dataset_folder_dict[parent_dataset]}/dpr_processed/labels/label_title_id_mapping.pkl', 'rb')) # NOTE: this label_title_id_mapping is universal
label_id_title_mapping = dict()
for label, id_list in label_title_id_mapping.items():
for id in id_list:
label_id_title_mapping[id] = label
combined_title_id_mapping = pickle.load(open(f'{dataset_folder_dict[parent_dataset]}/dpr_processed/labels{nickname}/combined_title_id_mapping.pkl', 'rb'))
combined_doc = dict()
passage_counts = sum(1 for i in open(f'{dataset_folder_dict[parent_dataset]}/dpr_processed/labels{nickname}/combined/combined.jsonl'))
for line in tqdm(open(f'{dataset_folder_dict[parent_dataset]}/dpr_processed/labels{nickname}/combined/combined.jsonl'), total=passage_counts):
info = json.loads(line)
id, contents = info['id'], info['contents']
combined_doc[int(id)] = deepcopy(contents)
for split in splits.split(','):
saved_path = f'{saved_folder}/{split_name_mapping[split]}.json'
if os.path.exists(saved_path) and not overwrite:
print(f'find {split} at {saved_path}...')
continue
num_instances = instance_counts_dict[dataset][split]
data_path = f'{dataset_folder_dict[dataset]}/{split_name_mapping[split]}.json.gz'
saved_instances = list()
for idx, line in tqdm(enumerate(gzip.open(data_path)), total=num_instances):
info = json.loads(line)
title, content = info['title'], info['content']
mid_part = info['content']
mid_ids = tokenizer.encode(mid_part)
right_part = ' '+query_token + f' What is {title}?'
right_ids = tokenizer.encode(right_part)
preserved_context_token_num = max_context_token - 2 - (len(left_ids)-2+len(right_ids)-2)
start_idx, end_idx = 0, preserved_context_token_num
cur_truncated_mid_part = tokenizer.decode(mid_ids[1:-1][start_idx:end_idx])
positive_ctxs, answers = list(), list()
for target_ind in info['target_ind']:
target_text = label_id_title_mapping[target_ind]
if target_text in answers:
continue
# avoid label repetition
answers.append(target_text)
for doc_id in combined_title_id_mapping[target_text]:
positive_ctxs.append({
'title': target_text,
'text': combined_doc[doc_id],
'passage_id': doc_id,
})
cur_instance = {
'question': left_part + cur_truncated_mid_part + right_part,
'answers': answers,
'positive_ctxs': positive_ctxs,
'negative_ctxs': [],
'hard_negatives_ctxs': [],
}
saved_instances.append(cur_instance)
if debug and idx >= 100:
break
print(f'{split}_instances: {len(saved_instances)}, {len(saved_instances)/num_instances}')
with open(saved_path, 'w') as f:
json.dump(saved_instances, f, indent=4)
return
def f1(p, r):
if r == 0.:
return 0.
return 2 * p * r / float(p + r)
def macro(true_and_prediction):
num_examples = len(true_and_prediction)
p = 0.
r = 0.
pred_example_count = 0.
pred_label_count = 0.
gold_label_count = 0.
for true_labels, predicted_labels in true_and_prediction:
if predicted_labels:
pred_example_count += 1
pred_label_count += len(predicted_labels)
per_p = len(set(predicted_labels).intersection(set(true_labels))) / float(len(predicted_labels))
p += per_p
if len(true_labels):
gold_label_count += 1
per_r = len(set(predicted_labels).intersection(set(true_labels))) / float(len(true_labels))
r += per_r
if pred_example_count > 0:
precision = p / pred_example_count
avg_elem_per_pred = pred_label_count / pred_example_count
else:
precision = 0
avg_elem_per_pred = 0
if gold_label_count > 0:
recall = r / gold_label_count
return num_examples, pred_example_count, avg_elem_per_pred, precision, recall, f1(precision, recall)
def evaluate(label_range='general'):
"""
9 general types: person, location, object, orga-
nization, place, entity, object, time, event
• 121 fine-grained types, mapped to fine-grained entity labels from prior work (Ling and Weld,
2012; Gillick et al., 2014) (e.g. film, athlete)
"""
if label_range == 'general':
selected_labels = range(9)
elif label_range == 'fine-grained':
selected_labels = range(130)
true_list, pred_list = list(), list()
for line in gzip.open('./datasets/UFET/crowd/tst.json.gz'):
true_label = json.loads(line)['target_ind']
true_list.append(true_label)
pred_list.append([i for i in true_label if i in selected_labels])
res = macro(list(zip(true_list, pred_list)))
saved_res = {
'precision': res[-3],
'recall': res[-2],
'f1': res[-1],
}
print(f'{label_range}: {saved_res}')
"""
label_range: {'precision': 1.0, 'recall': 0.22638986211182532, 'f1': 0.3691972171426553}
label_range: {'precision': 1.0, 'recall': 0.3479842087040223, 'f1': 0.5163030938449656}
"""
return
if __name__ == '__main__':
parser = argparse.PARSER('ufet dataset preparation and evaluation code')
parser.add_arguments('--stage', type=str, choices=['data_preparation', 'evaluation'], default='data_preparation')
args = parser.parse_args()
if args.stage == 'data_preparation':
construct_dataset_with_context(
dataset='UFET_crowd',
max_context_token=256,
splits='train,dev,test',
debug=False,
label_type='combined',
overwrite=False
)
elif args.stage == 'evaluation':
evaluate(label_range='general')
evaluate(label_range='fine-grained')