-
Notifications
You must be signed in to change notification settings - Fork 108
/
utils.py
81 lines (63 loc) · 2.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Most code in this file was borrowed from https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/15_Style_Transfer.ipynb
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
"""Helper-functions for image manipulation"""
# This function loads an image and returns it as a numpy array of floating-points.
# The image can be automatically resized so the largest of the height or width equals max_size.
# or resized to the given shape
def load_image(filename, shape=None, max_size=None):
image = PIL.Image.open(filename)
if max_size is not None:
# Calculate the appropriate rescale-factor for
# ensuring a max height and width, while keeping
# the proportion between them.
factor = float(max_size) / np.max(image.size)
# Scale the image's height and width.
size = np.array(image.size) * factor
# The size is now floating-point because it was scaled.
# But PIL requires the size to be integers.
size = size.astype(int)
# Resize the image.
image = image.resize(size, PIL.Image.LANCZOS) # PIL.Image.LANCZOS is one of resampling filter
if shape is not None:
image = image.resize(shape, PIL.Image.LANCZOS) # PIL.Image.LANCZOS is one of resampling filter
# Convert to numpy floating-point array.
return np.float32(image)
# Save an image as a jpeg-file.
# The image is given as a numpy array with pixel-values between 0 and 255.
def save_image(image, filename):
# Ensure the pixel-values are between 0 and 255.
image = np.clip(image, 0.0, 255.0)
# Convert to bytes.
image = image.astype(np.uint8)
# Write the image-file in jpeg-format.
with open(filename, 'wb') as file:
PIL.Image.fromarray(image).save(file, 'jpeg')
# This function plots the content-, mixed- and style-images.
def plot_images(content_image, style_image, mixed_image):
# Create figure with sub-plots.
fig, axes = plt.subplots(1, 3, figsize=(10, 10))
# Adjust vertical spacing.
fig.subplots_adjust(hspace=0.1, wspace=0.1)
# Plot the content-image.
# Note that the pixel-values are normalized to
# the [0.0, 1.0] range by dividing with 255.
ax = axes.flat[0]
ax.imshow(content_image / 255.0, interpolation='sinc')
ax.set_xlabel("Content")
# Plot the mixed-image.
ax = axes.flat[1]
ax.imshow(mixed_image / 255.0, interpolation='sinc')
ax.set_xlabel("Output")
# Plot the style-image
ax = axes.flat[2]
ax.imshow(style_image / 255.0, interpolation='sinc')
ax.set_xlabel("Style")
# Remove ticks from all the plots.
for ax in axes.flat:
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()