Skip to content

Commit

Permalink
Modify handlers variables creation process (#801)
Browse files Browse the repository at this point in the history
1. Create unique handlers' variable name by adding node.name to it.
If cannot create unique variable name with node.name then throw
exception.
2. Allow handler to set the variable shape base on node.attrs values
3. Move TFModule class from backend.run_node to backend_tf_module.py
4. Create handlers' variables in TFModule.init

Signed-off-by: Winnie Tsang <[email protected]>
  • Loading branch information
winnietsang authored Nov 11, 2020
1 parent 7e4802c commit a742d29
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 63 deletions.
14 changes: 2 additions & 12 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from onnx_tf.common import supports_device as common_supports_device
from onnx_tf.common.handler_helper import get_all_backend_handlers
from onnx_tf.pb_wrapper import OnnxNode
from onnx_tf.backend_tf_module import BackendTFModule
from onnx_tf.backend_tf_module import BackendTFModule, TFModule
import onnx_tf.common as common


Expand Down Expand Up @@ -205,16 +205,6 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
:return: Outputs.
"""

class TFModule(tf.Module):

def __init__(self, node):
super(TFModule, self).__init__()
self.node = node

@tf.function
def __call__(self, **input_dict):
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)

super(TensorflowBackend, cls).run_node(node, inputs, device)
common.sys_config.device = device

Expand All @@ -233,7 +223,7 @@ def __call__(self, **input_dict):
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])

module = TFModule(node)
module = TFModule(node, cls)

output_vals = module(**input_dict)
output_vals = [
Expand Down
104 changes: 69 additions & 35 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from onnx.defs import ONNX_DOMAIN
import tensorflow as tf
from onnx_tf.common import exception
from onnx_tf.common import get_variable_name
from onnx_tf.pb_wrapper import OnnxNode


class BackendTFModule(tf.Module):
""" BackendTFModule is the tf.Module class used in backend.prepare,
tf_rep.export_graph and tf_rep.run
"""

def __init__(self, handlers, opset, strict, graph_def, backend):
super(BackendTFModule, self).__init__()
Expand Down Expand Up @@ -42,31 +46,34 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):

# create tf.Variable for handlers that required to use variable in handler
def _create_handlers_variables(self, graph, vars_dict):
handlers = self.backend._get_handlers(self.opset)
for node in graph.node:
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler and bool(handler.get_req_vars_template()):
for v_name, v_template in handler.get_req_vars_template().items():
v_init, v_shape = v_template
v_count = 0
for var_name in vars_dict.keys():
v_count = v_count + 1 if var_name.startswith(v_name) else v_count
v_name = v_name + '_' + str(v_count)
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
vars_dict = self._create_handlers_variables(body, vars_dict)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
else_branch = onnx_node.attrs['else_branch']
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
if self.handlers:
handlers = self.backend._get_handlers(self.opset)
for node in graph.node:
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler and bool(
handler.get_req_vars_template(node, self.initializer_dict)):
for v_name, v_template in handler.get_req_vars_template(
node, self.initializer_dict).items():
v_init, v_shape = v_template
v_name = get_variable_name(node, v_name)
if v_name in vars_dict.keys():
# found duplicated variable name due to non unique node name
exception.NON_UNIQUE_NODE_NAME_EXCEPT()
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
vars_dict = self._create_handlers_variables(body, vars_dict)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
else_branch = onnx_node.attrs['else_branch']
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
return vars_dict

@tf.function
Expand All @@ -85,11 +92,6 @@ def gen_tensor_dict(self, input_dict):
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
# TODO update this when we support handlers in other domain
for _, handler in self.handlers[ONNX_DOMAIN].items():
handler.VAR_COUNT = 0

return tensor_dict

@tf.function
Expand All @@ -110,8 +112,40 @@ def __call__(self, **kwargs):

outputs = [tensor_dict[output] for output in self.outputs]

# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
# TODO update this when we support handlers in other domain
for _, handler in self.handlers[ONNX_DOMAIN].items():
handler.VAR_COUNT = 0
return outputs


class TFModule(tf.Module):
""" TFModule is the tf.Module class used in backend.run_node.
"""

def __init__(self, node, backend):
super(TFModule, self).__init__()
self.node = node
self.backend = backend
self.handlers = backend._get_handlers(opset=None)
self.handler_variables = self._create_handlers_variables(dict())

def _create_handlers_variables(self, vars_dict):
if self.handlers:
handler = self.handlers[self.node.domain].get(
self.node.op_type,
None) if self.node.domain in self.handlers else None
if handler and bool(
handler.get_req_vars_template(self.node, self.node.attrs)):
for v_name, v_template in handler.get_req_vars_template(
self.node, self.node.attrs).items():
v_init, v_shape = v_template
v_name = get_variable_name(self.node, v_name)
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
return vars_dict

@tf.function
def __call__(self, **input_dict):
input_dict.update(self.handler_variables)
outputs = self.backend._onnx_node_to_tensorflow_op(self.node, input_dict,
self.handlers)
return outputs
11 changes: 10 additions & 1 deletion onnx_tf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self):
self.device = 'CPU'



sys_config = SysConfig()


Expand Down Expand Up @@ -183,6 +182,16 @@ def supports_device(device):
return False


def get_variable_name(node, var_name):
""" Get variable name.
:param node: ONNX NodeProto object
:param var_name: name of the variable
:return: unique variable name
"""
v_name = node.op_type.lower() + '_' + var_name
return v_name + '_' + node.name.lower() if node.name else v_name


CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
Expand Down
15 changes: 15 additions & 0 deletions onnx_tf/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,23 @@ def get_message(self, op, supported_dtypes):
return self._message.format(op, supported_dtypes)


class NonUniqueNodeNameException(object):

def __init__(self):
super(NonUniqueNodeNameException, self).__init__()
self._func = RuntimeError
self._message = "Node name is not unique in your model. Please recreate your model with unique node name."

def __call__(self):
raise self._func(self.get_message())

def get_message(self):
return self._message.format()


IGNORE_UNIMPLEMENTED = False
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
DTYPE_NOT_CAST_EXCEPT = DtypeNotCastException()
NONUNIQUE_NODE_NAME_EXCEPT = NonUniqueNodeNameException()
19 changes: 11 additions & 8 deletions onnx_tf/handlers/backend/non_max_suppression.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import tensorflow as tf

from onnx_tf.common import get_variable_name
from onnx_tf.common.tf_helper import tf_shape
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op


@onnx_op("NonMaxSuppression")
class NonMaxSuppression(BackendHandler):
var_prefix = 'non_max_suppression_result'
var_name = 'result'

@classmethod
def get_req_vars_template(cls):
""" Get required variables template.
:return: Dict.
def get_req_vars_template(cls, node, init_dict):
""" Get required variables template, which is a
dictionary of variable names with initial value and
shape.
:param node: ONNX NodeProto object.
:param init_dict: initializer dictionary of the graph.
:return: Dictionary.
"""
return {
cls.var_prefix: [
cls.var_name: [
tf.constant([[0, 0, 0]], dtype=tf.int64),
tf.TensorShape([None, 3])
]
Expand Down Expand Up @@ -96,10 +100,9 @@ def create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
result = output if tf.equal(batch_i, 0) and tf.equal(
class_j, 0) else tf.concat([result, output], 0)

cls.VAR_COUNT = cls.VAR_COUNT + 1
return result

result = tensor_dict[cls.var_prefix + '_' + str(cls.VAR_COUNT)]
result = tensor_dict[get_variable_name(node, cls.var_name)]
return [
create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, result)
Expand Down
12 changes: 7 additions & 5 deletions onnx_tf/handlers/backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ class BackendHandler(Handler):
"""

TF_FUNC = None
VAR_COUNT = 0

@classmethod
def get_req_vars_template(cls):
""" Get required variables template.
:return: Dict.
def get_req_vars_template(cls, node, init_dict):
""" Get required variables template, which is a
dictionary of variable names with initial value and
shape
:param node: ONNX NodeProto object.
:param init_dict: initializer dictionary of the graph.
:return: Dictionary.
"""
return {}

Expand Down
6 changes: 4 additions & 2 deletions test/backend/test_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,14 @@ def test_non_max_suppression_with_if(self):
"NonMaxSuppression",
["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"],
["selected_indices_1"],
center_point_box=0)
center_point_box=0,
name='NonMaxSuppression_1')
non_max_suppression_node_2 = helper.make_node("NonMaxSuppression", [
"boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"
], ["selected_indices_2"],
center_point_box=0)
center_point_box=0,
name='NonMaxSuppression_2')

then_graph = helper.make_graph(nodes=[non_max_suppression_node_1],
name="then_graph",
Expand Down

0 comments on commit a742d29

Please sign in to comment.