-
Notifications
You must be signed in to change notification settings - Fork 4
/
make_extraction_labels.py
98 lines (81 loc) · 3.05 KB
/
make_extraction_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""produce the dataset with (psudo) extraction label"""
import os
from os.path import exists, join
import json
from time import time
from datetime import timedelta
import multiprocessing as mp
from cytoolz import curry, compose
from utils import count_data
from metric import compute_rouge_l
try:
DATA_DIR = os.environ['DATA']
except KeyError:
print('please use environment variable to specify data directories')
def _split_words(texts):
return map(lambda t: t.split(), texts)
def get_extract_label(art_sents, abs_sents):
""" greedily match summary sentences to article sentences"""
extracted = []
scores = []
indices = list(range(len(art_sents)))
for abst in abs_sents:
rouges = list(map(compute_rouge_l(reference=abst, mode='r'),
art_sents))
ext = max(indices, key=lambda i: rouges[i])
indices.remove(ext)
extracted.append(ext)
scores.append(rouges[ext])
if not indices:
break
return extracted, scores
@curry
def process(split, i):
data_dir = join(DATA_DIR, split)
with open(join(data_dir, '{}.json'.format(i))) as f:
data = json.loads(f.read())
tokenize = compose(list, _split_words)
art_sents = tokenize(data['article'])
abs_sents = tokenize(data['abstract'])
if art_sents and abs_sents: # some data contains empty article/abstract
extracted, scores = get_extract_label(art_sents, abs_sents)
else:
extracted, scores = [], []
data['extracted'] = extracted
data['score'] = scores
with open(join(data_dir, '{}.json'.format(i)), 'w') as f:
json.dump(data, f, indent=4)
def label_mp(split):
""" process the data split with multi-processing"""
start = time()
print('start processing {} split...'.format(split))
data_dir = join(DATA_DIR, split)
n_data = count_data(data_dir)
with mp.Pool() as pool:
list(pool.imap_unordered(process(split),
list(range(n_data)), chunksize=1024))
print('finished in {}'.format(timedelta(seconds=time() - start)))
def label(split):
start = time()
print('start processing {} split...'.format(split))
data_dir = join(DATA_DIR, split)
n_data = count_data(data_dir)
for i in range(n_data):
print('processing {}/{} ({:.2f}%%)\r'.format(i, n_data, 100 * i / n_data),
end='')
with open(join(data_dir, '{}.json'.format(i))) as f:
data = json.loads(f.read())
tokenize = compose(list, _split_words)
art_sents = tokenize(data['article'])
abs_sents = tokenize(data['abstract'])
extracted, scores = get_extract_label(art_sents, abs_sents)
data['extracted'] = extracted
data['score'] = scores
with open(join(data_dir, '{}.json'.format(i)), 'w') as f:
json.dump(data, f, indent=4)
print('finished in {}'.format(timedelta(seconds=time() - start)))
def main():
for split in ['val', 'train']: # no need of extraction label when testing
label_mp(split)
if __name__ == '__main__':
main()