Skip to content

Commit

Permalink
Fix dynamic shape test and move handler's variable creation into (#795)
Browse files Browse the repository at this point in the history
BackendTFModule.__init__

1. Add export_graph and use this graph to run in Tensorflow instead
of using tf_rep.run for each dynamic shape testcases to verify our
handlers can support unknown shape inputs.
2. Fix conv_mixin to support dynamic shape input for W.
3. Move tf.Variable creation into BackendTFModule.__init__ instead
of creating them inside the handler.
4. Update non_max_suppression handler to create the tf.Variable
in BackendTFModule.__init__ and use the variable in the handler
5. Fix random_normal_like handler to support dynamic shape input.
6. Fix dilated_pooling to handle unknown input shape.
7. Add ONNX operators count to tf_rep.
8. Minor update on CLI_template.md

Signed-off-by: Winnie Tsang <[email protected]>
  • Loading branch information
winnietsang authored Nov 3, 2020
1 parent c63d435 commit 52fbf7b
Show file tree
Hide file tree
Showing 11 changed files with 431 additions and 119 deletions.
2 changes: 1 addition & 1 deletion doc/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ optional arguments:
### Convert:

#### From ONNX to Tensorflow:
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`

More information: `onnx-tf convert -h`
```
Expand Down
2 changes: 1 addition & 1 deletion doc/CLI_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ More information: `onnx-tf -h`
### Convert:

#### From ONNX to Tensorflow:
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`

More information: `onnx-tf convert -h`
```
Expand Down
31 changes: 31 additions & 0 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,39 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
tf_rep.signatures = signatures
tf_rep.tensor_dict = module.gen_tensor_dict(
input_dict) if gen_tensor_dict else None
tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def)
return tf_rep

@classmethod
def _get_onnx_op_list(cls, graph_def):
""" Get ONNX operator counts of the model.
:param graph_def: ONNX GraphProto object.
:return: Dictionary of all operators counts in the model.
"""

def get_onnx_op_from_graph_and_subgraph(graph, op_list):
for node in graph.node:
op_list[node.op_type] = 1 if node.op_type not in op_list.keys(
) else op_list[node.op_type] + 1
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
op_list = get_onnx_op_from_graph_and_subgraph(body, op_list)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
op_list = get_onnx_op_from_graph_and_subgraph(then_branch, op_list)
else_branch = onnx_node.attrs['else_branch']
op_list = get_onnx_op_from_graph_and_subgraph(else_branch, op_list)
return op_list

op_list = get_onnx_op_from_graph_and_subgraph(graph_def, dict())
sorted_op_list = dict()
for key in sorted(op_list):
sorted_op_list[key] = op_list[key]
return sorted_op_list

@classmethod
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
""" Run ONNX node.
Expand Down
22 changes: 18 additions & 4 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def tensor_dict(self):
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict

@property
def onnx_op_list(self):
return self._onnx_op_list

@onnx_op_list.setter
def onnx_op_list(self, onnx_op_list):
self._onnx_op_list = onnx_op_list

@property
def tf_module(self):
return self._tf_module
Expand Down Expand Up @@ -80,11 +88,13 @@ def run(self, inputs, **kwargs):
# single input
feed_dict = dict([(self.inputs[0], inputs)])

input_dict = dict(
[(x[0], tf.constant(x[1])) for x in feed_dict.items()])
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict.items()])

output_values = self.tf_module(**input_dict)
output_values = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_values]
output_values = [
val.numpy() if isinstance(val, tf.Tensor) else val
for val in output_values
]

return namedtupledict('Outputs', self.outputs)(*output_values)

Expand All @@ -99,4 +109,8 @@ def export_graph(self, path):
:returns: none.
"""
tf.saved_model.save(self.tf_module, path, signatures=self.tf_module.__call__.get_concrete_function(**self.signatures))
tf.saved_model.save(
self.tf_module,
path,
signatures=self.tf_module.__call__.get_concrete_function(
**self.signatures))
44 changes: 44 additions & 0 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from onnx.defs import ONNX_DOMAIN
import tensorflow as tf
from onnx_tf.pb_wrapper import OnnxNode

Expand All @@ -14,6 +15,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
self.outputs = []
self.initializer_dict = self._get_initializer_from_graph_and_subgraphs(
self.graph_def, dict())
self.handler_variables = self._create_handlers_variables(
self.graph_def, dict())

# get initializer from the main graph and all subgraphs in loop or if or scan
# into tensor_dict
Expand All @@ -37,10 +40,40 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
else_branch, graph_tensor_dict)
return graph_tensor_dict

# create tf.Variable for handlers that required to use variable in handler
def _create_handlers_variables(self, graph, vars_dict):
handlers = self.backend._get_handlers(self.opset)
for node in graph.node:
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler and bool(handler.get_req_vars_template()):
for v_name, v_template in handler.get_req_vars_template().items():
v_init, v_shape = v_template
v_count = 0
for var_name in vars_dict.keys():
v_count = v_count + 1 if var_name.startswith(v_name) else v_count
v_name = v_name + '_' + str(v_count)
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
vars_dict = self._create_handlers_variables(body, vars_dict)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
else_branch = onnx_node.attrs['else_branch']
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
return vars_dict

@tf.function
def gen_tensor_dict(self, input_dict):
tensor_dict = dict(input_dict)
tensor_dict.update(self.initializer_dict)
tensor_dict.update(self.handler_variables)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand All @@ -52,12 +85,18 @@ def gen_tensor_dict(self, input_dict):
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
# TODO update this when we support handlers in other domain
for _, handler in self.handlers[ONNX_DOMAIN].items():
handler.VAR_COUNT = 0

