Skip to content

Commit

Permalink
use tf.shape instead of .shape for dynamic axes
Browse files Browse the repository at this point in the history
Signed-off-by: masakistan <[email protected]>
  • Loading branch information
masakistan committed Sep 26, 2020
2 parents f884d78 + f616d65 commit eed87e3
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 19 deletions.
105 changes: 105 additions & 0 deletions example/test_mnist_onnxruntime_stepping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import unittest
import numpy as np

import onnx
from onnx import helper
from onnx import TensorProto
import tensorflow as tf
import onnxruntime.backend as ort

import onnx_tf.backend as otf
from onnx_tf.common import data_type


def find_between(s, first, last):
try:
start = s.index(first)
end = s.index(last) + len(last)
return s[start:end]
except ValueError:
return ""


class TestMnistModel(unittest.TestCase):
# Make sure the onnx file path is correct, assuming copied to the
# current directory
model_path = 'mnist-8.onnx'

def test(self):
_model = onnx.load(self.model_path)
print("Total node count in model: ", len(_model.graph.node))

# The input tensors could be provided as constants
# The example below illustrates such a dictionary could be
# provided for models with unknown input shapes. Since
# mnist has known input shape, we don't provide input tensors.
# input_tensors = {'Input3': tf.constant(0, dtype = tf.float32,
# name='Input3',
# shape=[1, 1, 28, 28])}
input_tensors = {}
tensor_dict = otf.prepare(_model,
gen_tensor_dict=True,
input_tensor_dict=input_tensors).tensor_dict
more_outputs = []
output_to_check = []
for node in _model.graph.node:
# add the first output of each node to the model output
output_tensor = None
for i in range(len(_model.graph.value_info)):
if _model.graph.value_info[i].name == node.output[0]:
output_tensor = _model.graph.value_info[i]

for i in range(len(_model.graph.initializer)):
if _model.graph.initializer[i].name == node.output[0]:
output_tensor = _model.graph.initializer[i]

# assume the first output is a tensor
tensor = tensor_dict[node.output[0]]
output_tensor = helper.make_tensor_value_info(
node.output[0], data_type.tf2onnx(tensor.dtype),
tensor.shape) if output_tensor is None else output_tensor
more_outputs.append(output_tensor)
output_to_check.append(node.output[0])
_model.graph.output.extend(more_outputs)

tf_rep = otf.prepare(_model)
rt_rep = ort.prepare(_model)

# prepare input data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32)

inputs = [sample]
my_out = tf_rep.run(inputs)
rt_out = rt_rep.run(inputs)

for op in output_to_check:
for i in range(len(my_out)):
# find the index of output in the list
if my_out[op] is my_out[i]:

try:
np.savetxt(op.replace("/", "__") + ".rt",
rt_out[i].flatten(),
delimiter='\t')
np.savetxt(op.replace("/", "__") + ".tf",
my_out[i].flatten(),
delimiter='\t')
np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2)
print(op, "results of this layer are correct within tolerence.")
except Exception as e:
np.set_printoptions(threshold=np.inf)
mismatch_percent = (find_between(str(e), "(mismatch", "%)"))
print(op, "mismatch with percentage {} %".format(mismatch_percent))


if __name__ == '__main__':
unittest.main()
pass
68 changes: 51 additions & 17 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def prepare(cls,
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.sys_config.auto_cast=auto_cast
common.sys_config.auto_cast = auto_cast

return cls.onnx_model_to_tensorflow_rep(model, strict)
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)

