forked from brucexia6116/pytorch_Chinese_NER_POS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
158 lines (135 loc) · 5.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# @Author : bamtercelboo
# @Datetime : 2018/1/30 19:50
# @File : main_hyperparams.py.py
# @Last Modify Time : 2018/1/30 19:50
# @Contact : bamtercelboo@{gmail.com, 163.com}
"""
FILE : main_hyperparams.py.py
FUNCTION : main
"""
import argparse
import datetime
import Config.config as configurable
from DataUtils.mainHelp import *
from DataUtils.Alphabet import *
from test import load_test_data
from test import T_Inference
from trainer import Train
import random
# solve default encoding problem
from imp import reload
defaultencoding = 'utf-8'
if sys.getdefaultencoding() != defaultencoding:
reload(sys)
sys.setdefaultencoding(defaultencoding)
# random seed
torch.manual_seed(seed_num)
random.seed(seed_num)
def start_train(train_iter, dev_iter, test_iter, model, config):
"""
:param train_iter: train batch data iterator
:param dev_iter: dev batch data iterator
:param test_iter: test batch data iterator
:param model: nn model
:param config: config
:return: None
"""
t = Train(train_iter=train_iter, dev_iter=dev_iter, test_iter=test_iter, model=model, config=config)
t.train()
print("Finish Train.")
def start_test(train_iter, dev_iter, test_iter, model, alphabet, config):
"""
:param train_iter: train batch data iterator
:param dev_iter: dev batch data iterator
:param test_iter: test batch data iterator
:param model: nn model
:param alphabet: alphabet dict
:param config: config
:return: None
"""
print("\nTesting Start......")
data, path_source, path_result = load_test_data(train_iter, dev_iter, test_iter, config)
infer = T_Inference(model=model, data=data, path_source=path_source, path_result=path_result, alphabet=alphabet,
use_crf=config.use_crf, config=config)
infer.infer2file()
print("Finished Test.")
def main():
"""
main()
:return:
"""
# save file
config.mulu = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# config.add_args(key="mulu", value=datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
config.save_dir = os.path.join(config.save_direction, config.mulu)
if not os.path.isdir(config.save_dir): os.makedirs(config.save_dir)
# get data, iter, alphabet
train_iter, dev_iter, test_iter, alphabet = load_data(config=config)
# get params
get_params(config=config, alphabet=alphabet)
# save dictionary
save_dictionary(config=config)
model = load_model(config)
# print("Training Start......")
if config.train is True:
start_train(train_iter, dev_iter, test_iter, model, config)
exit()
elif config.test is True:
start_test(train_iter, dev_iter, test_iter, model, alphabet, config)
exit()
def parse_argument():
"""
:argument
:return:
"""
parser = argparse.ArgumentParser(description="NER & POS")
parser.add_argument("-c", "--config", dest="config_file", type=str, default="./Config/config.cfg",help="config path")
parser.add_argument("-device", "--device", dest="device", type=str, default="cuda:0", help="device[‘cpu’,‘cuda:0’,‘cuda:1’,......]")
parser.add_argument("--train", dest="train", action="store_true", default=True, help="train model")
parser.add_argument("-p", "--process", dest="process", action="store_true", default=True, help="data process")
parser.add_argument("-t", "--test", dest="test", action="store_true", default=False, help="test model")
parser.add_argument("--t_model", dest="t_model", type=str, default=None, help="model for test")
parser.add_argument("--t_data", dest="t_data", type=str, default=None, help="data[train, dev, test, None] for test model")
parser.add_argument("--predict", dest="predict", action="store_true", default=False, help="predict model")
args = parser.parse_args()
# print(vars(args))
config = configurable.Configurable(config_file=args.config_file)
config.device = args.device
config.train = args.train
config.process = args.process
config.test = args.test
config.t_model = args.t_model
config.t_data = args.t_data
config.predict = args.predict
# config
if config.test is True:
config.train = False
if config.t_data not in [None, "train", "dev", "test"]:
print("\nUsage")
parser.print_help()
print("t_data : {}, not in [None, 'train', 'dev', 'test']".format(config.t_data))
exit()
print("***************************************")
print("Device : {}".format(config.device))
print("Data Process : {}".format(config.process))
print("Train model : {}".format(config.train))
print("Test model : {}".format(config.test))
print("t_model : {}".format(config.t_model))
print("t_data : {}".format(config.t_data))
print("predict : {}".format(config.predict))
print("***************************************")
return config
if __name__ == "__main__":
print("Process ID {}, Process Parent ID {}".format(os.getpid(), os.getppid()))
config = parse_argument()
if config.device != cpu_device:
print("Using GPU To Train......")
device_number = config.device[-1]
torch.cuda.set_device(int(device_number))
print("Current Cuda Device {}".format(torch.cuda.current_device()))
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
print("torch.cuda.initial_seed", torch.cuda.initial_seed())
main()