-
Notifications
You must be signed in to change notification settings - Fork 1
/
vis.py
126 lines (98 loc) · 3.85 KB
/
vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import json
import pathlib
import os
import numpy as np
import tensorflow as tf
from PIL import Image
from models import *
# from metrics import get_metric
# from .metrics import get_metric_experimental
# Move this to Arguments later
model_save_dir = pathlib.Path('./saved-models')
def get_vis_data(input_size=(256, 256, 3), imgs=None):
'''
Prepare Numpy Dataset batch
Returns a list of tuple of the PIL images and their original shapes
'''
if imgs is None:
raise ValueError('Invalid Dataset directory paths provided')
if not isinstance(imgs, pathlib.Path):
imgs = pathlib.Path(imgs)
# Create output directory
if not os.path.exists(imgs.parent/'out'):
os.makedirs(imgs.parent/'out')
data = {
'dir': imgs,
'data': []
}
# Take only input images
img_names = [img_name for img_name in list(os.listdir(imgs)) if os.path.isfile(os.path.join(imgs, img_name))]
for img_name in img_names:
img = Image.open(imgs/img_name)
data['data'].append((img, img.size, img_name))
return data
def visualize(vis_data, input_size=(256, 256, 3), load_name=None):
if load_name is not None and not os.path.exists(model_save_dir/load_name):
raise ValueError('No saved model with the name \'%s\' exists!' % load_name)
load_path = None
if load_name is not None:
load_path = model_save_dir/load_name
model = build_vis_model(input_size=input_size, load_path=load_path)
model.trainable = False
# opt_adam = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999
# )
# model.compile(optimizer=opt_adam, loss=recon_loss(), metrics=[
# get_metric('psnr'),
# get_metric('ssim')
# ])
model.summary()
def img2arr(img):
"""
`img` is PIL Image
Reshape `img` to (256, 256, 3). Default method: Nearest Neighbour
"""
img = img.resize(input_size[:2])
# img = tf.keras.preprocessing.image.img_to_array(img)
img = np.array(img)[:,:,:3].astype(float) / 255.
# Remove alpha channel and convert to float
# img = img*2 - 1;
img = np.expand_dims(img, axis=0)
return img
def arr2img(img):
"""
`img` is tf.tensor array
"""
img = img.numpy()
img = np.squeeze(img, axis=0)
# img = tf.keras.preprocessing.image.array_to_img(img)
# img = (img + 1) / 2.
img = np.clip(img * 255., 0., 255.).astype('uint8')
img = Image.fromarray(img)
img = img.resize(orig_shape)
return img
for datum in vis_data['data']:
img, orig_shape, img_name = datum
img = img2arr(img)
[img, R_img, I_img, R2_img, I2_img] = model(img, training=False)
img = arr2img(img)
R_img = arr2img(R_img)
I_img = arr2img(I_img)
R2_img = arr2img(R2_img)
I2_img = arr2img(I2_img)
# Save image
img.save(vis_data['dir'].parent/'out'/('%s'%img_name))
# R_img.save(vis_data['dir']/'out'/('R_%s'%img_name))
# I_img.save(vis_data['dir']/'out'/('I_%s'%img_name))
# R2_img.save(vis_data['dir']/'out'/('R2_%s'%img_name))
# I2_img.save(vis_data['dir']/'out'/('I2_%s'%img_name))
print('Saved output images.')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training Args')
parser.add_argument('data_path', metavar='I', default='../dataset/test/imgs/', help='Path to the test directory containing input images')
parser.add_argument('--load-name', default=None, dest='load_name', help='Name of already saved model to load')
args = parser.parse_args()
input_size=(512, 512, 3)
vis_data = get_vis_data(input_size=input_size, imgs=args.data_path)
visualize(vis_data, input_size=input_size, load_name=args.load_name)