From 3bc773cb8b8887d8566dd39373e293c7185c3de6 Mon Sep 17 00:00:00 2001 From: masakistan Date: Sat, 19 Sep 2020 16:38:45 -0600 Subject: [PATCH 1/3] use tf.shape instead of .shape for dynamic axes Signed-off-by: masakistan --- onnx_tf/handlers/backend/instance_normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_tf/handlers/backend/instance_normalization.py b/onnx_tf/handlers/backend/instance_normalization.py index 59beba3bf..0e05ae577 100644 --- a/onnx_tf/handlers/backend/instance_normalization.py +++ b/onnx_tf/handlers/backend/instance_normalization.py @@ -31,7 +31,7 @@ def _common(cls, node, **kwargs): beta = tensor_dict[node.inputs[2]] inputs = tensor_dict[node.inputs[0]] - inputs_shape = inputs.shape + inputs_shape = tf.shape(inputs) inputs_rank = inputs.shape.ndims moments_axes = list(range(inputs_rank))[2:] From 7b27f5d80f8b783ebf875b9d5f6c7b09bb5cf149 Mon Sep 17 00:00:00 2001 From: Chin Huang Date: Wed, 23 Sep 2020 09:56:19 +0800 Subject: [PATCH 2/3] Add model stepping test for Mnist (#734) * Add model stepping test for Mnist Add model stepping test for Mnist using ONNX runtime. The assumption is that ONNX runtime is installed and the mnist model from ONNX model zoo is downloaded. Signed-off-by: Chin Huang * add tensor_dict back in TFRep Signed-off-by: Chin Huang --- example/test_mnist_onnxruntime_stepping.py | 105 +++++++++++++++++++++ onnx_tf/backend.py | 68 +++++++++---- onnx_tf/backend_rep.py | 9 ++ onnx_tf/backend_tf_module.py | 24 ++++- 4 files changed, 187 insertions(+), 19 deletions(-) create mode 100644 example/test_mnist_onnxruntime_stepping.py diff --git a/example/test_mnist_onnxruntime_stepping.py b/example/test_mnist_onnxruntime_stepping.py new file mode 100644 index 000000000..a5fcd5f65 --- /dev/null +++ b/example/test_mnist_onnxruntime_stepping.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +import numpy as np + +import onnx +from onnx import helper +from onnx import TensorProto +import tensorflow as tf +import onnxruntime.backend as ort + +import onnx_tf.backend as otf +from onnx_tf.common import data_type + + +def find_between(s, first, last): + try: + start = s.index(first) + end = s.index(last) + len(last) + return s[start:end] + except ValueError: + return "" + + +class TestMnistModel(unittest.TestCase): + # Make sure the onnx file path is correct, assuming copied to the + # current directory + model_path = 'mnist-8.onnx' + + def test(self): + _model = onnx.load(self.model_path) + print("Total node count in model: ", len(_model.graph.node)) + + # The input tensors could be provided as constants + # The example below illustrates such a dictionary could be + # provided for models with unknown input shapes. Since + # mnist has known input shape, we don't provide input tensors. + # input_tensors = {'Input3': tf.constant(0, dtype = tf.float32, + # name='Input3', + # shape=[1, 1, 28, 28])} + input_tensors = {} + tensor_dict = otf.prepare(_model, + gen_tensor_dict=True, + input_tensor_dict=input_tensors).tensor_dict + more_outputs = [] + output_to_check = [] + for node in _model.graph.node: + # add the first output of each node to the model output + output_tensor = None + for i in range(len(_model.graph.value_info)): + if _model.graph.value_info[i].name == node.output[0]: + output_tensor = _model.graph.value_info[i] + + for i in range(len(_model.graph.initializer)): + if _model.graph.initializer[i].name == node.output[0]: + output_tensor = _model.graph.initializer[i] + + # assume the first output is a tensor + tensor = tensor_dict[node.output[0]] + output_tensor = helper.make_tensor_value_info( + node.output[0], data_type.tf2onnx(tensor.dtype), + tensor.shape) if output_tensor is None else output_tensor + more_outputs.append(output_tensor) + output_to_check.append(node.output[0]) + _model.graph.output.extend(more_outputs) + + tf_rep = otf.prepare(_model) + rt_rep = ort.prepare(_model) + + # prepare input data + mnist = tf.keras.datasets.mnist + (x_train, y_train), (x_test, y_test) = mnist.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32) + + inputs = [sample] + my_out = tf_rep.run(inputs) + rt_out = rt_rep.run(inputs) + + for op in output_to_check: + for i in range(len(my_out)): + # find the index of output in the list + if my_out[op] is my_out[i]: + + try: + np.savetxt(op.replace("/", "__") + ".rt", + rt_out[i].flatten(), + delimiter='\t') + np.savetxt(op.replace("/", "__") + ".tf", + my_out[i].flatten(), + delimiter='\t') + np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2) + print(op, "results of this layer are correct within tolerence.") + except Exception as e: + np.set_printoptions(threshold=np.inf) + mismatch_percent = (find_between(str(e), "(mismatch", "%)")) + print(op, "mismatch with percentage {} %".format(mismatch_percent)) + + +if __name__ == '__main__': + unittest.main() + pass diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index a8c80e06e..a88171d5c 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -64,12 +64,12 @@ def prepare(cls, super(TensorflowBackend, cls).prepare(model, device, **kwargs) common.logger.setLevel(logging_level) common.logger.handlers[0].setLevel(logging_level) - common.sys_config.auto_cast=auto_cast + common.sys_config.auto_cast = auto_cast - return cls.onnx_model_to_tensorflow_rep(model, strict) + return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs) @classmethod - def onnx_model_to_tensorflow_rep(cls, model, strict): + def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs): """ Convert ONNX model to TensorflowRep. :param model: ONNX ModelProto object. @@ -86,18 +86,27 @@ def onnx_model_to_tensorflow_rep(cls, model, strict): opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)] else: opset_import = model.opset_import - return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict) + return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict, + **kwargs) @classmethod - def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): + def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. + :kwargs: additional arguements to generate tensor_dict for model debugging :return: TensorflowRep object. """ + # To generate tensor_dict or not, default is False + gen_tensor_dict = kwargs[ + 'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False + # User provided input tensors, in the case the model inputs have unknown shapes + input_tensor_dict = kwargs[ + 'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict() + handlers = cls._get_handlers(opset) # initializer: TensorProtos representing the values to initialize @@ -105,13 +114,15 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): # initialized: A list of names of the initialized tensors. if graph_def.initializer: + input_dict_items = cls._onnx_initializer_to_input_dict_items( + graph_def.initializer) initialized = {init.name for init in graph_def.initializer} else: + input_dict_items = [] initialized = set() module = BackendTFModule(handlers, opset, strict, graph_def, cls) signatures = dict() - for value_info in graph_def.input: if value_info.name in initialized: continue @@ -119,12 +130,24 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( - ":", "_tf_") + "_" + get_unique_suffix( - ) if ":" in value_info.name else value_info.name + ":", "_tf_") + "_" + get_unique_suffix( + ) if ":" in value_info.name else value_info.name - tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name) + tf_spec = tf.TensorSpec( + shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), + value_info_name) signatures[value_info.name] = tf_spec + if gen_tensor_dict: + x = tf.constant( + 0, + dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type), + name=value_info_name, + shape=shape + ) if value_info.name not in input_tensor_dict else input_tensor_dict[ + value_info.name] + input_dict_items.append((value_info_name, x)) + tf_rep = TensorflowRep() tf_rep.inputs = [ value_info.name @@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): module.outputs = tf_rep.outputs tf_rep.tf_module = module tf_rep.signatures = signatures + tf_rep.tensor_dict = module.gen_tensor_dict( + input_dict_items) if gen_tensor_dict else None + return tf_rep @classmethod @@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs): :param kwargs: Other args. :return: Outputs. """ + class TFModule(tf.Module): + def __init__(self, node): super(TFModule, self).__init__() self.node = node @@ -171,13 +199,16 @@ def __call__(self, **input_dict): feed_dict_raw = dict(zip(node.inputs, inputs)) # TODO: is constant the best way for feeding inputs? - input_dict = dict( - [(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()]) + input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items() + ]) module = TFModule(node) output_vals = module(**input_dict) - output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals] + output_vals = [ + val.numpy() if isinstance(val, tf.Tensor) else val + for val in output_vals + ] return namedtupledict('Outputs', node.outputs)(*output_vals) @@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls, """ handlers = handlers or cls._get_handlers(opset) if handlers: - handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None + handler = handlers[node.domain].get( + node.op_type, None) if node.domain in handlers else None if handler: return handler.handle(node, tensor_dict=tensor_dict, strict=strict) - raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type)) + raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format( + node.op_type)) @classmethod def _get_handlers(cls, opset): @@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls, nodes_outputs.append(o_name) for node in subgraph.node: for i_name in node.input: - if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(): + if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys( + ): subgraph_tensor_dict[i_name] = tensor_dict[i_name] onnx_node = OnnxNode(node) output_ops = cls._onnx_node_to_tensorflow_op(onnx_node, @@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls, return subgraph_tensor_dict @classmethod - def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True): + def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs): """ Converts ONNX graph to TensorflowRep Args: @@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True): """ # get the opset of the installed ONNX opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())] - return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict) + return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs) prepare = TensorflowBackend.prepare diff --git a/onnx_tf/backend_rep.py b/onnx_tf/backend_rep.py index 84c03c9e4..5c83302d0 100644 --- a/onnx_tf/backend_rep.py +++ b/onnx_tf/backend_rep.py @@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None): self._inputs = inputs or [] self._outputs = outputs or [] self._tensor_dict = tensor_dict or {} + self._tf_module = None @property def graph(self): @@ -49,6 +50,14 @@ def tensor_dict(self): def tensor_dict(self, tensor_dict): self._tensor_dict = tensor_dict + @property + def tf_module(self): + return self._tf_module + + @tf_module.setter + def tf_module(self, tf_module): + self._tf_module = tf_module + def run(self, inputs, **kwargs): """ Run TensorflowRep. diff --git a/onnx_tf/backend_tf_module.py b/onnx_tf/backend_tf_module.py index 3b7bc65e0..1e1f8659a 100644 --- a/onnx_tf/backend_tf_module.py +++ b/onnx_tf/backend_tf_module.py @@ -1,6 +1,7 @@ import tensorflow as tf from onnx_tf.pb_wrapper import OnnxNode + class BackendTFModule(tf.Module): def __init__(self, handlers, opset, strict, graph_def, backend): @@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend): self.backend = backend self.outputs = [] + @tf.function + def gen_tensor_dict(self, input_dict_items): + tensor_dict = dict(input_dict_items) + + for node in self.graph_def.node: + onnx_node = OnnxNode(node) + output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node, + tensor_dict, + self.handlers, + opset=self.opset, + strict=self.strict) + curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) + tensor_dict.update(curr_node_output_map) + + return tensor_dict + @tf.function def __call__(self, **kwargs): tensor_dict = kwargs @@ -26,8 +43,11 @@ def __call__(self, **kwargs): for node in self.graph_def.node: onnx_node = OnnxNode(node) - output_ops = self.backend._onnx_node_to_tensorflow_op( - onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict) + output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node, + tensor_dict, + self.handlers, + opset=self.opset, + strict=self.strict) curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map) From 9ef09b822d423c587b14e65561870d110d3d7467 Mon Sep 17 00:00:00 2001 From: masakistan Date: Thu, 12 Nov 2020 15:13:25 -0700 Subject: [PATCH 3/3] use onnx-tf tf_shape instead of tf.shape Signed-off-by: masakistan --- onnx_tf/handlers/backend/instance_normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_tf/handlers/backend/instance_normalization.py b/onnx_tf/handlers/backend/instance_normalization.py index 0e05ae577..04d83b68c 100644 --- a/onnx_tf/handlers/backend/instance_normalization.py +++ b/onnx_tf/handlers/backend/instance_normalization.py @@ -3,7 +3,7 @@ from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op from onnx_tf.handlers.handler import tf_func - +from onnx_tf.common.tf_helper import tf_shape @onnx_op("InstanceNormalization") @tf_func(tf.nn.batch_normalization) @@ -31,7 +31,7 @@ def _common(cls, node, **kwargs): beta = tensor_dict[node.inputs[2]] inputs = tensor_dict[node.inputs[0]] - inputs_shape = tf.shape(inputs) + inputs_shape = tf_shape(inputs) inputs_rank = inputs.shape.ndims moments_axes = list(range(inputs_rank))[2:]