-
Notifications
You must be signed in to change notification settings - Fork 35
/
datagen.py
69 lines (57 loc) · 1.91 KB
/
datagen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import print_function
import io
import os
import sys
import random
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
class ListDataset(data.Dataset):
def __init__(self, root, list_file, transform):
'''
Args:
root: (str) ditectory to images.
list_file: (str) path to index file.
transform: ([transforms]) image transforms.
'''
self.root = root
self.transform = transform
self.fname = []
self.fiiqa = []
with io.open(list_file, encoding='gbk') as f:
lines = f.readlines()
self.num_imgs = len(lines)
for line in lines:
sp = line.strip().split()
self.fname.append(sp[0])
self.fiiqa.append(int(sp[1]))
def __getitem__(self, idx):
'''Load image.
Args:
idx: (int) image index.
Returns:
img: (tensor) image tensor.
fiiqa: (float) fiiqa.
'''
# Load image and bbox locations.
fname = self.fname[idx]
fiiqa = self.fiiqa[idx]
img = Image.open(os.path.join(self.root, fname)).convert('RGB')
img = self.transform(img)
return img, fiiqa
def __len__(self):
return self.num_imgs
def test():
import torchvision
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
dataset = ListDataset(root='./data/validationset/val-faces/', list_file='./data/validationset/val-faces/new_4people_val_standard.txt', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)
for img, fiiqa in dataloader:
print(img.size())
print(fiiqa.size())