-
Notifications
You must be signed in to change notification settings - Fork 4
/
predict_func.py
22 lines (22 loc) · 1.03 KB
/
predict_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import cv2
import tensorflow as tf
from data import *
import numpy as np
def predict(model, image_test, label, color_mode, size):
image = cv2.imread(image_test)
if color_mode == 'hsv':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
elif color_mode == 'rgb':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif color_mode == 'gray':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image_cvt = tf.expand_dims(image_cvt, axis = 2)
image_cvt = tf.image.resize(image_cvt, size, method= 'nearest')
image_cvt = tf.cast(image_cvt, tf.float32)
image_norm = image_cvt / 255.
image_norm = tf.expand_dims(image_norm, axis= 0)
new_image = model(image_norm)
image_argmax = np.argmax(tf.squeeze(new_image, axis = 0), axis = 2)
image_decode = decode_label(image_argmax, label)
predict_img = tf.cast(tf.image.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), size, method = 'nearest'), tf.float32) * 0.7 + image_decode * 0.3
return np.floor(predict_img).astype('int'), new_image