-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
450 lines (368 loc) · 16.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import os
import sys
import glob
import re
import json
import math
from collections import defaultdict
import numpy as np
from PIL import Image
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset, ConcatDataset
import torchvision
import nltk
def show_versions():
print('python =', sys.version)
print('nltk =', nltk.__version__)
print('torch =', torch.__version__)
print('torchvision =', torchvision.__version__)
print('cuda =', torch.version.cuda)
print('cudnn =', torch.backends.cudnn.version())
print('cuda.is_available =', torch.cuda.is_available())
def basename(fname, split=None):
if split is not None:
fname.split(split)
return os.path.splitext(os.path.basename(fname))[0]
def feats_to_str(feats):
return '+'.join(feats.internal + [os.path.splitext(os.path.basename(f))[0]
for f in feats.external])
# This is to print the float without exponential-notation, and without trailing zeros.
# Normal formatting, e.g.: '{:f}'.format(0.01) produces "0.010000"
def f2s(f):
return '{:0.16f}'.format(f).rstrip('0')
def get_model_name(args, params):
"""Create model name"""
if args.model_name is not None:
model_name = args.model_name
elif args.load_model:
model_name = os.path.split(os.path.dirname(args.load_model))[-1]
else:
bn = args.model_basename
feat_spec = feats_to_str(params.features)
if params.has_persist_features():
feat_spec += '-' + feats_to_str(params.persist_features)
model_name = ('{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}'.
format(bn, params.embed_size, params.hidden_size, params.num_layers,
params.batch_size, args.optimizer, f2s(params.learning_rate),
f2s(args.weight_decay), params.dropout, params.encoder_dropout,
feat_spec))
return model_name
def get_model_path(args, params, epoch):
model_name = get_model_name(args, params)
file_name = 'ep{}.model'.format(epoch)
model_path = os.path.join(args.output_root, args.model_path, model_name, file_name)
return model_path
# TODO: convert parameters to **kwargs
def save_model(args, params, encoder, decoder, optimizer, epoch, vocab):
state = {
'hierarchical_model': params.hierarchical_model,
'epoch': epoch + 1,
# Attention models can in principle be trained without an encoder:
'encoder': encoder.state_dict() if encoder is not None else None,
'decoder': decoder.state_dict(),
'optimizer': optimizer.state_dict(),
'embed_size': params.embed_size,
'hidden_size': params.hidden_size,
'num_layers': params.num_layers,
'batch_size': params.batch_size,
'learning_rate': params.learning_rate,
'dropout': params.dropout,
'encoder_dropout': params.encoder_dropout,
'encoder_non_lin': params.encoder_non_lin,
'features': params.features,
'ext_features_dim': params.ext_features_dim,
'persist_features': params.persist_features,
'attention': params.attention,
'vocab': vocab,
'skip_start_token': params.skip_start_token,
'rnn_arch': params.rnn_arch,
'rnn_hidden_init': params.rnn_hidden_init,
'share_embedding_weights': params.share_embedding_weights,
'command_history': params.command_history + [' '.join(sys.argv)]
}
if params.hierarchical_model:
state['max_sentences'] = params.max_sentences
state['dropout_stopping'] = params.dropout_stopping
state['dropout_fc'] = params.dropout_fc
state['fc_size'] = params.fc_size
state['coherent_sentences'] = params.coherent_sentences
state['coupling_alpha'] = params.coupling_alpha
state['coupling_beta'] = params.coupling_beta
model_path = get_model_path(args, params, epoch + 1)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(state, model_path)
print('Saved model as {}'.format(model_path))
if args.verbose:
print(params)
def stats_filename(args, params, postfix):
model_name = get_model_name(args, params)
model_dir = os.path.join(args.output_root, args.model_path, model_name)
if postfix is None:
json_name = 'train_stats.json'
else:
json_name = 'train_stats-{}.json'.format(postfix)
return os.path.join(model_dir, json_name)
def init_stats(args, params, postfix=None):
filename = stats_filename(args, params, postfix)
if os.path.exists(filename):
with open(filename, 'r') as fp:
return json.load(fp)
else:
return dict()
def save_stats(args, params, all_stats, postfix=None, writer=None):
filename = stats_filename(args, params, postfix)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, 'w') as outfile:
json.dump(all_stats, outfile, indent=2)
print('Wrote stats to {}'.format(filename))
# Write events to tensorboardx if available:
if writer is not None:
epoch = max([int(k) for k in all_stats.keys()])
writer.add_scalars('train_stats', all_stats[str(epoch)], epoch)
def log_model_data(params, model, n_iter, writer):
"""Log model data using tensorboard"""
def _get_weights(x):
"""Clone tensor x to numpy for logging"""
return x.clone().cpu().detach().numpy()
if params.hierarchical_model:
word_decoder = model.decoder.word_decoder
sent_decoder = model.decoder
# Log Coherent model data:
if params.coherent_sentences:
cu = sent_decoder.coupling_unit
writer.add_histogram('weights/coupling/linear_1',
_get_weights(cu.linear1.weight),
n_iter)
writer.add_histogram('weights/coupling/linear_2',
_get_weights(cu.linear2.weight),
n_iter)
writer.add_histogram('weights/coupling/gru_hh_l0',
_get_weights(cu.gate.weight_hh_l0),
n_iter)
writer.add_histogram('weights/coupling/gru_ih_l0',
_get_weights(cu.gate.weight_ih_l0),
n_iter)
# Log SentenceRNN data
writer.add_histogram('weights/sentence_RNN/linear_1',
_get_weights(sent_decoder.linear1.weight),
n_iter)
writer.add_histogram('weights/sentence_RNN/linear_2',
_get_weights(sent_decoder.linear2.weight),
n_iter)
writer.add_histogram('weights/sentence_RNN/rnn_hh_l0',
_get_weights(sent_decoder.sentence_rnn.weight_hh_l0),
n_iter)
writer.add_histogram('weights/sentence_RNN/rnn_ih_l0',
_get_weights(sent_decoder.sentence_rnn.weight_ih_l0),
n_iter)
else:
word_decoder = model.decoder
# Log WordRNN data:
writer.add_histogram('weights/word_RNN/embed',
_get_weights(word_decoder.embed.weight),
n_iter)
writer.add_histogram('weights/word_RNN/rnn_hh_l0',
_get_weights(word_decoder.rnn.weight_hh_l0),
n_iter)
writer.add_histogram('weights/word_RNN/rnn_ih_l0',
_get_weights(word_decoder.rnn.weight_ih_l0),
n_iter)
if params.share_embedding_weights:
writer.add_histogram('weights/word_RNN/embedding_projection',
_get_weights(word_decoder.projection.weight),
n_iter)
def get_teacher_prob(k, i, beta=1):
"""Inverse sigmoid sampling scheduler determines the probability
with which teacher forcing is turned off, more info here:
https://arxiv.org/pdf/1506.03099.pdf"""
if k == 0:
return 1.0
i = i * beta
p = k / (k + np.exp(i / k))
return p
# Simple gradient clipper from tutorial, can be replaced with torch's own
# using it now to stay close to reference Attention implementation
def clip_gradients(optimizer, grad_clip):
"""
Clips gradients computed during backpropagation to avoid explosion of gradients.
:param optimizer: optimizer with the gradients to be clipped
:param grad_clip: clip value
"""
for group in optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
def prepare_hierarchical_targets(last_sentence_indicator, max_sentences, lengths, captions, device):
"""Prepares the training targets used by hierarchical model"""
# Validate that the last sentence indicator is outputting correct data:
last_sentence_indicator = last_sentence_indicator.to(device)
word_rnn_targets = []
for j in range(max_sentences):
if lengths[0, j] == 0:
break # no more sentences at position >= j in current minibatch
# print(lengths[:, i])
# change to offset / first occurance of zero instead of indices
non_zero_idxs = lengths[:, j] > 0
# print(j)
# print('lengths[:, j] {}'.format(lengths[:, j]))
# print('print: captions[:, j] {}'.format(captions[:, j]))
# print('non zero indices: {}'.format(non_zero_idxs))
# print('filtered: {}'.format(lengths[:, i][non_zero_idxs]))
# Pack the non-zero values for each sentence position:
packed = pack_padded_sequence(captions[:, j][non_zero_idxs],
lengths[:, j][non_zero_idxs],
batch_first=True)[0]
word_rnn_targets.append(packed)
targets = (last_sentence_indicator, word_rnn_targets)
return targets
def cyclical_lr(step_sz, min_lr=0.001, max_lr=1, mode='triangular', scale_func=None,
scale_md='cycles', gamma=1.):
"""implements a cyclical learning rate policy (CLR).
Notes: the learning rate of optimizer should be 1
Parameters:
----------
mode : str, optional
one of {triangular, triangular2, exp_range}.
scale_md : str, optional
{'cycles', 'iterations'}.
gamma : float, optional
constant in 'exp_range' scaling function: gamma**(cycle iterations)
Examples:
--------
>>> # the learning rate of optimizer should be 1
>>> optimizer = optim.SGD(model.parameters(), lr=1.)
>>> step_size = 2*len(train_loader)
>>> clr = cyclical_lr(step_size, min_lr=0.001, max_lr=0.005)
>>> scheduler = lr_scheduler.LambdaLR(optimizer, [clr])
>>> # some other operations
>>> scheduler.step()
>>> optimizer.step()
Source: https://github.com/pytorch/pytorch/pull/2016#issuecomment-387755710
"""
if scale_func is None:
if mode == 'triangular':
scale_fn = lambda x: 1.
scale_mode = 'cycles'
elif mode == 'triangular2':
scale_fn = lambda x: 1 / (2. ** (x - 1))
scale_mode = 'cycles'
elif mode == 'exp_range':
scale_fn = lambda x: gamma ** (x)
scale_mode = 'iterations'
else:
raise ValueError('The {} is not valid value!'.format(mode))
else:
scale_fn = scale_func
scale_mode = scale_md
lr_lambda = lambda iters: min_lr + (max_lr - min_lr) * rel_val(iters, step_sz, scale_mode)
def rel_val(iteration, stepsize, mode):
cycle = math.floor(1 + iteration / (2 * stepsize))
x = abs(iteration / stepsize - 2 * cycle + 1)
if mode == 'cycles':
return max(0, (1 - x)) * scale_fn(cycle)
elif mode == 'iterations':
return max(0, (1 - x)) * scale_fn(iteration)
else:
raise ValueError('The {} is not valid value!'.format(scale_mode))
return lr_lambda
def path_from_id(image_dir, image_id):
"""Return image path based on image directory, image id and
glob matching for extension"""
return glob.glob(os.path.join(image_dir, image_id) + '.*')[0]
def load_image(image_path, transform=None):
image = Image.open(image_path)
image = image.resize([224, 224], Image.LANCZOS)
if image.mode != 'RGB':
print('WARNING: converting {} from {} to RGB'.
format(image_path, image.mode))
image = image.convert('RGB')
if transform is not None:
image = transform(image).unsqueeze(0)
return image
def fix_caption(caption):
m = re.match(r'^<start> (.*?)( <end>)?$', caption)
if m is None:
print('ERROR: unexpected caption format: "{}"'.format(caption))
return caption.capitalize()
ret = m.group(1)
ret = re.sub(r'\s([.,])(\s|$)', r'\1\2', ret)
return ret.capitalize()
def torchify_sequence(batch):
final_tensor = torch.tensor([])
for image in batch:
image = image.unsqueeze(0)
final_tensor = torch.cat([final_tensor, image])
return final_tensor
def to_contiguous(tensor):
if tensor.is_contiguous():
return tensor
else:
return tensor.contiguous()
def get_ground_truth_captions(dataset):
"""
Get the list of captions. If the dataset is a ConcatDataset, these ids must be unique.
:param dataset: Dataset or ConcatDataset class
:return: List of captions for every id
"""
if isinstance(dataset, ConcatDataset):
assert all([d.unique_ids for d in dataset.datasets]), \
'All datasets in the concatenation must ensure that the labels are unique to not mix them'
gts = defaultdict(list)
for d in dataset.datasets:
for label, text, idxs in d.data:
gts[label].append(text)
return gts
elif isinstance(dataset, Dataset):
gts = defaultdict(list)
for label, text, idxs in dataset.data:
gts[label].append(text)
return gts
else:
raise NotImplementedError
def trigram_penalty(i, batch_size, sampled_ids, logprobs, trigrams, alpha=2.0):
"""
Inference constraint that penalizes the log-probabilities of words that would result in repeated trigrams.
The penalty is proportional to the number of times the trigram has already been generated.
Source:
https://www.aclweb.org/anthology/D18-1084
https://github.com/lukemelas/image-paragraph-captioning
:param i:
:param batch_size:
:param sampled_ids:
:param logprobs:
:param trigrams:
:param alpha:
:return:
"""
# Mess with trigrams
if i >= 3: # This is counting out the start token. Use 2 on the conditions if no token is generated.
# Store trigram generated at last step
prev_two_batch = sampled_ids[i - 3:i - 1]
for j in range(batch_size): # = seq.size(0)
prev_two = (prev_two_batch[0][j].item(), prev_two_batch[1][j].item())
current = sampled_ids[i - 1][j]
if i == 3: # initialize
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
elif i > 3:
if prev_two in trigrams[j]: # add to list
trigrams[j][prev_two].append(current)
else: # create list
trigrams[j][prev_two] = [current]
# Block used trigrams at next step
prev_two_batch = sampled_ids[i - 2:i]
mask = torch.zeros(logprobs.size()).to(logprobs.device) # batch_size x vocab_size
for j in range(batch_size):
prev_two = (prev_two_batch[0][j].item(), prev_two_batch[1][j].item())
if prev_two in trigrams[j]:
for k in trigrams[j][prev_two]:
mask[j, k] += 1
# Apply mask to log probs
# logprobs = logprobs - (mask * 1e9)
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
# Length penalty
# penalty = 1.00
# formula = penalty**t # (5 + t)**penalty / (5 + 1)**penalty
# helper = (torch.ones(logprobs.shape) - (1.0 - formula) * (torch.arange(logprobs.shape[1]).expand(logprobs.shape) <= 1).float()).cuda()
# logprobs = logprobs * helper
return logprobs