From 824df084165ba8a4fc7d217bf438917bdd60309d Mon Sep 17 00:00:00 2001 From: Qi Han <37238621+Hanqer@users.noreply.github.com> Date: Fri, 12 Nov 2021 20:48:24 +0800 Subject: [PATCH] fix the data fix length in global search --- batch_gen_resize.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/batch_gen_resize.py b/batch_gen_resize.py index 202b4af..6752286 100644 --- a/batch_gen_resize.py +++ b/batch_gen_resize.py @@ -18,8 +18,15 @@ def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rat file_ptr = open(vid, 'r') self.list_of_examples = file_ptr.read().split('\n')[:-1] file_ptr.close() - - self.mask = torch.ones(self.num_classes, 2000, dtype=torch.float) + + if '50salads' in gt_path: + self.fix_size = 5000 + elif 'breakfast' in gt_path: + self.fix_size = 2000 + elif 'gtea' in gt_path: + self.fix_size = 1000 + + self.mask = torch.ones(self.num_classes, self.fix_size, dtype=torch.float) def __getitem__(self, idx): @@ -45,8 +52,8 @@ def __getitem__(self, idx): batch_input = torch.from_numpy(batch_input) batch_target = torch.from_numpy(batch_target) - batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=5000, mode='nearest').squeeze() - batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=5000, mode='nearest').squeeze().long() + batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=self.fix_size, mode='nearest').squeeze() + batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=self.fix_size, mode='nearest').squeeze().long() np.save(self.features_path + batch.split('.')[0] + '_fix', batch_input.numpy()) np.save(self.features_path + batch.split('.')[0] + '_fix_label', batch_target.numpy())