forked from luigifreda/pyslam
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils_tf.py
90 lines (81 loc) · 3.22 KB
/
utils_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
* This file is part of PYSLAM
* Adpated from adapted from https://github.com/lzx551402/contextdesc/blob/master/utils/tf.py, see the license therein.
* Copyright (C) 2016-present Luigi Freda <luigi dot freda at gmail dot com>
*
* PYSLAM is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PYSLAM is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with PYSLAM. If not, see <http://www.gnu.org/licenses/>.
"""
import os
import warnings # to disable tensorflow-numpy warnings: from https://github.com/tensorflow/tensorflow/issues/30427
warnings.filterwarnings('ignore', category=FutureWarning)
if False:
import tensorflow as tf
else:
# from https://stackoverflow.com/questions/56820327/the-name-tf-session-is-deprecated-please-use-tf-compat-v1-session-instead
import tensorflow.compat.v1 as tf
def load_frozen_model(pb_path, prefix='', print_nodes=False):
"""Load frozen model (.pb file) for testing.
After restoring the model, operators can be accessed by
graph.get_tensor_by_name('<prefix>/<op_name>')
Args:
pb_path: the path of frozen model.
prefix: prefix added to the operator name.
print_nodes: whether to print node names.
Returns:
graph: tensorflow graph definition.
"""
if os.path.exists(pb_path):
with tf.io.gfile.GFile(pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=prefix
)
if print_nodes:
for op in graph.get_operations():
print(op.name)
return graph
else:
print('Model file does not exist', pb_path)
exit(-1)
def recoverer(sess, model_path, meta_graph_path=None):
"""
Recovery parameters from a pretrained model.
Args:
sess: The tensorflow session instance.
model_path: Checkpoint file path.
Returns:
Nothing
"""
if meta_graph_path is None:
restore_var = tf.global_variables()
restorer = tf.train.Saver(restore_var)
else:
restorer = tf.train.import_meta_graph(meta_graph_path)
restorer.restore(sess, model_path)
# from https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printed
def set_tf_logging(logging_flag):
print('setting tf logging:',logging_flag)
if logging_flag:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
tf.get_logger().setLevel("INFO")
else:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
tf.get_logger().setLevel("ERROR")