return tensor_dict

@tf.function
def __call__(self, **kwargs):
tensor_dict = kwargs
tensor_dict.update(self.initializer_dict)
tensor_dict.update(self.handler_variables)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand All @@ -70,4 +109,9 @@ def __call__(self, **kwargs):
tensor_dict.update(curr_node_output_map)

outputs = [tensor_dict[output] for output in self.outputs]

# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
# TODO update this when we support handlers in other domain
for _, handler in self.handlers[ONNX_DOMAIN].items():
handler.VAR_COUNT = 0
return outputs
15 changes: 8 additions & 7 deletions onnx_tf/handlers/backend/conv_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def conv(cls, node, input_dict, transpose=False):

if "kernel_shape" in node.attrs.keys():
kernel_shape = node.attrs["kernel_shape"]
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
"kernel_shape "
"attr of convolution does not match the actual weight "
"passed to this operation, attr {}, actual {}").format(
kernel_shape,
in_weights.get_shape().as_list())
if in_weights.get_shape().is_fully_defined():
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
"kernel_shape "
"attr of convolution does not match the actual weight "
"passed to this operation, attr {}, actual {}").format(
kernel_shape,
in_weights.get_shape().as_list())
else:
kernel_shape = in_weights.get_shape().as_list()[2:]
kernel_shape = tf_shape(in_weights, tf.int32)[2:]

weights = tf.transpose(in_weights, perm)
dilations = node.attrs.get("dilations", [1] * spatial_size)
Expand Down
9 changes: 6 additions & 3 deletions onnx_tf/handlers/backend/dilated_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,12 @@ def dilated_maxpool_with_argmax(self, force_custom_impl=False):

# if there was padding, recalculate the returned index
# to exclude the padding
count_nonzero_op = np.count_nonzero if self.is_known_shape else tf.math.count_nonzero
if count_nonzero_op(self.pads) != 0:
new_ind = self._calc_argmax_without_padding(new_ind)
if self.is_known_shape:
if np.count_nonzero(self.pads) != 0:
new_ind = self._calc_argmax_without_padding(new_ind)
else:
new_ind = tf.where(tf.not_equal(tf.math.count_nonzero(self.pads), 0),
self._calc_argmax_without_padding(new_ind), new_ind)

return (pooled, new_ind)

Expand Down
26 changes: 16 additions & 10 deletions onnx_tf/handlers/backend/non_max_suppression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@

@onnx_op("NonMaxSuppression")
class NonMaxSuppression(BackendHandler):
var_prefix = 'non_max_suppression_result'

result = None
@classmethod
def get_req_vars_template(cls):
""" Get required variables template.
:return: Dict.
"""
return {
cls.var_prefix: [
tf.constant([[0, 0, 0]], dtype=tf.int64),
tf.TensorShape([None, 3])
]
}

@classmethod
def _common(cls, node, **kwargs):
Expand Down Expand Up @@ -84,19 +96,13 @@ def create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
result = output if tf.equal(batch_i, 0) and tf.equal(
class_j, 0) else tf.concat([result, output], 0)

cls.VAR_COUNT = cls.VAR_COUNT + 1
return result

# Since tf.function doesn't support locals() and it require all the variables
# are defined before use in the "for loop" before it will perform any auto
# convertion of the python code. Therefore need to define "result" as a
# Variable here and send it in as a parameter to "create_nodes"
if cls.result is None:
cls.result = tf.Variable([[0, 0, 0]],
dtype=tf.int64,
shape=tf.TensorShape([None, 3]))
result = tensor_dict[cls.var_prefix + '_' + str(cls.VAR_COUNT)]
return [
create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, cls.result)
score_threshold, result)
]

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion onnx_tf/handlers/backend/random_normal_like.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf

from onnx_tf.common.tf_helper import tf_shape
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_func
Expand All @@ -15,5 +16,5 @@ def get_attrs_processor_param(cls):

@classmethod
def version_1(cls, node, **kwargs):
inputs = [kwargs["tensor_dict"][node.inputs[0]].get_shape()]
inputs = [tf_shape(kwargs["tensor_dict"][node.inputs[0]])]
return [cls.make_tensor_from_onnx_node(node, inputs=inputs, **kwargs)]
15 changes: 12 additions & 3 deletions onnx_tf/handlers/backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ class BackendHandler(Handler):
"""

TF_FUNC = None
VAR_COUNT = 0

@classmethod
def get_req_vars_template(cls):
""" Get required variables template.
:return: Dict.
"""
return {}

@classmethod
def get_attrs_processor_param(cls):
Expand Down Expand Up @@ -183,9 +192,9 @@ def _run_tf_func(cls, tf_func, inputs, attrs):

attrs = {p: v for p, v in attrs.items() if p in params}
kwargs = dict(zip(params, inputs))
ambiguous_arguments = any(kwargs.get(p) is not None and v is not None
for p, v in attrs.items())
ambiguous_arguments = any(
kwargs.get(p) is not None and v is not None for p, v in attrs.items())
if ambiguous_arguments:
raise TypeError('Ambiguous arguments for {}()'.format(tf_func.__name__))
kwargs.update((p, v) for p, v in attrs.items() if v is not None)
return tf_func(**kwargs)
return tf_func(**kwargs)
Loading

0 comments on commit 52fbf7b

Please sign in to comment.