-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dataset.py
74 lines (63 loc) · 1.98 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
import os
import glob
import random
import cv2
import numpy as np
import Spectrum
def read_rgb(path):
img = cv2.imread(path)
img = cv2.resize(img, (256, 192), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype('float32')
img = (img - 127.5) / 127.5
return img
def read_scr(path):
data = open(path, 'rb').read()
img = Spectrum.scr_to_image(data)
img = (img - 127.5) / 127.5
return img
def to_image(nn):
return (nn * 127.5 + 127.5).astype('uint8')
def batch(name_glob, read_func, batch_size):
data = np.array([read_func(name) for name in glob.glob(name_glob)])
assert batch_size < data.shape[0]
np.random.shuffle(data)
epoch = 0
pos = 0
while True:
if pos + batch_size >= data.shape[0]:
pos = 0
epoch += 1
np.random.shuffle(data)
yield epoch, data[pos:pos + batch_size]
pos += batch_size
def minibatch(batch_size, rgb_glob='image_rgb/*', scr_glob='image_scr/*.scr'):
rgb_data = batch(rgb_glob, read_rgb, batch_size)
scr_data = batch(scr_glob, read_scr, batch_size)
while True:
epoch1, A = next(rgb_data)
epoch2, B = next(scr_data)
yield max(epoch1, epoch2), A, B
class ImagePool:
def __init__(self, size=200):
self.size = size
self.n = 0
self.images = []
def replace(self, images):
new_images = []
for image in images:
if self.n < self.size:
self.n += 1
self.images.append(image)
new_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
i = random.randint(0, self.size - 1)
tmp = self.images[i]
self.images[i] = image
new_images.append(tmp)
else:
new_images.append(image)
return np.stack(new_images, axis=0)