diff --git a/batch_gen_resize.py b/batch_gen_resize.py index 202b4af..6752286 100644 --- a/batch_gen_resize.py +++ b/batch_gen_resize.py @@ -18,8 +18,15 @@ def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rat file_ptr = open(vid, 'r') self.list_of_examples = file_ptr.read().split('\n')[:-1] file_ptr.close() - - self.mask = torch.ones(self.num_classes, 2000, dtype=torch.float) + + if '50salads' in gt_path: + self.fix_size = 5000 + elif 'breakfast' in gt_path: + self.fix_size = 2000 + elif 'gtea' in gt_path: + self.fix_size = 1000 + + self.mask = torch.ones(self.num_classes, self.fix_size, dtype=torch.float) def __getitem__(self, idx): @@ -45,8 +52,8 @@ def __getitem__(self, idx): batch_input = torch.from_numpy(batch_input) batch_target = torch.from_numpy(batch_target) - batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=5000, mode='nearest').squeeze() - batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=5000, mode='nearest').squeeze().long() + batch_input = torch.nn.functional.interpolate(batch_input.unsqueeze(0), size=self.fix_size, mode='nearest').squeeze() + batch_target = torch.nn.functional.interpolate(batch_target.unsqueeze(0).unsqueeze(0), size=self.fix_size, mode='nearest').squeeze().long() np.save(self.features_path + batch.split('.')[0] + '_fix', batch_input.numpy()) np.save(self.features_path + batch.split('.')[0] + '_fix_label', batch_target.numpy())