This repository has been archived by the owner on Apr 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 45
/
dm_input.py
62 lines (49 loc) · 2.09 KB
/
dm_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import dm_celeba
FLAGS = tf.app.flags.FLAGS
def input_data(sess, mode, filenames, capacity_factor=3):
# Separate training and test sets
# TBD: Use partition given by dataset creators
assert mode == 'inference' or len(filenames) >= FLAGS.test_vectors
if mode == 'train':
filenames = filenames[FLAGS.test_vectors:]
batch_size = FLAGS.batch_size
elif mode == 'test':
filenames = filenames[:FLAGS.test_vectors]
batch_size = FLAGS.batch_size
elif mode == 'inference':
filenames = filenames[:]
batch_size = 1
else:
raise ValueError('Unknown mode `%s`' % (mode,))
# Read each JPEG file
reader = tf.WholeFileReader()
filename_queue = tf.train.string_input_producer(filenames)
key, value = reader.read(filename_queue)
channels = 3
image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image")
image.set_shape([None, None, channels])
# Crop and other random augmentations
if mode == 'train':
image = tf.image.random_flip_left_right(image)
#image = tf.image.random_saturation(image, .95, 1.05)
#image = tf.image.random_brightness(image, .05)
#image = tf.image.random_contrast(image, .95, 1.05)
size_x, size_y = 80, 100
if mode == 'inference':
# TBD: What does the 'align_corners' parameter do? Stretch blit?
image = tf.image.resize_images(image, (size_y, size_x), method=tf.image.ResizeMethod.AREA)
else:
# Dataset samples are 178x218 pixels
# Select face only without hair
off_x, off_y = 49, 90
image = tf.image.crop_to_bounding_box(image, off_y, off_x, size_y, size_x)
feature = tf.cast(image, tf.float32)/255.0
# Using asynchronous queues
features = tf.train.batch([feature],
batch_size=batch_size,
num_threads=4,
capacity = capacity_factor*batch_size,
name='features')
tf.train.start_queue_runners(sess=sess)
return features