From f6a1c91675935a9cbfdc983ce105fd3de0fef1a9 Mon Sep 17 00:00:00 2001 From: "ben.liu" Date: Mon, 5 Jul 2021 19:03:16 +0800 Subject: [PATCH] upgrade to tensorflow 2.x --- utils/face_seg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/face_seg.py b/utils/face_seg.py index e32132a..b6f1685 100644 --- a/utils/face_seg.py +++ b/utils/face_seg.py @@ -10,10 +10,10 @@ class FaceSeg: def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')): - config = tf.ConfigProto() + config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True self._graph = tf.Graph() - self._sess = tf.Session(config=config, graph=self._graph) + self._sess = tf.compat.v1.Session(config=config, graph=self._graph) self.pb_file_path = model_path self._restore_from_pb() @@ -24,8 +24,8 @@ def _restore_from_pb(self): with self._sess.as_default(): with self._graph.as_default(): with gfile.FastGFile(self.pb_file_path, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) + graph_def = tf.compat.v1.GraphDef.FromString(f.read()) + # graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') def input_transform(self, image):