@classmethod
def onnx_model_to_tensorflow_rep(cls, model, strict):
def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs):
""" Convert ONNX model to TensorflowRep.
:param model: ONNX ModelProto object.
Expand All @@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
else:
opset_import = model.opset_import
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict,
**kwargs)

@classmethod
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
""" Convert ONNX graph to TensorflowRep.
:param graph_def: ONNX GraphProto object.
:param opset: ONNX OperatorSetIdProto list.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:kwargs: additional arguements to generate tensor_dict for model debugging
:return: TensorflowRep object.
"""
# To generate tensor_dict or not, default is False
gen_tensor_dict = kwargs[
'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False
# User provided input tensors, in the case the model inputs have unknown shapes
input_tensor_dict = kwargs[
'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict()

handlers = cls._get_handlers(opset)

# initializer: TensorProtos representing the values to initialize
# a given tensor.
# initialized: A list of names of the initialized tensors.

if graph_def.initializer:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer)
initialized = {init.name for init in graph_def.initializer}
else:
input_dict_items = []
initialized = set()

module = BackendTFModule(handlers, opset, strict, graph_def, cls)
signatures = dict()

for value_info in graph_def.input:
if value_info.name in initialized:
continue
shape = list(
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in value_info.type.tensor_type.shape.dim)
value_info_name = value_info.name.replace(
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name

tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name)
tf_spec = tf.TensorSpec(
shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type),
value_info_name)
signatures[value_info.name] = tf_spec

if gen_tensor_dict:
x = tf.constant(
0,
dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type),
name=value_info_name,
shape=shape
) if value_info.name not in input_tensor_dict else input_tensor_dict[
value_info.name]
input_dict_items.append((value_info_name, x))

tf_rep = TensorflowRep()
tf_rep.inputs = [
value_info.name
Expand All @@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
module.outputs = tf_rep.outputs
tf_rep.tf_module = module
tf_rep.signatures = signatures
tf_rep.tensor_dict = module.gen_tensor_dict(
input_dict_items) if gen_tensor_dict else None

return tf_rep

@classmethod
Expand All @@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
:param kwargs: Other args.
:return: Outputs.
"""

class TFModule(tf.Module):

def __init__(self, node):
super(TFModule, self).__init__()
self.node = node
Expand All @@ -171,13 +199,16 @@ def __call__(self, **input_dict):
feed_dict_raw = dict(zip(node.inputs, inputs))

# TODO: is constant the best way for feeding inputs?
input_dict = dict(
[(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()])
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])

module = TFModule(node)

output_vals = module(**input_dict)
output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals]
output_vals = [
val.numpy() if isinstance(val, tf.Tensor) else val
for val in output_vals
]

return namedtupledict('Outputs', node.outputs)(*output_vals)

Expand Down Expand Up @@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
"""
handlers = handlers or cls._get_handlers(opset)
if handlers:
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)

raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(
node.op_type))

@classmethod
def _get_handlers(cls, opset):
Expand Down Expand Up @@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
nodes_outputs.append(o_name)
for node in subgraph.node:
for i_name in node.input:
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys():
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
):
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
Expand All @@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
return subgraph_tensor_dict

@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
"""
Converts ONNX graph to TensorflowRep
Args:
Expand All @@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
"""
# get the opset of the installed ONNX
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs)


prepare = TensorflowBackend.prepare
Expand Down
9 changes: 9 additions & 0 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
self._inputs = inputs or []
self._outputs = outputs or []
self._tensor_dict = tensor_dict or {}
self._tf_module = None

@property
def graph(self):
Expand Down Expand Up @@ -49,6 +50,14 @@ def tensor_dict(self):
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict

@property
def tf_module(self):
return self._tf_module

@tf_module.setter
def tf_module(self, tf_module):
self._tf_module = tf_module

def run(self, inputs, **kwargs):
""" Run TensorflowRep.
Expand Down
24 changes: 22 additions & 2 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from onnx_tf.pb_wrapper import OnnxNode


class BackendTFModule(tf.Module):

def __init__(self, handlers, opset, strict, graph_def, backend):
Expand All @@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
self.backend = backend
self.outputs = []

@tf.function
def gen_tensor_dict(self, input_dict_items):
tensor_dict = dict(input_dict_items)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
self.handlers,
opset=self.opset,
strict=self.strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

return tensor_dict

@tf.function
def __call__(self, **kwargs):
tensor_dict = kwargs
Expand All @@ -26,8 +43,11 @@ def __call__(self, **kwargs):

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
output_ops = self.backend._onnx_node_to_tensorflow_op(
onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict)
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
self.handlers,
opset=self.opset,
strict=self.strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

Expand Down

0 comments on commit eed87e3

Please sign in to comment.