diff --git a/example/test_mnist_onnxruntime_stepping.py b/example/test_mnist_onnxruntime_stepping.py new file mode 100644 index 000000000..a5fcd5f65 --- /dev/null +++ b/example/test_mnist_onnxruntime_stepping.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import numpy as np + +import onnx +from onnx import helper +from onnx import TensorProto +import tensorflow as tf +import onnxruntime.backend as ort + +import onnx_tf.backend as otf +from onnx_tf.common import data_type + + +def find_between(s, first, last): + try: + start = s.index(first) + end = s.index(last) + len(last) + return s[start:end] + except ValueError: + return "" + + +class TestMnistModel(unittest.TestCase): + # Make sure the onnx file path is correct, assuming copied to the + # current directory + model_path = 'mnist-8.onnx' + + def test(self): + _model = onnx.load(self.model_path) + print("Total node count in model: ", len(_model.graph.node)) + + # The input tensors could be provided as constants + # The example below illustrates such a dictionary could be + # provided for models with unknown input shapes. Since + # mnist has known input shape, we don't provide input tensors. + # input_tensors = {'Input3': tf.constant(0, dtype = tf.float32, + # name='Input3', + # shape=[1, 1, 28, 28])} + input_tensors = {} + tensor_dict = otf.prepare(_model, + gen_tensor_dict=True, + input_tensor_dict=input_tensors).tensor_dict + more_outputs = [] + output_to_check = [] + for node in _model.graph.node: + # add the first output of each node to the model output + output_tensor = None + for i in range(len(_model.graph.value_info)): + if _model.graph.value_info[i].name == node.output[0]: + output_tensor = _model.graph.value_info[i] + + for i in range(len(_model.graph.initializer)): + if _model.graph.initializer[i].name == node.output[0]: + output_tensor = _model.graph.initializer[i] + + # assume the first output is a tensor + tensor = tensor_dict[node.output[0]] + output_tensor = helper.make_tensor_value_info( + node.output[0], data_type.tf2onnx(tensor.dtype), + tensor.shape) if output_tensor is None else output_tensor + more_outputs.append(output_tensor) + output_to_check.append(node.output[0]) + _model.graph.output.extend(more_outputs) + + tf_rep = otf.prepare(_model) + rt_rep = ort.prepare(_model) + + # prepare input data + mnist = tf.keras.datasets.mnist + (x_train, y_train), (x_test, y_test) = mnist.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32) + + inputs = [sample] + my_out = tf_rep.run(inputs) + rt_out = rt_rep.run(inputs) + + for op in output_to_check: + for i in range(len(my_out)): + # find the index of output in the list + if my_out[op] is my_out[i]: + + try: + np.savetxt(op.replace("/", "__") + ".rt", + rt_out[i].flatten(), + delimiter='\t') + np.savetxt(op.replace("/", "__") + ".tf", + my_out[i].flatten(), + delimiter='\t') + np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2) + print(op, "results of this layer are correct within tolerence.") + except Exception as e: + np.set_printoptions(threshold=np.inf) + mismatch_percent = (find_between(str(e), "(mismatch", "%)")) + print(op, "mismatch with percentage {} %".format(mismatch_percent)) + + +if __name__ == '__main__': + unittest.main() + pass diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index a8c80e06e..a88171d5c 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -64,12 +64,12 @@ def prepare(cls, super(TensorflowBackend, cls).prepare(model, device, **kwargs) common.logger.setLevel(logging_level) common.logger.handlers[0].setLevel(logging_level) - common.sys_config.auto_cast=auto_cast + common.sys_config.auto_cast = auto_cast - return cls.onnx_model_to_tensorflow_rep(model, strict) + return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs) @classmethod - def onnx_model_to_tensorflow_rep(cls, model, strict): + def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs): """ Convert ONNX model to TensorflowRep. :param model: ONNX ModelProto object. @@ -86,18 +86,27 @@ def onnx_model_to_tensorflow_rep(cls, model, strict): opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)] else: opset_import = model.opset_import - return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict) + return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict, + **kwargs) @classmethod - def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): + def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. + :kwargs: additional arguements to generate tensor_dict for model debugging :return: TensorflowRep object. """ + # To generate tensor_dict or not, default is False + gen_tensor_dict = kwargs[ + 'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False + # User provided input tensors, in the case the model inputs have unknown shapes + input_tensor_dict = kwargs[ + 'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict() + handlers = cls._get_handlers(opset) # initializer: TensorProtos representing the values to initialize @@ -105,13 +114,15 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): # initialized: A list of names of the initialized tensors. if graph_def.initializer: + input_dict_items = cls._onnx_initializer_to_input_dict_items( + graph_def.initializer) initialized = {init.name for init in graph_def.initializer} else: + input_dict_items = [] initialized = set() module = BackendTFModule(handlers, opset, strict, graph_def, cls) signatures = dict() - for value_info in graph_def.input: if value_info.name in initialized: continue @@ -119,12 +130,24 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( - ":", "_tf_") + "_" + get_unique_suffix( - ) if ":" in value_info.name else value_info.name + ":", "_tf_") + "_" + get_unique_suffix( + ) if ":" in value_info.name else value_info.name - tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name) + tf_spec = tf.TensorSpec( + shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), + value_info_name) signatures[value_info.name] = tf_spec + if gen_tensor_dict: + x = tf.constant( + 0, + dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type), + name=value_info_name, + shape=shape + ) if value_info.name not in input_tensor_dict else input_tensor_dict[ + value_info.name] + input_dict_items.append((value_info_name, x)) + tf_rep = TensorflowRep() tf_rep.inputs = [ value_info.name @@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): module.outputs = tf_rep.outputs tf_rep.tf_module = module tf_rep.signatures = signatures + tf_rep.tensor_dict = module.gen_tensor_dict( + input_dict_items) if gen_tensor_dict else None + return tf_rep @classmethod @@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs): :param kwargs: Other args. :return: Outputs. """ + class TFModule(tf.Module): + def __init__(self, node): super(TFModule, self).__init__() self.node = node @@ -171,13 +199,16 @@ def __call__(self, **input_dict): feed_dict_raw = dict(zip(node.inputs, inputs)) # TODO: is constant the best way for feeding inputs? - input_dict = dict( - [(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()]) + input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items() + ]) module = TFModule(node) output_vals = module(**input_dict) - output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals] + output_vals = [ + val.numpy() if isinstance(val, tf.Tensor) else val + for val in output_vals + ] return namedtupledict('Outputs', node.outputs)(*output_vals) @@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls, """ handlers = handlers or cls._get_handlers(opset) if handlers: - handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None + handler = handlers[node.domain].get( + node.op_type, None) if node.domain in handlers else None if handler: return handler.handle(node, tensor_dict=tensor_dict, strict=strict) - raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type)) + raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format( + node.op_type)) @classmethod def _get_handlers(cls, opset): @@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls, nodes_outputs.append(o_name) for node in subgraph.node: for i_name in node.input: - if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(): + if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys( + ): subgraph_tensor_dict[i_name] = tensor_dict[i_name] onnx_node = OnnxNode(node) output_ops = cls._onnx_node_to_tensorflow_op(onnx_node, @@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls, return subgraph_tensor_dict @classmethod - def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True): + def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs): """ Converts ONNX graph to TensorflowRep Args: @@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True): """ # get the opset of the installed ONNX opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())] - return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict) + return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs) prepare = TensorflowBackend.prepare diff --git a/onnx_tf/backend_rep.py b/onnx_tf/backend_rep.py index 84c03c9e4..5c83302d0 100644 --- a/onnx_tf/backend_rep.py +++ b/onnx_tf/backend_rep.py @@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None): self._inputs = inputs or [] self._outputs = outputs or [] self._tensor_dict = tensor_dict or {} + self._tf_module = None @property def graph(self): @@ -49,6 +50,14 @@ def tensor_dict(self): def tensor_dict(self, tensor_dict): self._tensor_dict = tensor_dict + @property + def tf_module(self): + return self._tf_module + + @tf_module.setter + def tf_module(self, tf_module): + self._tf_module = tf_module + def run(self, inputs, **kwargs): """ Run TensorflowRep. diff --git a/onnx_tf/backend_tf_module.py b/onnx_tf/backend_tf_module.py index 3b7bc65e0..1e1f8659a 100644 --- a/onnx_tf/backend_tf_module.py +++ b/onnx_tf/backend_tf_module.py @@ -1,6 +1,7 @@ import tensorflow as tf from onnx_tf.pb_wrapper import OnnxNode + class BackendTFModule(tf.Module): def __init__(self, handlers, opset, strict, graph_def, backend): @@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend): self.backend = backend self.outputs = [] + @tf.function + def gen_tensor_dict(self, input_dict_items): + tensor_dict = dict(input_dict_items) + + for node in self.graph_def.node: + onnx_node = OnnxNode(node) + output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node, + tensor_dict, + self.handlers, + opset=self.opset, + strict=self.strict) + curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) + tensor_dict.update(curr_node_output_map) + + return tensor_dict + @tf.function def __call__(self, **kwargs): tensor_dict = kwargs @@ -26,8 +43,11 @@ def __call__(self, **kwargs): for node in self.graph_def.node: onnx_node = OnnxNode(node) - output_ops = self.backend._onnx_node_to_tensorflow_op( - onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict) + output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node, + tensor_dict, + self.handlers, + opset=self.opset, + strict=self.strict) curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map)