-
Notifications
You must be signed in to change notification settings - Fork 1
/
testLite.py
44 lines (36 loc) · 1.22 KB
/
testLite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import time
import functools
tf.enable_v2_behavior()
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False
PATH = './a.jpg' #picture PATH
MPATH = './pix2pix_v1_1.0.tflite' #Model PATH
height = width = 256
def load_img(path_to_img):
img = tf.io.read_file(path_to_img)
img = tf.image.decode_image(img, channels=3)
img = tf.cast(img, tf.float32) #load
img = tf.image.resize(img, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) #resize
img = (img / 127.5) - 1 #normalize
img = img[tf.newaxis, :]
return img
i= tf.lite.Interpreter(model_path=MPATH)
i.allocate_tensors()
input_details = i.get_input_details()
i.set_tensor(input_details[0]["index"], load_img(PATH))
i.invoke()
# Generate image.
g = i.tensor(i.get_output_details()[0]["index"])()
plt.figure(figsize=(5,5))
plt.title('Predicted')
plt.imshow(g[0] * 0.5 + 0.5)
plt.axis('off')
plt.show()
plt.savefig('pictures/test.png')
#imshow(g, 'Generated Image')