-
Notifications
You must be signed in to change notification settings - Fork 0
/
textcnn_main.py
73 lines (58 loc) · 3.41 KB
/
textcnn_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import argparse
from init_seeds import Seeds
from model_feature import GeneratePairsFeature
from textcnn_model import TextCNN
from textcnn_data_processer import TextCNNDataProcessor
from textcnn_trainer import TextCNNTrainer
parser = argparse.ArgumentParser(description='TextCNN text classifier')
parser.add_argument('-lr', type=float, default=0.001, help='学习率')
parser.add_argument('-epoch', type=int, default=20, help='epoch,每一个epoch大小所有样本都训练了一次')
parser.add_argument('-batch-size', type=int, default=50, help='batch,每一个batch大小更新一次参数')
parser.add_argument('-filter-num', type=int, default=100, help='卷积核的个数')
parser.add_argument('-filter-sizes', type=str, default='3,4,5', help='不同卷积核大小')
parser.add_argument('-dropout', type=float, default=0.5, help='随机失活率')
parser.add_argument('-static', type=bool, default=True, help='是否使用预训练词向量')
parser.add_argument('-label-num', type=int, default=20, help='标签个数(可自动获取)')
parser.add_argument('-embedding-dim', type=int, default=300, help='词向量的维度(预训练词向量可自动获取)')
parser.add_argument('-fine-tune', type=bool, default=True, help='预训练词向量是否要微调,不需要微调设置为True')
parser.add_argument('-disable-cuda', type=bool, default=False, help='是否禁用GPU')
parser.add_argument('-log-interval', type=int, default=1, help='经过多少iteration记录一次训练状态')
parser.add_argument('-test-interval', type=int, default=50, help='经过多少iteration对验证集进行测试')
parser.add_argument('-early-stopping', type=int, default=1000, help='早停时迭代的次数')
parser.add_argument('-save-best', type=bool, default=True, help='当得到更好的准确度是否要保存')
parser.add_argument('-model-save-dir', type=str, default='model_dir', help='存储训练模型位置')
parser.add_argument('-feature-save-dir', type=str, default='feature_dir', help='存储文本TextCNN特征')
args = parser.parse_args()
if __name__ == '__main__':
seeds = Seeds()
seeds.init_seeds()
print('0.加载可用设备...')
if not args.disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
args.device = device
print('args.disable_cuda: ', args.disable_cuda)
print('torch.cuda.is_available(): ', torch.cuda.is_available())
print('\n0.使用设备名称:{}\n'.format(device))
print('1.正在加载数据...')
processor = TextCNNDataProcessor(args)
train_iter, val_iter, test_iter = processor.load_data()
args.vocab_size, args.embedding_dim, args.label_num = processor.get_args()
args.vectors = processor.get_build_vocab()
print('\n加载数据完成!\n')
print('2.正在加载模型...')
model = TextCNN(args).to(device)
print('\n加载模型完成!\n')
print('3.开始训练模型...')
trainer = TextCNNTrainer(args)
trainer.train(train_iter, val_iter, model)
print('\n训练模型完成!\n')
print('4.测试静态模型...')
trainer.test(model, test_iter)
print('\n测试模型完成!\n')
print('5.获取特征...')
feature_extractor = GeneratePairsFeature(image_model_type='vgg19', text_model_type=model, args=args)
feature_extractor.save_features_labels(train_iter, val_iter, test_iter)
print('\n获取特征完成!\n')