forked from udacity/CVND---Gesture-Recognition
-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
data_loader.py
112 lines (91 loc) · 3.73 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import glob
import numpy as np
import torch
from PIL import Image
from data_parser import JpegDataset
from torchvision.transforms import *
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG']
def default_loader(path):
return Image.open(path).convert('RGB')
class VideoFolder(torch.utils.data.Dataset):
def __init__(self, root, csv_file_input, csv_file_labels, clip_size,
nclips, step_size, is_val, transform=None,
loader=default_loader):
self.dataset_object = JpegDataset(csv_file_input, csv_file_labels, root)
self.csv_data = self.dataset_object.csv_data
self.classes = self.dataset_object.classes
self.classes_dict = self.dataset_object.classes_dict
self.root = root
self.transform = transform
self.loader = loader
self.clip_size = clip_size
self.nclips = nclips
self.step_size = step_size
self.is_val = is_val
def __getitem__(self, index):
item = self.csv_data[index]
img_paths = self.get_frame_names(item.path)
imgs = []
for img_path in img_paths:
img = self.loader(img_path)
img = self.transform(img)
imgs.append(torch.unsqueeze(img, 0))
target_idx = self.classes_dict[item.label]
# format data to torch
data = torch.cat(imgs)
data = data.permute(1, 0, 2, 3)
return (data, target_idx)
def __len__(self):
return len(self.csv_data)
def get_frame_names(self, path):
frame_names = []
for ext in IMG_EXTENSIONS:
frame_names.extend(glob.glob(os.path.join(path, "*" + ext)))
frame_names = list(sorted(frame_names))
num_frames = len(frame_names)
# if (num_frames == 0):
# # print("!!!!!!!!!!!!!!! " + path)
# set number of necessary frames
if self.nclips > -1:
num_frames_necessary = self.clip_size * self.nclips * self.step_size
else:
num_frames_necessary = num_frames
# pick frames
offset = 0
if num_frames_necessary > num_frames:
# print(num_frames_necessary, num_frames)
# pad last frame if video is shorter than necessary
frame_names += [frame_names[-1]] * (num_frames_necessary - num_frames)
elif num_frames_necessary < num_frames:
# If there are more frames, then sample starting offset
diff = (num_frames - num_frames_necessary)
# Temporal augmentation
if not self.is_val:
offset = np.random.randint(0, diff)
frame_names = frame_names[offset:num_frames_necessary +
offset:self.step_size]
return frame_names
if __name__ == '__main__':
transform = Compose([
CenterCrop(84),
ToTensor(),
# Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
])
loader = VideoFolder(root="/hdd/20bn-datasets/20bn-jester-v1/",
csv_file_input="csv_files/jester-v1-validation.csv",
csv_file_labels="csv_files/jester-v1-labels.csv",
clip_size=18,
nclips=1,
step_size=2,
is_val=False,
transform=transform,
loader=default_loader)
# data_item, target_idx = loader[0]
# save_images_for_debug("input_images", data_item.unsqueeze(0))
train_loader = torch.utils.data.DataLoader(
loader,
batch_size=10, shuffle=False,
num_workers=5, pin_memory=True)