From 8c1b8aa211a70213aacd939516baf07735e4484a Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Sat, 28 Apr 2018 19:43:06 -0500 Subject: [PATCH] Sort the features in the forwarding instead of dataloader. --- dataloader.py | 6 ++++-- models/AttModel.py | 16 ++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/dataloader.py b/dataloader.py index 79f7fbac..d90cea71 100644 --- a/dataloader.py +++ b/dataloader.py @@ -151,9 +151,11 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): info_dict['file_path'] = self.info['images'][ix]['file_path'] infos.append(info_dict) - #sort by att_feat length + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True)) data = {} data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch])) # merge att_feats diff --git a/models/AttModel.py b/models/AttModel.py index ba7c8df5..40184f9d 100644 --- a/models/AttModel.py +++ b/models/AttModel.py @@ -25,10 +25,22 @@ from .CaptionModel import CaptionModel +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) + return tmp, inv_ix + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + def pack_wrapper(module, att_feats, att_masks): if att_masks is not None: - packed = pack_padded_sequence(att_feats, list(att_masks.data.long().sum(1)), batch_first=True) - return pad_packed_sequence(PackedSequence(module(packed[0]), packed[1]), batch_first=True)[0] + packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) else: return module(att_feats)