Skip to content

Commit

Permalink
compile TF graph: support for MetaGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jun 29, 2018
1 parent a68b3a7 commit 47da54c
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tools/compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(argv):
argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable')
argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
argparser.add_argument("--summaries_tensor_name")
argparser.add_argument("--output_file", help='output pb or pbtxt file')
argparser.add_argument("--output_file", help='output pb, pbtxt or meta, metatxt file')
argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params")
argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars")
args = argparser.parse_args(argv[1:])
Expand Down Expand Up @@ -107,21 +107,28 @@ def main(argv):
assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?"
tf.identity(summaries_tensor, name=args.summaries_tensor_name)

if args.output_file and os.path.splitext(args.output_file)[1] in [".meta", ".metatxt"]:
# https://www.tensorflow.org/api_guides/python/meta_graph
saver = tf.train.Saver(
var_list=network.get_saveable_params_list(), max_to_keep=2 ** 31 - 1)
graph_def = saver.export_meta_graph()
else:
graph_def = graph.as_graph_def(add_shapes=True)

print("Graph collection keys:", graph.get_all_collection_keys())
print("Graph num operations:", len(graph.get_operations()))
graph_def = graph.as_graph_def(add_shapes=True)
print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))

if args.output_file:
filename = args.output_file
_, ext = os.path.splitext(filename)
assert ext in [".pb", ".pbtxt"], 'filename %r extension should be pb or pbtxt' % filename
assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"], 'filename %r extension invalid' % filename
print("Write graph to file:", filename)
graph_io.write_graph(
graph_def,
logdir=os.path.dirname(filename),
name=os.path.basename(filename),
as_text=(ext == ".pbtxt"))
as_text=ext.endswith("txt"))
else:
print("Use --output_file if you want to store the graph.")

Expand Down

0 comments on commit 47da54c

Please sign in to comment.