-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
35 lines (32 loc) · 1.2 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from PIL import Image
from torch.utils.data import Dataset,DataLoader
class TrainingDataset(Dataset):
def __init__(self, img, transform=None):
self.img= img
self.transform=transform
# Length of Dataset will be # of Epochs (6)
def __getitem__(self, index):
epoch = []
for task in self.img[index]:
task_list = []
for frame_sequence in task:
frames = []
gt_path=frame_sequence[3]
im_path=[]
img=[]
for i in range(3):
im_path.append(frame_sequence[i])
for im in im_path:
im_opened=Image.open(im).convert('RGB')
if self.transform is not None:
img.append(self.transform(im_opened))
gt = Image.open(gt_path).convert('RGB')
if self.transform is not None:
gt = self.transform(gt)
frames.append(img)
frames.append(gt)
task_list.append(frames)
epoch.append(task_list)
return epoch
def __len__(self):
return len(self.img)