-
Notifications
You must be signed in to change notification settings - Fork 0
/
flowers.py
85 lines (73 loc) · 2.55 KB
/
flowers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Adapted from https://www.tensorflow.org/datasets/catalog/tf_flowers
@ONLINE {tfflowers,
author = "The TensorFlow Team",
title = "Flowers",
month = "jan",
year = "2019",
url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }
"""
URL = "http://download.tensorflow.org/example_images/flower_photos.tgz"
from torch.utils.data import Dataset
import os
import subprocess as sp
from PIL import Image
import numpy as np
from filelock import FileLock
from tempfile import mktemp
def download_and_unpack(URL):
fname = os.path.basename(URL)
with FileLock("/tmp/flowers_download"):
print(f"Downloading into {os.path.abspath(os.curdir)}")
if not os.path.exists(fname):
sp.call(["wget", URL])
if not os.path.exists("flower_photos"):
sp.call(["tar", "xf", fname])
class TFFlowers(Dataset):
def __init__(
self, data_dir=os.path.expanduser("~/.datasets/tfflowers"), img_size=(240, 240)
):
super().__init__()
data_dir = os.path.abspath(data_dir)
cwd = os.path.abspath(os.curdir)
self.img_size = img_size
os.makedirs(data_dir, exist_ok=True)
os.chdir(data_dir)
download_and_unpack(URL)
self.data_dir = os.path.join(data_dir, "flower_photos")
self.ind_to_class = {
i: v
for i, v in enumerate(
["daisy", "dandelion", "roses", "sunflowers", "tulips"]
)
}
self.class_to_ind = {i: v for v, i in self.ind_to_class.items()}
self.classes = []
self.paths = []
for k in self.class_to_ind.keys():
k_paths = [
os.path.abspath(os.path.join(self.data_dir, k, x))
for x in os.listdir(os.path.join(self.data_dir, k))
]
c = self.class_to_ind[k]
for i, p in enumerate(k_paths):
self.paths.append(p)
self.classes.append(c)
os.chdir(cwd)
self.classes = np.array(self.classes)
self.paths = np.array(self.paths)
def __getitem__(self, item):
c = self.classes[item]
path = self.paths[item]
with Image.open(path) as img:
img: Image
img = np.asarray(img.resize(self.img_size))
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
return img, c
def __len__(self):
return len(self.classes)
if __name__ == "__main__":
import matplotlib.pyplot as plt
ds = TFFlowers(data_dir="/tmp/flowers")
shapes = [x.shape for (x, _) in ds]
print(set(shapes))