-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
111 lines (102 loc) · 3.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torchvision as tv
from PIL import Image
from models.model import CNNPlusCNN
from models.model import NewCNNPlusCNNHierAtt
class Config(object):
def __init__(self):
self.max_epoch = 50
self.batch_size = 32
self.encoder_name = 'vgg16'
self.kernel_size = 3
self.num_layers = 6
self.channels = 300
self.prediction_dim = 4096
# ak_token: 9489
self.voc_size = 9489
self.attention_tracker = False
self.width = 224
self.height = 224
self.is_gpu = True
self.is_train = True
self.shuffle = True
self.num_workers = 8
self.transformer = tv.transforms.Compose(
[
tv.transforms.Resize((self.width + 32, self.height + 32)),
tv.transforms.RandomCrop(self.width, self.height),
tv.transforms.RandomHorizontalFlip()
]
)
self.is_dotatt = True
self.image_dir = '' # dir of images
self.train_ann_file = 'data/files/captions_train2014.json'
self.val_ann_file = 'data/files/captions_val2014.json'
self.split_file = 'data/files/caption_id.pkl'
self.vocab_file = 'data/files/vocab.pkl'
self.f = torch.nn.GLU(1)
self.keep_prob = 0.5
if self.is_dotatt:
string = ('%s_numl%s_kz%s_dotatt_keepprob%s') % \
(self.encoder_name, self.num_layers, self.kernel_size, self.keep_prob)
self.trained_model = 'trained_models/' + string + '.pth'
else:
string = ('%s_numl%s_kz%s_nnatt_keepprob%s')%\
(self.encoder_name, self.num_layers, self.kernel_size, self.keep_prob)
self.trained_model = 'trained_models/' + string + '.pth'
self.log_file = 'logs/' + self.trained_model.split('/')[1] + '.txt'
self.result_file = 'results/captions_val2014_' + string + '_results.json'
self.annfile = None
class HierConfig(object):
def __init__(self):
self.max_epoch = 50
self.batch_size = 24
self.encoder_name = 'resnet101'
self.kernel_size = 3
self.num_layers = 6
self.channels = 300
self.hier_att_hidden_size = 512
self.hier_att_lang_hidden_size = 512
self.prediction_dim = 4096
self.voc_size = 9489
self.width = 224
self.height = 224
self.is_gpu = True
self.is_train = True
self.shuffle = True
self.num_workers = 8
self.transformer = tv.transforms.Compose(
[
tv.transforms.Resize((self.width + 32, self.height + 32)),
tv.transforms.RandomCrop(self.width, self.height),
tv.transforms.RandomHorizontalFlip()
]
)
self.is_dotatt = True
self.image_dir = '' ## dir of images
self.train_ann_file = 'data/files/captions_train2014.json'
self.val_ann_file = 'data/files/captions_val2014.json'
self.split_file = 'data/files/caption_id.pkl'
self.vocab_file = 'data/files/vocab.pkl'
self.keep_prob = 0.5
if self.is_dotatt:
string = ('%s_conv3x_numl%s_kz%s_hierdotatt_keepprob%s') % \
(self.encoder_name, self.num_layers, self.kernel_size, self.keep_prob)
self.trained_model = 'trained_models/' + string + '.pth'
else:
string = ('%s_conv3x_numl%s_kz%s_hiernnatt_keepprob%s') % \
(self.encoder_name, self.num_layers, self.kernel_size, self.keep_prob)
self.trained_model = 'trained_models/' + string + '.pth'
self.log_file = 'logs/' + self.trained_model.split('/')[1] + '.txt'
self.result_file = 'results/captions_val2014_' + string + '_results.json'
self.annfile = None
self.weight_decay = 0.0
HIERATT = True
if HIERATT:
config = HierConfig()
model = NewCNNPlusCNNHierAtt(config)
model.trainer()
else:
config = Config()
model = CNNPlusCNN(config)
model.trainer()