From 52fbf7b4ac45472fd1d0ae8e5ea169efb8552f8b Mon Sep 17 00:00:00 2001 From: Winnie Tsang Date: Tue, 3 Nov 2020 14:48:14 -0800 Subject: [PATCH 1/3] Fix dynamic shape test and move handler's variable creation into (#795) BackendTFModule.__init__ 1. Add export_graph and use this graph to run in Tensorflow instead of using tf_rep.run for each dynamic shape testcases to verify our handlers can support unknown shape inputs. 2. Fix conv_mixin to support dynamic shape input for W. 3. Move tf.Variable creation into BackendTFModule.__init__ instead of creating them inside the handler. 4. Update non_max_suppression handler to create the tf.Variable in BackendTFModule.__init__ and use the variable in the handler 5. Fix random_normal_like handler to support dynamic shape input. 6. Fix dilated_pooling to handle unknown input shape. 7. Add ONNX operators count to tf_rep. 8. Minor update on CLI_template.md Signed-off-by: Winnie Tsang --- doc/CLI.md | 2 +- doc/CLI_template.md | 2 +- onnx_tf/backend.py | 31 ++ onnx_tf/backend_rep.py | 22 +- onnx_tf/backend_tf_module.py | 44 ++ onnx_tf/handlers/backend/conv_mixin.py | 15 +- onnx_tf/handlers/backend/dilated_pooling.py | 9 +- .../handlers/backend/non_max_suppression.py | 26 +- .../handlers/backend/random_normal_like.py | 3 +- onnx_tf/handlers/backend_handler.py | 15 +- test/backend/test_dynamic_shape.py | 381 ++++++++++++++---- 11 files changed, 431 insertions(+), 119 deletions(-) diff --git a/doc/CLI.md b/doc/CLI.md index 0377160ff..537823986 100644 --- a/doc/CLI.md +++ b/doc/CLI.md @@ -22,7 +22,7 @@ optional arguments: ### Convert: #### From ONNX to Tensorflow: -`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb` +`onnx-tf convert -i /path/to/input.onnx -o /path/to/output` More information: `onnx-tf convert -h` ``` diff --git a/doc/CLI_template.md b/doc/CLI_template.md index dd5e54f18..1ccfd7998 100644 --- a/doc/CLI_template.md +++ b/doc/CLI_template.md @@ -14,7 +14,7 @@ More information: `onnx-tf -h` ### Convert: #### From ONNX to Tensorflow: -`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb` +`onnx-tf convert -i /path/to/input.onnx -o /path/to/output` More information: `onnx-tf convert -h` ``` diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index 31188e914..c861e308d 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -160,8 +160,39 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs): tf_rep.signatures = signatures tf_rep.tensor_dict = module.gen_tensor_dict( input_dict) if gen_tensor_dict else None + tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def) return tf_rep + @classmethod + def _get_onnx_op_list(cls, graph_def): + """ Get ONNX operator counts of the model. + + :param graph_def: ONNX GraphProto object. + :return: Dictionary of all operators counts in the model. + """ + + def get_onnx_op_from_graph_and_subgraph(graph, op_list): + for node in graph.node: + op_list[node.op_type] = 1 if node.op_type not in op_list.keys( + ) else op_list[node.op_type] + 1 + if node.op_type in ['Loop', 'Scan']: + onnx_node = OnnxNode(node) + body = onnx_node.attrs["body"] + op_list = get_onnx_op_from_graph_and_subgraph(body, op_list) + elif node.op_type == 'If': + onnx_node = OnnxNode(node) + then_branch = onnx_node.attrs['then_branch'] + op_list = get_onnx_op_from_graph_and_subgraph(then_branch, op_list) + else_branch = onnx_node.attrs['else_branch'] + op_list = get_onnx_op_from_graph_and_subgraph(else_branch, op_list) + return op_list + + op_list = get_onnx_op_from_graph_and_subgraph(graph_def, dict()) + sorted_op_list = dict() + for key in sorted(op_list): + sorted_op_list[key] = op_list[key] + return sorted_op_list + @classmethod def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs): """ Run ONNX node. diff --git a/onnx_tf/backend_rep.py b/onnx_tf/backend_rep.py index 5c83302d0..a8a0953e8 100644 --- a/onnx_tf/backend_rep.py +++ b/onnx_tf/backend_rep.py @@ -50,6 +50,14 @@ def tensor_dict(self): def tensor_dict(self, tensor_dict): self._tensor_dict = tensor_dict + @property + def onnx_op_list(self): + return self._onnx_op_list + + @onnx_op_list.setter + def onnx_op_list(self, onnx_op_list): + self._onnx_op_list = onnx_op_list + @property def tf_module(self): return self._tf_module @@ -80,11 +88,13 @@ def run(self, inputs, **kwargs): # single input feed_dict = dict([(self.inputs[0], inputs)]) - input_dict = dict( - [(x[0], tf.constant(x[1])) for x in feed_dict.items()]) + input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict.items()]) output_values = self.tf_module(**input_dict) - output_values = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_values] + output_values = [ + val.numpy() if isinstance(val, tf.Tensor) else val + for val in output_values + ] return namedtupledict('Outputs', self.outputs)(*output_values) @@ -99,4 +109,8 @@ def export_graph(self, path): :returns: none. """ - tf.saved_model.save(self.tf_module, path, signatures=self.tf_module.__call__.get_concrete_function(**self.signatures)) + tf.saved_model.save( + self.tf_module, + path, + signatures=self.tf_module.__call__.get_concrete_function( + **self.signatures)) diff --git a/onnx_tf/backend_tf_module.py b/onnx_tf/backend_tf_module.py index 8774f60e6..fcf191e71 100644 --- a/onnx_tf/backend_tf_module.py +++ b/onnx_tf/backend_tf_module.py @@ -1,3 +1,4 @@ +from onnx.defs import ONNX_DOMAIN import tensorflow as tf from onnx_tf.pb_wrapper import OnnxNode @@ -14,6 +15,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend): self.outputs = [] self.initializer_dict = self._get_initializer_from_graph_and_subgraphs( self.graph_def, dict()) + self.handler_variables = self._create_handlers_variables( + self.graph_def, dict()) # get initializer from the main graph and all subgraphs in loop or if or scan # into tensor_dict @@ -37,10 +40,40 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict): else_branch, graph_tensor_dict) return graph_tensor_dict + # create tf.Variable for handlers that required to use variable in handler + def _create_handlers_variables(self, graph, vars_dict): + handlers = self.backend._get_handlers(self.opset) + for node in graph.node: + handler = handlers[node.domain].get( + node.op_type, None) if node.domain in handlers else None + if handler and bool(handler.get_req_vars_template()): + for v_name, v_template in handler.get_req_vars_template().items(): + v_init, v_shape = v_template + v_count = 0 + for var_name in vars_dict.keys(): + v_count = v_count + 1 if var_name.startswith(v_name) else v_count + v_name = v_name + '_' + str(v_count) + vars_dict[v_name] = tf.Variable(v_init, + dtype=v_init.dtype, + shape=v_shape, + name=v_name) + if node.op_type in ['Loop', 'Scan']: + onnx_node = OnnxNode(node) + body = onnx_node.attrs["body"] + vars_dict = self._create_handlers_variables(body, vars_dict) + elif node.op_type == 'If': + onnx_node = OnnxNode(node) + then_branch = onnx_node.attrs['then_branch'] + vars_dict = self._create_handlers_variables(then_branch, vars_dict) + else_branch = onnx_node.attrs['else_branch'] + vars_dict = self._create_handlers_variables(else_branch, vars_dict) + return vars_dict + @tf.function def gen_tensor_dict(self, input_dict): tensor_dict = dict(input_dict) tensor_dict.update(self.initializer_dict) + tensor_dict.update(self.handler_variables) for node in self.graph_def.node: onnx_node = OnnxNode(node) @@ -52,12 +85,18 @@ def gen_tensor_dict(self, input_dict): curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map) + # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN) + # TODO update this when we support handlers in other domain + for _, handler in self.handlers[ONNX_DOMAIN].items(): + handler.VAR_COUNT = 0 + return tensor_dict @tf.function def __call__(self, **kwargs): tensor_dict = kwargs tensor_dict.update(self.initializer_dict) + tensor_dict.update(self.handler_variables) for node in self.graph_def.node: onnx_node = OnnxNode(node) @@ -70,4 +109,9 @@ def __call__(self, **kwargs): tensor_dict.update(curr_node_output_map) outputs = [tensor_dict[output] for output in self.outputs] + + # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN) + # TODO update this when we support handlers in other domain + for _, handler in self.handlers[ONNX_DOMAIN].items(): + handler.VAR_COUNT = 0 return outputs diff --git a/onnx_tf/handlers/backend/conv_mixin.py b/onnx_tf/handlers/backend/conv_mixin.py index d33e11e70..0f61092b4 100644 --- a/onnx_tf/handlers/backend/conv_mixin.py +++ b/onnx_tf/handlers/backend/conv_mixin.py @@ -46,14 +46,15 @@ def conv(cls, node, input_dict, transpose=False): if "kernel_shape" in node.attrs.keys(): kernel_shape = node.attrs["kernel_shape"] - assert in_weights.get_shape().as_list()[2:] == kernel_shape, ( - "kernel_shape " - "attr of convolution does not match the actual weight " - "passed to this operation, attr {}, actual {}").format( - kernel_shape, - in_weights.get_shape().as_list()) + if in_weights.get_shape().is_fully_defined(): + assert in_weights.get_shape().as_list()[2:] == kernel_shape, ( + "kernel_shape " + "attr of convolution does not match the actual weight " + "passed to this operation, attr {}, actual {}").format( + kernel_shape, + in_weights.get_shape().as_list()) else: - kernel_shape = in_weights.get_shape().as_list()[2:] + kernel_shape = tf_shape(in_weights, tf.int32)[2:] weights = tf.transpose(in_weights, perm) dilations = node.attrs.get("dilations", [1] * spatial_size) diff --git a/onnx_tf/handlers/backend/dilated_pooling.py b/onnx_tf/handlers/backend/dilated_pooling.py index 7274c3df5..edfcf35e2 100644 --- a/onnx_tf/handlers/backend/dilated_pooling.py +++ b/onnx_tf/handlers/backend/dilated_pooling.py @@ -601,9 +601,12 @@ def dilated_maxpool_with_argmax(self, force_custom_impl=False): # if there was padding, recalculate the returned index # to exclude the padding - count_nonzero_op = np.count_nonzero if self.is_known_shape else tf.math.count_nonzero - if count_nonzero_op(self.pads) != 0: - new_ind = self._calc_argmax_without_padding(new_ind) + if self.is_known_shape: + if np.count_nonzero(self.pads) != 0: + new_ind = self._calc_argmax_without_padding(new_ind) + else: + new_ind = tf.where(tf.not_equal(tf.math.count_nonzero(self.pads), 0), + self._calc_argmax_without_padding(new_ind), new_ind) return (pooled, new_ind) diff --git a/onnx_tf/handlers/backend/non_max_suppression.py b/onnx_tf/handlers/backend/non_max_suppression.py index 6b03d8025..3374b1bf7 100644 --- a/onnx_tf/handlers/backend/non_max_suppression.py +++ b/onnx_tf/handlers/backend/non_max_suppression.py @@ -7,8 +7,20 @@ @onnx_op("NonMaxSuppression") class NonMaxSuppression(BackendHandler): + var_prefix = 'non_max_suppression_result' - result = None + @classmethod + def get_req_vars_template(cls): + """ Get required variables template. + + :return: Dict. + """ + return { + cls.var_prefix: [ + tf.constant([[0, 0, 0]], dtype=tf.int64), + tf.TensorShape([None, 3]) + ] + } @classmethod def _common(cls, node, **kwargs): @@ -84,19 +96,13 @@ def create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold, result = output if tf.equal(batch_i, 0) and tf.equal( class_j, 0) else tf.concat([result, output], 0) + cls.VAR_COUNT = cls.VAR_COUNT + 1 return result - # Since tf.function doesn't support locals() and it require all the variables - # are defined before use in the "for loop" before it will perform any auto - # convertion of the python code. Therefore need to define "result" as a - # Variable here and send it in as a parameter to "create_nodes" - if cls.result is None: - cls.result = tf.Variable([[0, 0, 0]], - dtype=tf.int64, - shape=tf.TensorShape([None, 3])) + result = tensor_dict[cls.var_prefix + '_' + str(cls.VAR_COUNT)] return [ create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold, - score_threshold, cls.result) + score_threshold, result) ] @classmethod diff --git a/onnx_tf/handlers/backend/random_normal_like.py b/onnx_tf/handlers/backend/random_normal_like.py index 25acf8842..1e9120b18 100644 --- a/onnx_tf/handlers/backend/random_normal_like.py +++ b/onnx_tf/handlers/backend/random_normal_like.py @@ -1,5 +1,6 @@ import tensorflow as tf +from onnx_tf.common.tf_helper import tf_shape from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op from onnx_tf.handlers.handler import tf_func @@ -15,5 +16,5 @@ def get_attrs_processor_param(cls): @classmethod def version_1(cls, node, **kwargs): - inputs = [kwargs["tensor_dict"][node.inputs[0]].get_shape()] + inputs = [tf_shape(kwargs["tensor_dict"][node.inputs[0]])] return [cls.make_tensor_from_onnx_node(node, inputs=inputs, **kwargs)] diff --git a/onnx_tf/handlers/backend_handler.py b/onnx_tf/handlers/backend_handler.py index 6ad8db57e..9933fce4e 100644 --- a/onnx_tf/handlers/backend_handler.py +++ b/onnx_tf/handlers/backend_handler.py @@ -24,6 +24,15 @@ class BackendHandler(Handler): """ TF_FUNC = None + VAR_COUNT = 0 + + @classmethod + def get_req_vars_template(cls): + """ Get required variables template. + + :return: Dict. + """ + return {} @classmethod def get_attrs_processor_param(cls): @@ -183,9 +192,9 @@ def _run_tf_func(cls, tf_func, inputs, attrs): attrs = {p: v for p, v in attrs.items() if p in params} kwargs = dict(zip(params, inputs)) - ambiguous_arguments = any(kwargs.get(p) is not None and v is not None - for p, v in attrs.items()) + ambiguous_arguments = any( + kwargs.get(p) is not None and v is not None for p, v in attrs.items()) if ambiguous_arguments: raise TypeError('Ambiguous arguments for {}()'.format(tf_func.__name__)) kwargs.update((p, v) for p, v in attrs.items() if v is not None) - return tf_func(**kwargs) \ No newline at end of file + return tf_func(**kwargs) diff --git a/test/backend/test_dynamic_shape.py b/test/backend/test_dynamic_shape.py index e380b3ca9..c6739ddbf 100644 --- a/test/backend/test_dynamic_shape.py +++ b/test/backend/test_dynamic_shape.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import unittest +import shutil from onnx import defs from onnx import helper @@ -53,11 +54,18 @@ def test_arg_max(self): ]) x = np.array([[1, 2, 3, 5, 3, 4, 5, 1], [2, 9, 3, 5, 9, 4, 5, 1]]).astype(np.float32) + # get tf_rep tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/arg_max' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) expected_output = np.argmax(np.flip(x, axis), axis=axis) expected_output = x.shape[axis] - expected_output - 1 - np.testing.assert_almost_equal(output['Y'], expected_output) + np.testing.assert_almost_equal(tf_model_output[0], expected_output) def test_arg_min(self): if legacy_opset_pre_ver(12): @@ -83,10 +91,16 @@ def test_arg_min(self): x = np.array([[1, 2, 3, 5, 3, 4, 5, 1], [2, 7, 3, 5, 2, 4, 5, 6]]).astype(np.float32) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/arg_min' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) expected_output = np.argmin(np.flip(x, axis), axis=axis) expected_output = x.shape[axis] - expected_output - 1 - np.testing.assert_almost_equal(output['Y'], expected_output) + np.testing.assert_almost_equal(tf_model_output[0], expected_output) def _batch_normalization(self, x, mean, variance, bias, scale, variance_epsilon): @@ -130,14 +144,14 @@ def test_batch_normalization(self): _bias = bias.reshape(_param_shape) golden = self._batch_normalization(x, _m, _v, _bias, _scale, 0.001) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({ - "X": x, - "scale": scale, - "bias": bias, - "mean": m, - "var": v - }) - np.testing.assert_almost_equal(output["Y"], golden, decimal=5) + # export to tf.saved_model + model_path = 'test_dynamic_shape/batch_normalization' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x, scale=scale, bias=bias, mean=m, var=v) + np.testing.assert_almost_equal(tf_model_output[0], golden, decimal=5) def test_compress(self): if legacy_opset_pre_ver(9): @@ -164,8 +178,15 @@ def test_compress(self): x = self._get_rnd_float32(shape=[5, 5, 5]) cond = np.array([1, 0, 1]).astype(np.bool) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x, "condition": cond}) - np.testing.assert_almost_equal(output['Y'], np.compress(cond, x, axis=axis)) + # export to tf.saved_model + model_path = 'test_dynamic_shape/compress' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x, condition=cond) + np.testing.assert_almost_equal(tf_model_output[0], + np.compress(cond, x, axis=axis)) def test_conv_transpose(self): # test dynamic batch size on transpose of 2d convolution @@ -192,7 +213,13 @@ def test_conv_transpose(self): ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x, "weights": weights}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/conv_transpose' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x, weights=weights) padh_left = weight_shape[2] - 1 - pads[0] padh_right = weight_shape[2] - 1 - pads[1] @@ -219,7 +246,7 @@ def test_conv_transpose(self): k2 - padw_left] * weights[c][m][kh + h - 1 - k1][kw + w - 1 - k2] - np.testing.assert_almost_equal(output["Y"], test_output, decimal=5) + np.testing.assert_almost_equal(tf_model_output[0], test_output, decimal=5) def test_eye_like(self): if legacy_opset_pre_ver(9): @@ -242,8 +269,14 @@ def test_eye_like(self): helper.make_tensor_value_info("y", TensorProto.FLOAT, [None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"x": x}) - np.testing.assert_equal(output["y"], y) + # export to tf.saved_model + model_path = 'test_dynamic_shape/eye_like' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(x=x) + np.testing.assert_equal(tf_model_output[0], y) def test_flatten(self): shape = [2, 3, 4] @@ -259,9 +292,15 @@ def test_flatten(self): ], outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None])]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/flatten' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) new_shape = (np.prod(shape[0:axis]).astype(int), -1) - np.testing.assert_almost_equal(output["Y"], np.reshape(x, new_shape)) + np.testing.assert_almost_equal(tf_model_output[0], np.reshape(x, new_shape)) def test_gather_nd(self): if legacy_opset_pre_ver(11): @@ -286,8 +325,14 @@ def test_gather_nd(self): helper.make_tensor_value_info("outputs", TensorProto.INT32, [None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"data": data, "indices": indices}) - np.testing.assert_almost_equal(output["outputs"], ref_output) + # export to tf.saved_model + model_path = 'test_dynamic_shape/gather_nd' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(data=data, indices=indices) + np.testing.assert_almost_equal(tf_model_output[0], ref_output) def test_is_inf(self): if legacy_opset_pre_ver(10): @@ -305,8 +350,14 @@ def test_is_inf(self): ], outputs=[helper.make_tensor_value_info("Y", TensorProto.BOOL, [None])]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": inp}) - np.testing.assert_equal(output["Y"], expected_output) + # export to tf.saved_model + model_path = 'test_dynamic_shape/is_inf' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=inp) + np.testing.assert_equal(tf_model_output[0], expected_output) def test_matmul_integer(self): if legacy_opset_pre_ver(10): @@ -343,13 +394,17 @@ def test_matmul_integer(self): [None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({ - "A": A, - "B": B, - "a_zero_point": a_zero_point, - "b_zero_point": b_zero_point - }) - np.testing.assert_almost_equal(output["Z"], z) + # export to tf.saved_model + model_path = 'test_dynamic_shape/matmul_integer' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(A=A, + B=B, + a_zero_point=a_zero_point, + b_zero_point=b_zero_point) + np.testing.assert_almost_equal(tf_model_output[0], z) # A & B are 4-D tensor and a_zero_point & b_zero_point are 1-D tensor A = self._get_rnd_int(-20, 20, shape=(2, 5, 3, 4), dtype=np.int8) B = self._get_rnd_int(-20, 20, shape=(2, 1, 4, 6), dtype=np.int8) @@ -385,13 +440,16 @@ def test_matmul_integer(self): [None, None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({ - "A": A, - "B": B, - "a_zero_point": a_zero_point, - "b_zero_point": b_zero_point - }) - np.testing.assert_almost_equal(output["Z"], z) + # export to tf.saved_model + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(A=A, + B=B, + a_zero_point=a_zero_point, + b_zero_point=b_zero_point) + np.testing.assert_almost_equal(tf_model_output[0], z) def test_non_max_suppression(self): if legacy_opset_pre_ver(10): @@ -437,14 +495,112 @@ def test_non_max_suppression(self): [None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({ - "boxes": boxes, - "scores": scores, - "max_output_boxes_per_class": max_output_boxes_per_class, - "iou_threshold": iou_threshold, - "score_threshold": score_threshold - }) - np.testing.assert_almost_equal(output["selected_indices"], selected_indices) + # export to tf.saved_model + model_path = 'test_dynamic_shape/non_max_suppression' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model( + boxes=boxes, + scores=scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold) + np.testing.assert_almost_equal(tf_model_output[0], selected_indices) + + def test_non_max_suppression_with_if(self): + # if cond + # return NonMaxSuppression suppress by IOU + # else + # return NonNaxSuppression suppress by IOU and score + boxes = np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], [0.0, 100.0, 1.0, + 101.0]]]).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = np.array([3]).astype(np.int64) + iou_threshold = np.array([0.5]).astype(np.float32) + score_threshold = np.array([0.4]).astype(np.float32) + selected_indices_1 = np.array([[0, 0, 3], [0, 0, 0], [0, 0, + 5]]).astype(np.int64) + selected_indices_2 = np.array([[0, 0, 3], [0, 0, 0]]).astype(np.int64) + + boxes_in = helper.make_tensor_value_info("boxes", TensorProto.FLOAT, + [None, None, None]) + scores_in = helper.make_tensor_value_info("scores", TensorProto.FLOAT, + [None, None, None]) + max_output_boxes_per_class_in = helper.make_tensor_value_info( + "max_output_boxes_per_class", TensorProto.INT64, [None]) + iou_threshold_in = helper.make_tensor_value_info("iou_threshold", + TensorProto.FLOAT, [None]) + score_threshold_in = helper.make_tensor_value_info("score_threshold", + TensorProto.FLOAT, + [None]) + cond_in = helper.make_tensor_value_info('cond', TensorProto.BOOL, []) + + selected_indices_1_out = helper.make_tensor_value_info( + "selected_indices_1", TensorProto.INT64, [None, None]) + selected_indices_2_out = helper.make_tensor_value_info( + "selected_indices_2", TensorProto.INT64, [None, None]) + selected_indices_out = helper.make_tensor_value_info( + "selected_indices", TensorProto.INT64, [None, None]) + + non_max_suppression_node_1 = helper.make_node( + "NonMaxSuppression", + ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"], + ["selected_indices_1"], + center_point_box=0) + non_max_suppression_node_2 = helper.make_node("NonMaxSuppression", [ + "boxes", "scores", "max_output_boxes_per_class", "iou_threshold", + "score_threshold" + ], ["selected_indices_2"], + center_point_box=0) + + then_graph = helper.make_graph(nodes=[non_max_suppression_node_1], + name="then_graph", + inputs=[ + boxes_in, scores_in, + max_output_boxes_per_class_in, + iou_threshold_in + ], + outputs=[selected_indices_1_out]) + else_graph = helper.make_graph(nodes=[non_max_suppression_node_2], + name="then_graph", + inputs=[ + boxes_in, scores_in, + max_output_boxes_per_class_in, + iou_threshold_in, score_threshold_in + ], + outputs=[selected_indices_2_out]) + if_node = helper.make_node('If', ['cond'], ["selected_indices"], + then_branch=then_graph, + else_branch=else_graph) + graph_def = helper.make_graph(nodes=[if_node], + name='test_if', + inputs=[ + boxes_in, scores_in, + max_output_boxes_per_class_in, + iou_threshold_in, score_threshold_in, + cond_in + ], + outputs=[selected_indices_out]) + tf_rep = onnx_graph_to_tensorflow_rep(graph_def) + # export to tf.saved_model + model_path = 'test_dynamic_shape/non_max_suppression/if' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + for cond, exp in [[True, selected_indices_1], [False, selected_indices_2]]: + tf_model_output = tf_model( + boxes=boxes, + scores=scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + cond=cond) + np.testing.assert_almost_equal(tf_model_output[0], exp) def test_scatter_nd(self): if legacy_opset_pre_ver(11): @@ -478,8 +634,14 @@ def test_scatter_nd(self): [None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"data": data, "indices": indices, "updates": updates}) - np.testing.assert_almost_equal(output["outputs"], ref_output) + # export to tf.saved_model + model_path = 'test_dynamic_shape/scatter_nd' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(data=data, indices=indices, updates=updates) + np.testing.assert_almost_equal(tf_model_output[0], ref_output) def test_max_pool_2d_dilations_ceil_pads(self): if legacy_opset_pre_ver(10): @@ -526,16 +688,21 @@ def test_max_pool_2d_dilations_ceil_pads(self): [None, None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) - - np.testing.assert_almost_equal(output["Y"], test_output) + # export to tf.saved_model + model_path = 'test_dynamic_shape/max_pool_2d_dilations_ceil_pads' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) + np.testing.assert_almost_equal(tf_model_output[0], test_output) def test_max_pool_with_argmax_2d_dilations_ceil_pads(self): if legacy_opset_pre_ver(10): raise unittest.SkipTest( "ONNX version {} doesn't support dilations nor ceil mode.".format( defs.onnx_opset_version())) - + kernel_shape = [3, 3] strides = [2, 2] dilations = [3, 3] @@ -567,7 +734,13 @@ def test_max_pool_with_argmax_2d_dilations_ceil_pads(self): ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/max_pool_with_argmax_2d_dilations_ceil_pads' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) test_output, test_ind = py_pool(x, kernel_shape=kernel_shape, @@ -577,8 +750,8 @@ def test_max_pool_with_argmax_2d_dilations_ceil_pads(self): ceil_mode=ceil_mode, pooling_type="MAX") - np.testing.assert_almost_equal(output["Y"], test_output) - np.testing.assert_almost_equal(output["Ind"], test_ind) + np.testing.assert_almost_equal(tf_model_output[0], test_output) + np.testing.assert_almost_equal(tf_model_output[1], test_ind) def test_average_pool_2d(self): kernel_shape = [1, 2] @@ -611,9 +784,14 @@ def test_average_pool_2d(self): [None, None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output = tf_rep.run({"X": x}) - - np.testing.assert_almost_equal(output["Y"], test_output) + # export to tf.saved_model + model_path = 'test_dynamic_shape/average_pool_2d' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) + np.testing.assert_almost_equal(tf_model_output[0], test_output) def test_max_unpool(self): input_shape = [10, 3, 24, 24] @@ -645,7 +823,13 @@ def test_max_unpool(self): [None, None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) - output_unpool = tf_rep.run({"X": x}) + # export to tf.saved_model + model_path = 'test_dynamic_shape/max_unpool' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model + tf_model_output = tf_model(X=x) test_output = np.zeros(input_shape) for i1 in range(0, input_shape[0]): @@ -660,7 +844,7 @@ def test_max_unpool(self): max_ind = (j1, j2) j1, j2 = max_ind test_output[i1][i2][j1][j2] = max_val - np.testing.assert_almost_equal(output_unpool["Y"], test_output) + np.testing.assert_almost_equal(tf_model_output[0], test_output) def test_slice(self): # test case 1 with normal inputs @@ -703,20 +887,20 @@ def test_slice(self): [None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) + # export to tf.saved_model + model_path = 'test_dynamic_shape/slice' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) if legacy_opset_pre_ver(10): x = self._get_rnd_float32(shape=[1000]).reshape([10, 10, 10]) - output = tf_rep.run({"X": x}) - np.testing.assert_almost_equal(output["S"], x[0:2, 0:2, 0:2]) + tf_model_output = tf_model(X=x) + np.testing.assert_almost_equal(tf_model_output[0], x[0:2, 0:2, 0:2]) else: x = self._get_rnd_float32(shape=[1000]).reshape([10, 10, 10]) - output = tf_rep.run({ - "X": x, - "starts": starts, - "ends": ends, - "axes": axes - }) - np.testing.assert_almost_equal(output["S"], x[0:2, 0:2, 0:2]) + tf_model_output = tf_model(X=x, starts=starts, ends=ends, axes=axes) + np.testing.assert_almost_equal(tf_model_output[0], x[0:2, 0:2, 0:2]) # test case 2 with negative, out-of-bound and default inputs axes = [0, 2] @@ -761,20 +945,24 @@ def test_slice(self): [None, None, None]) ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) + # export to tf.saved_model + model_path = 'test_dynamic_shape/slice' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + if legacy_opset_pre_ver(10): x = self._get_rnd_float32(shape=[1000]).reshape([10, 10, 10]) - output = tf_rep.run({"X": x}) - np.testing.assert_almost_equal(output["S"], x[0:-8, :, -7:20]) + tf_model_output = tf_model(X=x) + np.testing.assert_almost_equal(tf_model_output[0], x[0:-8, :, -7:20]) else: x = self._get_rnd_float32(shape=[1000]).reshape([10, 10, 10]) - output = tf_rep.run({ - "X": x, - "starts": starts, - "ends": ends, - "axes": axes, - "steps": steps - }) - np.testing.assert_almost_equal(output["S"], x[0:-8, :, -7:20]) + tf_model_output = tf_model(X=x, + starts=starts, + ends=ends, + axes=axes, + steps=steps) + np.testing.assert_almost_equal(tf_model_output[0], x[0:-8, :, -7:20]) # test case 3 with non-default steps axes = [0, 1, 2] @@ -784,14 +972,13 @@ def test_slice(self): if not legacy_opset_pre_ver(10): x = self._get_rnd_float32(shape=[1000]).reshape([10, 10, 10]) - output = tf_rep.run({ - "X": x, - "starts": starts, - "ends": ends, - "axes": axes, - "steps": steps - }) - np.testing.assert_almost_equal(output["S"], x[0:2:2, 0:2:-2, 0:2:-1]) + tf_model_output = tf_model(X=x, + starts=starts, + ends=ends, + axes=axes, + steps=steps) + np.testing.assert_almost_equal(tf_model_output[0], x[0:2:2, 0:2:-2, + 0:2:-1]) def test_split(self): shape = [12, 12] @@ -814,14 +1001,30 @@ def test_split(self): ]) tf_rep = onnx_graph_to_tensorflow_rep(graph_def) + # export to tf.saved_model + model_path = 'test_dynamic_shape/split' + tf_rep.export_graph(model_path) + # load the saved_model back + tf_model = tf.saved_model.load(model_path) + # run the model x = self._get_rnd_float32(shape=shape) - output = tf_rep.run({"X": x}) + tf_model_output = tf_model(X=x) per_part = shape[axis] // output_count split = [per_part] * output_count - for a, b in zip(list(output), np.split(x, np.cumsum(split))[:-1]): + for a, b in zip(list(tf_model_output), np.split(x, np.cumsum(split))[:-1]): np.testing.assert_almost_equal(a, b) + @classmethod + def tearDownClass(cls): + # clean up saved model folder + try: + model_path = 'test_dynamic_shape' + shutil.rmtree(model_path) + except FileNotFoundError: + # the model folder doesn't exist + pass + if __name__ == '__main__': unittest.main() From 7e4802c5296f5e7086b1f1a9fd51b52d971f692c Mon Sep 17 00:00:00 2001 From: Jason Plurad Date: Tue, 10 Nov 2020 18:49:29 -0500 Subject: [PATCH 2/3] Fixed Dropout logic when is_test == 1 and opset < 7 (#774) * Fixed Dropout logic when is_test == 1 and opset < 7 A recent change incorrectly moved the is_test check later in the conditional logic, causing failures in several model zoo models, like caffenet-3 and squeezenet1.0-6. Signed-off-by: Jason Plurad * Improved Dropout logic when is_test == 1 and opset < 7 Signed-off-by: Jason Plurad --- onnx_tf/handlers/backend/dropout.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnx_tf/handlers/backend/dropout.py b/onnx_tf/handlers/backend/dropout.py index 509886ca3..65ffef374 100644 --- a/onnx_tf/handlers/backend/dropout.py +++ b/onnx_tf/handlers/backend/dropout.py @@ -17,10 +17,12 @@ def _common(cls, node, **kwargs): x = tensor_dict[node.inputs[0]] attrs = copy.deepcopy(node.attrs) - if cls.SINCE_VERSION < 7: + if cls.SINCE_VERSION < 7 and attrs.pop("is_test", 0) == 0: attrs["keep_prob"] = 1 - attrs.pop("ratio", 0.5) return [cls.make_tensor_from_onnx_node(node, attrs=attrs, **kwargs)] - elif cls.SINCE_VERSION < 12 or attrs.pop("is_test", 0) == 1: # for Opset 7, 10 + elif cls.SINCE_VERSION < 12 : # for Opset 7, 10 + # at inference mode, is_test attribute is always set to 1 + # dropout at inference mode is a no-op return [x] else: # for Opset 12, 13 # ratio and training_mode are optional and passed as inputs @@ -30,7 +32,7 @@ def _common(cls, node, **kwargs): training_mode = False # default is false if len(node.inputs) == 3: training_mode = tensor_dict[node.inputs[2]] - + return_mask = len(node.outputs) == 2 # if there are 2 outputs, mask is requested if ratio == 0 or training_mode is False: # Inferencing if return_mask is True: From a742d2928a9d4f8d384d086f120b01f569fbd455 Mon Sep 17 00:00:00 2001 From: Winnie Tsang Date: Wed, 11 Nov 2020 09:45:45 -0800 Subject: [PATCH 3/3] Modify handlers variables creation process (#801) 1. Create unique handlers' variable name by adding node.name to it. If cannot create unique variable name with node.name then throw exception. 2. Allow handler to set the variable shape base on node.attrs values 3. Move TFModule class from backend.run_node to backend_tf_module.py 4. Create handlers' variables in TFModule.init Signed-off-by: Winnie Tsang --- onnx_tf/backend.py | 14 +-- onnx_tf/backend_tf_module.py | 104 ++++++++++++------ onnx_tf/common/__init__.py | 11 +- onnx_tf/common/exception.py | 15 +++ .../handlers/backend/non_max_suppression.py | 19 ++-- onnx_tf/handlers/backend_handler.py | 12 +- test/backend/test_dynamic_shape.py | 6 +- 7 files changed, 118 insertions(+), 63 deletions(-) diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index c861e308d..f51eda9cc 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -26,7 +26,7 @@ from onnx_tf.common import supports_device as common_supports_device from onnx_tf.common.handler_helper import get_all_backend_handlers from onnx_tf.pb_wrapper import OnnxNode -from onnx_tf.backend_tf_module import BackendTFModule +from onnx_tf.backend_tf_module import BackendTFModule, TFModule import onnx_tf.common as common @@ -205,16 +205,6 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs): :return: Outputs. """ - class TFModule(tf.Module): - - def __init__(self, node): - super(TFModule, self).__init__() - self.node = node - - @tf.function - def __call__(self, **input_dict): - return cls._onnx_node_to_tensorflow_op(self.node, input_dict) - super(TensorflowBackend, cls).run_node(node, inputs, device) common.sys_config.device = device @@ -233,7 +223,7 @@ def __call__(self, **input_dict): input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items() ]) - module = TFModule(node) + module = TFModule(node, cls) output_vals = module(**input_dict) output_vals = [ diff --git a/onnx_tf/backend_tf_module.py b/onnx_tf/backend_tf_module.py index fcf191e71..2124c5225 100644 --- a/onnx_tf/backend_tf_module.py +++ b/onnx_tf/backend_tf_module.py @@ -1,9 +1,13 @@ -from onnx.defs import ONNX_DOMAIN import tensorflow as tf +from onnx_tf.common import exception +from onnx_tf.common import get_variable_name from onnx_tf.pb_wrapper import OnnxNode class BackendTFModule(tf.Module): + """ BackendTFModule is the tf.Module class used in backend.prepare, + tf_rep.export_graph and tf_rep.run + """ def __init__(self, handlers, opset, strict, graph_def, backend): super(BackendTFModule, self).__init__() @@ -42,31 +46,34 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict): # create tf.Variable for handlers that required to use variable in handler def _create_handlers_variables(self, graph, vars_dict): - handlers = self.backend._get_handlers(self.opset) - for node in graph.node: - handler = handlers[node.domain].get( - node.op_type, None) if node.domain in handlers else None - if handler and bool(handler.get_req_vars_template()): - for v_name, v_template in handler.get_req_vars_template().items(): - v_init, v_shape = v_template - v_count = 0 - for var_name in vars_dict.keys(): - v_count = v_count + 1 if var_name.startswith(v_name) else v_count - v_name = v_name + '_' + str(v_count) - vars_dict[v_name] = tf.Variable(v_init, - dtype=v_init.dtype, - shape=v_shape, - name=v_name) - if node.op_type in ['Loop', 'Scan']: - onnx_node = OnnxNode(node) - body = onnx_node.attrs["body"] - vars_dict = self._create_handlers_variables(body, vars_dict) - elif node.op_type == 'If': - onnx_node = OnnxNode(node) - then_branch = onnx_node.attrs['then_branch'] - vars_dict = self._create_handlers_variables(then_branch, vars_dict) - else_branch = onnx_node.attrs['else_branch'] - vars_dict = self._create_handlers_variables(else_branch, vars_dict) + if self.handlers: + handlers = self.backend._get_handlers(self.opset) + for node in graph.node: + handler = handlers[node.domain].get( + node.op_type, None) if node.domain in handlers else None + if handler and bool( + handler.get_req_vars_template(node, self.initializer_dict)): + for v_name, v_template in handler.get_req_vars_template( + node, self.initializer_dict).items(): + v_init, v_shape = v_template + v_name = get_variable_name(node, v_name) + if v_name in vars_dict.keys(): + # found duplicated variable name due to non unique node name + exception.NON_UNIQUE_NODE_NAME_EXCEPT() + vars_dict[v_name] = tf.Variable(v_init, + dtype=v_init.dtype, + shape=v_shape, + name=v_name) + if node.op_type in ['Loop', 'Scan']: + onnx_node = OnnxNode(node) + body = onnx_node.attrs["body"] + vars_dict = self._create_handlers_variables(body, vars_dict) + elif node.op_type == 'If': + onnx_node = OnnxNode(node) + then_branch = onnx_node.attrs['then_branch'] + vars_dict = self._create_handlers_variables(then_branch, vars_dict) + else_branch = onnx_node.attrs['else_branch'] + vars_dict = self._create_handlers_variables(else_branch, vars_dict) return vars_dict @tf.function @@ -85,11 +92,6 @@ def gen_tensor_dict(self, input_dict): curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map) - # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN) - # TODO update this when we support handlers in other domain - for _, handler in self.handlers[ONNX_DOMAIN].items(): - handler.VAR_COUNT = 0 - return tensor_dict @tf.function @@ -110,8 +112,40 @@ def __call__(self, **kwargs): outputs = [tensor_dict[output] for output in self.outputs] - # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN) - # TODO update this when we support handlers in other domain - for _, handler in self.handlers[ONNX_DOMAIN].items(): - handler.VAR_COUNT = 0 + return outputs + + +class TFModule(tf.Module): + """ TFModule is the tf.Module class used in backend.run_node. + """ + + def __init__(self, node, backend): + super(TFModule, self).__init__() + self.node = node + self.backend = backend + self.handlers = backend._get_handlers(opset=None) + self.handler_variables = self._create_handlers_variables(dict()) + + def _create_handlers_variables(self, vars_dict): + if self.handlers: + handler = self.handlers[self.node.domain].get( + self.node.op_type, + None) if self.node.domain in self.handlers else None + if handler and bool( + handler.get_req_vars_template(self.node, self.node.attrs)): + for v_name, v_template in handler.get_req_vars_template( + self.node, self.node.attrs).items(): + v_init, v_shape = v_template + v_name = get_variable_name(self.node, v_name) + vars_dict[v_name] = tf.Variable(v_init, + dtype=v_init.dtype, + shape=v_shape, + name=v_name) + return vars_dict + + @tf.function + def __call__(self, **input_dict): + input_dict.update(self.handler_variables) + outputs = self.backend._onnx_node_to_tensorflow_op(self.node, input_dict, + self.handlers) return outputs diff --git a/onnx_tf/common/__init__.py b/onnx_tf/common/__init__.py index c0ce036f7..eefa1c5a1 100644 --- a/onnx_tf/common/__init__.py +++ b/onnx_tf/common/__init__.py @@ -31,7 +31,6 @@ def __init__(self): self.device = 'CPU' - sys_config = SysConfig() @@ -183,6 +182,16 @@ def supports_device(device): return False +def get_variable_name(node, var_name): + """ Get variable name. + :param node: ONNX NodeProto object + :param var_name: name of the variable + :return: unique variable name + """ + v_name = node.op_type.lower() + '_' + var_name + return v_name + '_' + node.name.lower() if node.name else v_name + + CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32" CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32" CONST_ONE_INT32 = "_onnx_tf_internal_one_int32" diff --git a/onnx_tf/common/exception.py b/onnx_tf/common/exception.py index 1019af928..c1c9bac35 100644 --- a/onnx_tf/common/exception.py +++ b/onnx_tf/common/exception.py @@ -81,8 +81,23 @@ def get_message(self, op, supported_dtypes): return self._message.format(op, supported_dtypes) +class NonUniqueNodeNameException(object): + + def __init__(self): + super(NonUniqueNodeNameException, self).__init__() + self._func = RuntimeError + self._message = "Node name is not unique in your model. Please recreate your model with unique node name." + + def __call__(self): + raise self._func(self.get_message()) + + def get_message(self): + return self._message.format() + + IGNORE_UNIMPLEMENTED = False OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException() OP_UNSUPPORTED_EXCEPT = OpUnsupportedException() CONST_NOT_FOUND_EXCEPT = ConstNotFoundException() DTYPE_NOT_CAST_EXCEPT = DtypeNotCastException() +NONUNIQUE_NODE_NAME_EXCEPT = NonUniqueNodeNameException() diff --git a/onnx_tf/handlers/backend/non_max_suppression.py b/onnx_tf/handlers/backend/non_max_suppression.py index 3374b1bf7..91c88d3af 100644 --- a/onnx_tf/handlers/backend/non_max_suppression.py +++ b/onnx_tf/handlers/backend/non_max_suppression.py @@ -1,5 +1,6 @@ import tensorflow as tf +from onnx_tf.common import get_variable_name from onnx_tf.common.tf_helper import tf_shape from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op @@ -7,16 +8,19 @@ @onnx_op("NonMaxSuppression") class NonMaxSuppression(BackendHandler): - var_prefix = 'non_max_suppression_result' + var_name = 'result' @classmethod - def get_req_vars_template(cls): - """ Get required variables template. - - :return: Dict. + def get_req_vars_template(cls, node, init_dict): + """ Get required variables template, which is a + dictionary of variable names with initial value and + shape. + :param node: ONNX NodeProto object. + :param init_dict: initializer dictionary of the graph. + :return: Dictionary. """ return { - cls.var_prefix: [ + cls.var_name: [ tf.constant([[0, 0, 0]], dtype=tf.int64), tf.TensorShape([None, 3]) ] @@ -96,10 +100,9 @@ def create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold, result = output if tf.equal(batch_i, 0) and tf.equal( class_j, 0) else tf.concat([result, output], 0) - cls.VAR_COUNT = cls.VAR_COUNT + 1 return result - result = tensor_dict[cls.var_prefix + '_' + str(cls.VAR_COUNT)] + result = tensor_dict[get_variable_name(node, cls.var_name)] return [ create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, result) diff --git a/onnx_tf/handlers/backend_handler.py b/onnx_tf/handlers/backend_handler.py index 9933fce4e..3f842fab5 100644 --- a/onnx_tf/handlers/backend_handler.py +++ b/onnx_tf/handlers/backend_handler.py @@ -24,13 +24,15 @@ class BackendHandler(Handler): """ TF_FUNC = None - VAR_COUNT = 0 @classmethod - def get_req_vars_template(cls): - """ Get required variables template. - - :return: Dict. + def get_req_vars_template(cls, node, init_dict): + """ Get required variables template, which is a + dictionary of variable names with initial value and + shape + :param node: ONNX NodeProto object. + :param init_dict: initializer dictionary of the graph. + :return: Dictionary. """ return {} diff --git a/test/backend/test_dynamic_shape.py b/test/backend/test_dynamic_shape.py index c6739ddbf..178fe4ebf 100644 --- a/test/backend/test_dynamic_shape.py +++ b/test/backend/test_dynamic_shape.py @@ -550,12 +550,14 @@ def test_non_max_suppression_with_if(self): "NonMaxSuppression", ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"], ["selected_indices_1"], - center_point_box=0) + center_point_box=0, + name='NonMaxSuppression_1') non_max_suppression_node_2 = helper.make_node("NonMaxSuppression", [ "boxes", "scores", "max_output_boxes_per_class", "iou_threshold", "score_threshold" ], ["selected_indices_2"], - center_point_box=0) + center_point_box=0, + name='NonMaxSuppression_2') then_graph = helper.make_graph(nodes=[non_max_suppression_node_1], name="then_graph",