Skip to content

Commit

Permalink
Write synthetic dataset to disk once for a faster training
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Mar 28, 2018
1 parent 2c7bbdf commit d0e0311
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 104 deletions.
4 changes: 2 additions & 2 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import matplotlib.pyplot as plt


def plot_imgs(imgs, titles=None, cmap='brg', ylabel='', normalize=False, ax=None):
def plot_imgs(imgs, titles=None, cmap='brg', ylabel='', normalize=False, ax=None, dpi=100):
n = len(imgs)
if not isinstance(cmap, list):
cmap = [cmap]*n
if ax is None:
_, ax = plt.subplots(1, n, figsize=(6*n, 6), dpi=100)
_, ax = plt.subplots(1, n, figsize=(6*n, 6), dpi=dpi)
if n == 1:
ax = [ax]
else:
Expand Down
115 changes: 62 additions & 53 deletions notebooks/visualize_synthetic_shapes.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions superpoint/datasets/synthetic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def overlap(center, rad, centers, rads):
return flag


def draw_multiple_polygons(img, max_sides=6, nb_polygons=30):
def draw_multiple_polygons(img, max_sides=6, nb_polygons=30, **extra):
""" Draw multiple polygons with a random number of corners
and return the corner points
Parameters:
Expand Down Expand Up @@ -247,7 +247,8 @@ def draw_multiple_polygons(img, max_sides=6, nb_polygons=30):
# Color the polygon with a custom background
corners = new_points.reshape((-1, 1, 2))
mask = np.zeros(img.shape, np.uint8)
custom_background = generate_custom_background(img.shape, background_color)
custom_background = generate_custom_background(img.shape, background_color,
**extra)
cv.fillPoly(mask, [corners], 255)
locs = np.where(mask != 0)
img[locs[0], locs[1]] = custom_background[locs[0], locs[1]]
Expand Down Expand Up @@ -619,7 +620,7 @@ def draw_cube(img, min_size_ratio=0.2, min_angle_rot=math.pi / 10,
for i in [0, 1, 2]:
cv.fillPoly(img, [cube[faces[i]].reshape((-1, 1, 2))],
col_face)
thickness = random_state.randint(min_dim * 0.01, min_dim * 0.015)
thickness = random_state.randint(min_dim * 0.003, min_dim * 0.015)
for i in [0, 1, 2]:
for j in [0, 1, 2, 3]:
col_edge = (col_face + 128
Expand Down
224 changes: 178 additions & 46 deletions superpoint/datasets/synthetic_shapes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
import numpy as np
import tensorflow as tf
import cv2
import os
import tarfile
from pathlib import Path
from tqdm import tqdm
import shutil

from .base_dataset import BaseDataset
from superpoint.datasets import synthetic_dataset
from superpoint.settings import DATA_PATH


class SyntheticShapes(BaseDataset):
default_config = {
'image_size': [240, 320],
'primitive': 'all',
'n_background_blobs': 30,
'validation_size': 100,
'test_size': 500,
'primitives': 'all',
'validation_size': -1,
'test_size': -1,
'on-the-fly': False,
'cache_in_memory': False,
'generation': {
'split_sizes': {'training': 5000, 'validation': 200, 'test': 500},
'image_size': [960, 1280],
'random_seed': 0,
'params': {
'draw_stripes': {'transform_params': (0.1, 0.1)},
'generate_background': {'min_kernel_size': 100},
'draw_multiple_polygons': {'kernel_boundaries': (40, 80)}
},
},
'preprocessing': {
'resize': [240, 320],
'blur_size': 11,
},
}
primitives = [
'draw_lines',
Expand All @@ -25,52 +46,163 @@ class SyntheticShapes(BaseDataset):
'gaussian_noise'
]

def parse_primitives(self, primitives):
return self.primitives if (primitives == 'all') else \
(primitives if isinstance(primitives, list) else [primitives])

def dump_primitive_data(self, primitive, tar_path, config):
temp_dir = Path(os.environ['TMPDIR'], primitive)

tf.logging.info('Generating tarfile for primitive {}.'.format(primitive))
synthetic_dataset.set_random_state(np.random.RandomState(
config['generation']['random_seed']))
for split, size in self.config['generation']['split_sizes'].items():
im_dir, pts_dir = [Path(temp_dir, i, split) for i in ['images', 'points']]
im_dir.mkdir(parents=True, exist_ok=True)
pts_dir.mkdir(parents=True, exist_ok=True)

for i in tqdm(range(size), desc=split, leave=False):
image = synthetic_dataset.generate_background(
config['generation']['image_size'],
**config['generation']['params']['generate_background'])
points = np.array(getattr(synthetic_dataset, primitive)(
image, **config['generation']['params'].get(primitive, {})))
points = np.flip(points, 1) # reverse convention with opencv

b = config['preprocessing']['blur_size']
image = cv2.GaussianBlur(image, (b, b), 0)
points = (points * np.array(config['preprocessing']['resize'], np.float)
/ np.array(config['generation']['image_size'], np.float))
image = cv2.resize(image, tuple(config['preprocessing']['resize'][::-1]),
interpolation=cv2.INTER_LINEAR)

cv2.imwrite(str(Path(im_dir, '{}.png'.format(i))), image)
np.save(Path(pts_dir, '{}.npy'.format(i)), points)

# Pack into a tar file
tar = tarfile.open(tar_path, mode='w:gz')
tar.add(temp_dir, arcname=primitive)
tar.close()
shutil.rmtree(temp_dir)
tf.logging.info('Tarfile dumped to {}.'.format(tar_path))

def _init_dataset(self, **config):
assert config['primitive'] in ['all']+self.primitives
return synthetic_dataset

def _get_data(self, dataset, split_name, **config):
def _draw_shape(_):
if config['primitive'] == 'all':
primitive = np.random.choice(self.primitives)
else:
primitive = config['primitive']
im = dataset.generate_background(config['image_size'],
config['n_background_blobs'])
points = np.array(getattr(dataset, primitive)(im))
return im.astype(np.float32), points.astype(np.int32)

def _preprocess(e_in):
e_out = {}
keypoints = tf.reverse(e_in['keypoints'], axis=[-1])
e_out['keypoint_map'] = tf.scatter_nd(
keypoints,
tf.ones([tf.shape(keypoints)[0]], dtype=tf.int32),
tf.shape(e_in['image']))
e_out['image'] = tf.expand_dims(e_in['image'], axis=-1)
return e_out

def _dummy():
primitives = self.parse_primitives(config['primitives'])
assert set(primitives) <= set(self.primitives)

if config['on-the-fly']:
return None

basepath = Path(DATA_PATH, 'synthetic_shapes')
basepath.mkdir(parents=True, exist_ok=True)

splits = {s: {'images': [], 'points': []}
for s in ['training', 'validation', 'test']}
for primitive in primitives:
tar_path = Path(basepath, '{}.tar.gz'.format(primitive))
if not tar_path.exists():
self.dump_primitive_data(primitive, tar_path, config)

# Untar locally
tf.logging.info('Extracting archive for primitive {}.'.format(primitive))
tar = tarfile.open(tar_path)
temp_dir = Path(os.environ['TMPDIR'])
tar.extractall(path=temp_dir)
tar.close()

# Gather filenames in all splits
path = Path(temp_dir, primitive)
for s in splits:
for obj in ['images', 'points']:
splits[s][obj].extend([str(p) for p in Path(path, obj, s).iterdir()])

# Shuffle
for s in splits:
perm = np.random.permutation(len(splits[s]['images']))
for obj in ['images', 'points']:
splits[s][obj] = np.array(splits[s][obj])[perm].tolist()
return splits

def _get_data(self, filenames, split_name, **config):

def _gen_shape():
primitives = self.parse_primitives(config['primitives'])
while True:
yield 0

def _set_shapes(im, keypoints):
im.set_shape(tf.TensorShape(config['image_size']))
keypoints.set_shape(tf.TensorShape([None, 2]))
return im, keypoints

data = tf.data.Dataset.from_generator(_dummy, tf.int32, tf.TensorShape([]))
data = data.map(
lambda i: tuple(tf.py_func(_draw_shape, [i], [tf.float32, tf.int32])),
num_parallel_calls=8)
data = data.map(_set_shapes)
data = data.map(lambda im, keypoints: {'image': im, 'keypoints': keypoints})
data = data.map(_preprocess)

# Make the length of the validation and test sets finite
primitive = np.random.choice(primitives)
image = synthetic_dataset.generate_background(
config['generation']['image_size'],
**config['generation']['params']['generate_background'])
points = np.array(getattr(synthetic_dataset, primitive)(
image, **config['generation']['params'].get(primitive, {})))
yield (np.expand_dims(image, axis=-1).astype(np.float32),
np.flip(points.astype(np.float32), 1))

def _read_image(filename):
image = tf.read_file(filename)
image = tf.image.decode_png(image, channels=1)
return tf.cast(image, tf.float32)

def _read_points(filename):
return np.load(filename.decode('utf-8')).astype(np.float32)

def _downsample(image, coordinates):
with tf.name_scope('gaussian_blur'):
kernel = cv2.getGaussianKernel(config['preprocessing']['blur_size'], 0)
kernel = kernel[:, 0]
kernel = np.outer(kernel, kernel).astype(np.float32)
kernel = tf.convert_to_tensor(kernel)
kernel = tf.expand_dims(tf.expand_dims(kernel, axis=-1), axis=-1)
image = tf.expand_dims(image, axis=0) # add batch dim
image = tf.nn.depthwise_conv2d(image, kernel, [1, 1, 1, 1], 'SAME')
image = image[0] # remove batch dim

ratio = tf.divide(tf.convert_to_tensor(config['preprocessing']['resize']),
tf.shape(image)[0:2])
coordinates = coordinates * tf.cast(ratio, tf.float32)
image = tf.image.resize_images(image, config['preprocessing']['resize'],
method=tf.image.ResizeMethod.BILINEAR)
return image, coordinates

def _coordinates_to_kmap(image, coordinates):
# Round and clip to image size
coordinates = tf.to_int32(tf.round(coordinates))
coordinates = tf.minimum(coordinates,
tf.expand_dims(tf.stack([tf.shape(image)[0]-1,
tf.shape(image)[1]-1]),
axis=0))
kmap = tf.scatter_nd(
coordinates,
tf.ones([tf.shape(coordinates)[0]], dtype=tf.int32),
tf.shape(image)[:2])
return image, kmap

if config['on-the-fly']:
data = tf.data.Dataset.from_generator(
_gen_shape, (tf.float32, tf.float32),
(tf.TensorShape(config['generation']['image_size']+[1]),
tf.TensorShape([None, 2])))
data = data.map(_downsample)
else:
# Initialize filenames with file names
data = tf.data.Dataset.from_tensor_slices(
(filenames[split_name]['images'], filenames[split_name]['points']))
# Read image and point coordinates
data = data.map(
lambda image, points:
(_read_image(image), tf.py_func(_read_points, [points], tf.float32)))
data = data.map(lambda image, points: (image, tf.reshape(points, [-1, 2])))

# Convert point coordinates to a dense keypoint map
data = data.map(_coordinates_to_kmap)
data = data.map(lambda image, kmap: {'image': image, 'keypoint_map': kmap})

if split_name == 'validation':
data = data.take(config['validation_size'])
elif split_name == 'test':
data = data.take(config['test_size'])

if config['cache_in_memory'] and not config['on-the-fly']:
tf.logging.info('Caching data, fist access will take some time.')
data = data.cache()

return data

0 comments on commit d0e0311

Please sign in to comment.