-
Notifications
You must be signed in to change notification settings - Fork 15
/
args.py
112 lines (90 loc) · 3.92 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# coding: utf-8
'''
这里存放一些参数
'''
import os
import time
# 训练相关的超参数
num_epochs = 100
batch_size = 100
learning_rate = 3e-4
ss_factor = 24
use_cuda = True
use_checkpoint = False
time_format = '%m-%d_%X'
current_time = time.strftime(time_format, time.localtime())
env_tag = '%s_TA-RES_SS0.75' % (current_time)
log_environment = os.path.join('logs', env_tag) # tensorboard的记录环境
# 模型相关的超参数
projected_size = 500
hidden_size = 1000 # LSTM层的隐层单元数目
frame_shape = (3, 224, 224) # 视频帧的形状
frame_size = 4096 # 视频特征的维度
frame_sample_rate = 10 # 视频帧的采样率
num_frames = 20 # 图像序列的规定长度
num_words = 30 # 文本序列的规定长度
# 数据相关的参数
# 提供两个数据集:MSR-VTT和MSVD
msrvtt_video_root = './datasets/MSR-VTT/TrainValVideo/'
msrvtt_anno_json_path = './datasets/MSR-VTT/train_val_videodatainfo.json'
msrvtt_video_sort_lambda = lambda x: int(x[5:-4])
msrvtt_train_range = None
msrvtt_val_range = None
msrvtt_test_range = None
msvd_video_root = './datasets/MSVD/youtube_videos'
msvd_csv_path = './datasets/MSVD/MSR Video Description Corpus_refine.csv' # 手动修改一些数据集中的错误
msvd_video_name2id_map = './datasets/MSVD/youtube_mapping.txt'
msvd_anno_json_path = './datasets/MSVD/annotations.json' # MSVD并未提供这个文件,需要自己写代码生成(build_msvd_annotation.py)
msvd_video_sort_lambda = lambda x: int(x[3:-4])
msvd_train_range = (0, 1200)
msvd_val_range = (1200, 1300)
msvd_test_range = (1300, 1970)
dataset = {
'msr-vtt': [msrvtt_video_root, msrvtt_video_sort_lambda, msrvtt_anno_json_path,
msrvtt_train_range, msrvtt_val_range, msrvtt_test_range],
'msvd': [msvd_video_root, msvd_video_sort_lambda, msvd_anno_json_path,
msvd_train_range, msvd_val_range, msvd_test_range]
}
# 用video_root和anno_json_path这两个变量来切换所使用的数据集
# video_sort_lambda用来对视频按照名称进行排序
ds = 'msvd'
video_root, video_sort_lambda, anno_json_path, \
train_range, val_range, test_range = dataset[ds]
feat_dir = 'feats'
if not os.path.exists(feat_dir):
os.mkdir(feat_dir)
vocab_pkl_path = os.path.join(feat_dir, ds + '_vocab.pkl')
caption_pkl_path = os.path.join(feat_dir, ds + '_captions.pkl')
caption_pkl_base = os.path.join(feat_dir, ds + '_captions')
train_caption_pkl_path = caption_pkl_base + '_train.pkl'
val_caption_pkl_path = caption_pkl_base + '_val.pkl'
test_caption_pkl_path = caption_pkl_base + '_test.pkl'
sal_h5_path = os.path.join(feat_dir, ds + '_saliency.h5')
sal_h5_dataset = 'sal'
fore_h5_path = os.path.join(feat_dir, ds + '_foreground.h5')
fore_h5_dataset = 'sal'
back_h5_path = os.path.join(feat_dir, ds + '_background.h5')
back_h5_dataset = 'back'
full_h5_path = os.path.join(feat_dir, ds + '_videos.h5')
full_h5_dataset = 'feats'
feature_h5_path = os.path.join(feat_dir, ds + '_features.h5')
feature_h5_feats = 'feats'
# 结果评估相关的参数
result_dir = 'results'
if not os.path.exists(result_dir):
os.mkdir(result_dir)
val_reference_txt_path = os.path.join(result_dir, 'val_references.txt')
val_prediction_txt_path = os.path.join(result_dir, 'val_predictions.txt')
test_reference_txt_path = os.path.join(result_dir, 'test_references.txt')
test_prediction_txt_path = os.path.join(result_dir, 'test_predictions.txt')
# checkpoint相关的超参数
vgg_checkpoint = './models/vgg16-00b39a1b.pth' # 从caffe转换而来
# vgg_checkpoint = './models/vgg16-397923af.pth' # 直接用pytorch训练的模型
decoder_pth_path = os.path.join(result_dir, ds + '_decoder.pth')
best_decoder_pth_path = os.path.join(result_dir, ds + '_best_decoder.pth')
optimizer_pth_path = os.path.join(result_dir, ds + '_optimizer.pth')
best_optimizer_pth_path = os.path.join(result_dir, ds + '_best_optimizer.pth')
# 图示结果相关的超参数
visual_dir = 'visuals'
if not os.path.exists(visual_dir):
os.mkdir(visual_dir)