-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
118 lines (92 loc) · 3.41 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from argparse import ArgumentParser
import os.path as path
from multiprocessing import Pool
import os
import pickle
from Transformer.handle import init_preprocess_options
def handle_data_single_worker(src_lines, tgt_lines, word2ind, is_test):
src_res = []
tgt_res = []
for src_line, tgt_line in zip(src_lines, tgt_lines):
src_res.append(
[
word2ind[i] if i in word2ind else word2ind["<unk>"]
for i in src_line.strip().split(" ")
]
)
if is_test:
tgt_res.append([i for i in tgt_line.strip().split(" ")])
else:
tgt_res.append(
[
word2ind[i] if i in word2ind else word2ind["<unk>"]
for i in tgt_line.strip().split(" ")
]
)
return src_res, tgt_res
def solve(args):
word2ind = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<unk>": 3}
root_dir: str = args.data_path
if root_dir[-1] in ["/", "\\"]:
root_dir = root_dir[:-1]
dist_dir = path.join(args.dist_dir, path.basename(root_dir))
with open(path.join(root_dir, args.vocab_name), "r", encoding="utf-8") as vocab:
vocab_size = 0
for ind, line in enumerate(vocab.readlines()):
line = line.strip().split()
word2ind[line[0]] = ind + 4
vocab_size = ind + 5
os.makedirs(dist_dir)
with open(path.join(dist_dir, "dict.txt"), "w", encoding="utf-8") as fl:
print(vocab_size, file=fl)
for k, v in word2ind.items():
print(f"{k} {v}", file=fl)
for split in ["test", "train", "valid"]:
pool = Pool(args.workers)
with open(
path.join(root_dir, f"{split}.{args.src_lang}"), "r", encoding="utf-8"
) as src, open(
path.join(root_dir, f"{split}.{args.tgt_lang}"), "r", encoding="utf-8"
) as tgt:
src_res = []
tgt_res = []
result = []
src_lines = []
tgt_lines = []
for ind, (src_line, tgt_line) in enumerate(
zip(src.readlines(), tgt.readlines())
):
if ind > 0 and ind % args.lines_per_thread == 0:
result.append(
pool.apply_async(
handle_data_single_worker,
(src_lines, tgt_lines, word2ind, split == "test"),
)
)
src_lines = []
tgt_lines = []
src_lines.append(src_line)
tgt_lines.append(tgt_line)
if len(src_lines):
result.append(
pool.apply_async(
handle_data_single_worker,
(src_lines, tgt_lines, word2ind, split == "test"),
)
)
pool.close()
pool.join()
for res in result:
res = res.get()
src_res += res[0]
tgt_res += res[1]
with open(path.join(dist_dir, f"{split}.{args.src_lang}"), "wb") as src, open(
path.join(dist_dir, f"{split}.{args.tgt_lang}"), "wb"
) as tgt:
pickle.dump(src_res, src)
pickle.dump(tgt_res, tgt)
if __name__ == "__main__":
parser = ArgumentParser()
init_preprocess_options(parser)
args = parser.parse_args()
solve(args)