-
Notifications
You must be signed in to change notification settings - Fork 55
/
datasets.py
46 lines (39 loc) · 1.5 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import imghdr
import torch.utils.data as data
from PIL import Image
class ImageDataset(data.Dataset):
def __init__(self, data_dir, transform=None, recursive_search=False):
super(ImageDataset, self).__init__()
self.data_dir = os.path.expanduser(data_dir)
self.transform = transform
self.imgpaths = self.__load_imgpaths_from_dir(self.data_dir, walk=recursive_search)
def __len__(self):
return len(self.imgpaths)
def __getitem__(self, index, color_format='RGB'):
img = Image.open(self.imgpaths[index])
img = img.convert(color_format)
if self.transform is not None:
img = self.transform(img)
return img
def __is_imgfile(self, filepath):
filepath = os.path.expanduser(filepath)
if os.path.isfile(filepath) and imghdr.what(filepath):
return True
return False
def __load_imgpaths_from_dir(self, dirpath, walk=False):
imgpaths = []
dirpath = os.path.expanduser(dirpath)
if walk:
for (root, _, files) in os.walk(dirpath):
for file in files:
file = os.path.join(root, file)
if self.__is_imgfile(file):
imgpaths.append(file)
else:
for path in os.listdir(dirpath):
path = os.path.join(dirpath, path)
if not self.__is_imgfile(path):
continue
imgpaths.append(path)
return imgpaths