-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
85 lines (60 loc) · 2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# coding=utf-8
"""Train your model and save it in model/."""
import argparse
import os
import glob
import pickle
class Model:
def __init__(self, dim):
self.elements = {"": 0}
self.dim = dim
def __getitem__(self, key):
if key in self.elements:
return self.elements[key]
else:
return 0
def add(self, word):
if word not in self.elements:
self.elements[word] = 0
self.elements[word] += 1
if len(word) == 1:
self.elements[""] += 1
def arg_parse():
parser = argparse.ArgumentParser(description="Train Model")
parser.add_argument("--data", type=str, default="train/2016-*.pkl", help="training data (.pkl) path, allow glob)")
parser.add_argument("-n", "--n-gram", type=int, default=2, help="train an n-gram model")
parser.add_argument("--save", type=str, default="model/model.pkl", help="output model pickle path")
args = parser.parse_args()
assert 2 <= args.n_gram <= 4
return args
def load_train(train_path):
train_data = []
for file in glob.glob(train_path):
with open(file, "rb") as f:
data = pickle.load(f)
train_data += data
return train_data
def train(model, train_data):
n = model.dim
print("Now training a", n, "grams model.")
count = 0
for sentence in train_data:
for k in range(1, n + 1): # k-gram elements
for i in range(len(sentence) - k + 1):
word = ""
for temp in range(i, i + k):
word += sentence[temp]
model.add(word)
count += 1
if count % 10000 == 0:
print(count, "sentences trained.")
print("Training done!")
def main():
args = arg_parse()
model = Model(args.n_gram)
train_data = load_train(args.data) # default training db: 1121k+ data
train(model, train_data)
with open(args.save, "wb") as f:
pickle.dump(model, f)
if __name__ == "__main__":
main()