-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
41 lines (33 loc) · 1.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow as tf
print('Loading model ...')
model = tf.keras.models.load_model('saved_model')
class_names = ['Level_0','Level_1','Level_2']
def load_and_prep_image(filename, img_shape=224):
img = tf.io.read_file(filename)
# Decode it into a tensor
img = tf.image.decode_jpeg(img)
# Resize the image
img = tf.image.resize(img, [img_shape, img_shape])
# Rescale the image (get all values between 0 and 1)
img = img/255.
return img
def predict_class(filename):
"""
Imports an image located at filename, makes a prediction with model
and plots the image with the predicted class as the title.
"""
print('Loading image ...')
# Import the target image and preprocess it
img = load_and_prep_image(filename)
print('Predicting class of image ...')
# Make a prediction
pred = model.predict(tf.expand_dims(img, axis=0))
print(pred)
# Add in logic for multi-class & get pred_class name
if len(pred[0]) > 1:
pred_class = class_names[tf.argmax(pred[0])]
else:
pred_class = class_names[int(tf.round(pred[0]))]
print('Predicted class:', pred_class)
return pred_class
predict_class('test_image.jpeg')