Skip to content

Commit

Permalink
Sort the features in the forwarding instead of dataloader.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Apr 29, 2018
1 parent 9ebae0a commit 8c1b8aa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
6 changes: 4 additions & 2 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
info_dict['file_path'] = self.info['images'][ix]['file_path']
infos.append(info_dict)

#sort by att_feat length
# #sort by att_feat length
# fc_batch, att_batch, label_batch, gts, infos = \
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
fc_batch, att_batch, label_batch, gts, infos = \
zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True))
data = {}
data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch]))
# merge att_feats
Expand Down
16 changes: 14 additions & 2 deletions models/AttModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,22 @@

from .CaptionModel import CaptionModel

def sort_pack_padded_sequence(input, lengths):
sorted_lengths, indices = torch.sort(lengths, descending=True)
tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
inv_ix = indices.clone()
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
return tmp, inv_ix

def pad_unsort_packed_sequence(input, inv_ix):
tmp, _ = pad_packed_sequence(input, batch_first=True)
tmp = tmp[inv_ix]
return tmp

def pack_wrapper(module, att_feats, att_masks):
if att_masks is not None:
packed = pack_padded_sequence(att_feats, list(att_masks.data.long().sum(1)), batch_first=True)
return pad_packed_sequence(PackedSequence(module(packed[0]), packed[1]), batch_first=True)[0]
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
else:
return module(att_feats)

Expand Down

0 comments on commit 8c1b8aa

Please sign in to comment.