Skip to content

Commit

Permalink
Merge branch 'master' into fix_instance_norm
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Fujimoto <[email protected]>
  • Loading branch information
masakistan authored and EC2 Default User committed Nov 12, 2020
2 parents ac891dc + a742d29 commit 7bab839
Show file tree
Hide file tree
Showing 14 changed files with 504 additions and 135 deletions.
2 changes: 1 addition & 1 deletion doc/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ optional arguments:
### Convert:

#### From ONNX to Tensorflow:
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`

More information: `onnx-tf convert -h`
```
Expand Down
2 changes: 1 addition & 1 deletion doc/CLI_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ More information: `onnx-tf -h`
### Convert:

#### From ONNX to Tensorflow:
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`

More information: `onnx-tf convert -h`
```
Expand Down
45 changes: 33 additions & 12 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from onnx_tf.common import supports_device as common_supports_device
from onnx_tf.common.handler_helper import get_all_backend_handlers
from onnx_tf.pb_wrapper import OnnxNode
from onnx_tf.backend_tf_module import BackendTFModule
from onnx_tf.backend_tf_module import BackendTFModule, TFModule
import onnx_tf.common as common


Expand Down Expand Up @@ -160,8 +160,39 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
tf_rep.signatures = signatures
tf_rep.tensor_dict = module.gen_tensor_dict(
input_dict) if gen_tensor_dict else None
tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def)
return tf_rep

@classmethod
def _get_onnx_op_list(cls, graph_def):
""" Get ONNX operator counts of the model.
:param graph_def: ONNX GraphProto object.
:return: Dictionary of all operators counts in the model.
"""

def get_onnx_op_from_graph_and_subgraph(graph, op_list):
for node in graph.node:
op_list[node.op_type] = 1 if node.op_type not in op_list.keys(
) else op_list[node.op_type] + 1
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
op_list = get_onnx_op_from_graph_and_subgraph(body, op_list)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
op_list = get_onnx_op_from_graph_and_subgraph(then_branch, op_list)
else_branch = onnx_node.attrs['else_branch']
op_list = get_onnx_op_from_graph_and_subgraph(else_branch, op_list)
return op_list

op_list = get_onnx_op_from_graph_and_subgraph(graph_def, dict())
sorted_op_list = dict()
for key in sorted(op_list):
sorted_op_list[key] = op_list[key]
return sorted_op_list

@classmethod
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
""" Run ONNX node.
Expand All @@ -174,16 +205,6 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
:return: Outputs.
"""

class TFModule(tf.Module):

def __init__(self, node):
super(TFModule, self).__init__()
self.node = node

@tf.function
def __call__(self, **input_dict):
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)

super(TensorflowBackend, cls).run_node(node, inputs, device)
common.sys_config.device = device

Expand All @@ -202,7 +223,7 @@ def __call__(self, **input_dict):
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])

module = TFModule(node)
module = TFModule(node, cls)

output_vals = module(**input_dict)
output_vals = [
Expand Down
22 changes: 18 additions & 4 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def tensor_dict(self):
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict

@property
def onnx_op_list(self):
return self._onnx_op_list

@onnx_op_list.setter
def onnx_op_list(self, onnx_op_list):
self._onnx_op_list = onnx_op_list

@property
def tf_module(self):
return self._tf_module
Expand Down Expand Up @@ -80,11 +88,13 @@ def run(self, inputs, **kwargs):
# single input
feed_dict = dict([(self.inputs[0], inputs)])

input_dict = dict(
[(x[0], tf.constant(x[1])) for x in feed_dict.items()])
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict.items()])

output_values = self.tf_module(**input_dict)
output_values = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_values]
output_values = [
val.numpy() if isinstance(val, tf.Tensor) else val
for val in output_values
]

return namedtupledict('Outputs', self.outputs)(*output_values)

Expand All @@ -99,4 +109,8 @@ def export_graph(self, path):
:returns: none.
"""
tf.saved_model.save(self.tf_module, path, signatures=self.tf_module.__call__.get_concrete_function(**self.signatures))
tf.saved_model.save(
self.tf_module,
path,
signatures=self.tf_module.__call__.get_concrete_function(
**self.signatures))
78 changes: 78 additions & 0 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import tensorflow as tf
from onnx_tf.common import exception
from onnx_tf.common import get_variable_name
from onnx_tf.pb_wrapper import OnnxNode


class BackendTFModule(tf.Module):
""" BackendTFModule is the tf.Module class used in backend.prepare,
tf_rep.export_graph and tf_rep.run
"""

def __init__(self, handlers, opset, strict, graph_def, backend):
super(BackendTFModule, self).__init__()
Expand All @@ -14,6 +19,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
self.outputs = []
self.initializer_dict = self._get_initializer_from_graph_and_subgraphs(
self.graph_def, dict())
self.handler_variables = self._create_handlers_variables(
self.graph_def, dict())

# get initializer from the main graph and all subgraphs in loop or if or scan
# into tensor_dict
Expand All @@ -37,10 +44,43 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
else_branch, graph_tensor_dict)
return graph_tensor_dict

# create tf.Variable for handlers that required to use variable in handler
def _create_handlers_variables(self, graph, vars_dict):
if self.handlers:
handlers = self.backend._get_handlers(self.opset)
for node in graph.node:
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler and bool(
handler.get_req_vars_template(node, self.initializer_dict)):
for v_name, v_template in handler.get_req_vars_template(
node, self.initializer_dict).items():
v_init, v_shape = v_template
v_name = get_variable_name(node, v_name)
if v_name in vars_dict.keys():
# found duplicated variable name due to non unique node name
exception.NON_UNIQUE_NODE_NAME_EXCEPT()
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
if node.op_type in ['Loop', 'Scan']:
onnx_node = OnnxNode(node)
body = onnx_node.attrs["body"]
vars_dict = self._create_handlers_variables(body, vars_dict)
elif node.op_type == 'If':
onnx_node = OnnxNode(node)
then_branch = onnx_node.attrs['then_branch']
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
else_branch = onnx_node.attrs['else_branch']
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
return vars_dict

@tf.function
def gen_tensor_dict(self, input_dict):
tensor_dict = dict(input_dict)
tensor_dict.update(self.initializer_dict)
tensor_dict.update(self.handler_variables)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand All @@ -58,6 +98,7 @@ def gen_tensor_dict(self, input_dict):
def __call__(self, **kwargs):
tensor_dict = kwargs
tensor_dict.update(self.initializer_dict)
tensor_dict.update(self.handler_variables)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
Expand All @@ -70,4 +111,41 @@ def __call__(self, **kwargs):
tensor_dict.update(curr_node_output_map)

outputs = [tensor_dict[output] for output in self.outputs]

return outputs


class TFModule(tf.Module):
""" TFModule is the tf.Module class used in backend.run_node.
"""

def __init__(self, node, backend):
super(TFModule, self).__init__()
self.node = node
self.backend = backend
self.handlers = backend._get_handlers(opset=None)
self.handler_variables = self._create_handlers_variables(dict())

def _create_handlers_variables(self, vars_dict):
if self.handlers:
handler = self.handlers[self.node.domain].get(
self.node.op_type,
None) if self.node.domain in self.handlers else None
if handler and bool(
handler.get_req_vars_template(self.node, self.node.attrs)):
for v_name, v_template in handler.get_req_vars_template(
self.node, self.node.attrs).items():
v_init, v_shape = v_template
v_name = get_variable_name(self.node, v_name)
vars_dict[v_name] = tf.Variable(v_init,
dtype=v_init.dtype,
shape=v_shape,
name=v_name)
return vars_dict

@tf.function
def __call__(self, **input_dict):
input_dict.update(self.handler_variables)
outputs = self.backend._onnx_node_to_tensorflow_op(self.node, input_dict,
self.handlers)
return outputs
11 changes: 10 additions & 1 deletion onnx_tf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self):
self.device = 'CPU'



sys_config = SysConfig()


Expand Down Expand Up @@ -183,6 +182,16 @@ def supports_device(device):
return False


def get_variable_name(node, var_name):
""" Get variable name.
:param node: ONNX NodeProto object
:param var_name: name of the variable
:return: unique variable name
"""
v_name = node.op_type.lower() + '_' + var_name
return v_name + '_' + node.name.lower() if node.name else v_name


CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
Expand Down
15 changes: 15 additions & 0 deletions onnx_tf/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,23 @@ def get_message(self, op, supported_dtypes):
return self._message.format(op, supported_dtypes)


class NonUniqueNodeNameException(object):

def __init__(self):
super(NonUniqueNodeNameException, self).__init__()
self._func = RuntimeError
self._message = "Node name is not unique in your model. Please recreate your model with unique node name."

def __call__(self):
raise self._func(self.get_message())

def get_message(self):
return self._message.format()


IGNORE_UNIMPLEMENTED = False
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
DTYPE_NOT_CAST_EXCEPT = DtypeNotCastException()
NONUNIQUE_NODE_NAME_EXCEPT = NonUniqueNodeNameException()
15 changes: 8 additions & 7 deletions onnx_tf/handlers/backend/conv_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def conv(cls, node, input_dict, transpose=False):

if "kernel_shape" in node.attrs.keys():
kernel_shape = node.attrs["kernel_shape"]
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
"kernel_shape "
"attr of convolution does not match the actual weight "
"passed to this operation, attr {}, actual {}").format(
kernel_shape,
in_weights.get_shape().as_list())
if in_weights.get_shape().is_fully_defined():
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
"kernel_shape "
"attr of convolution does not match the actual weight "
"passed to this operation, attr {}, actual {}").format(
kernel_shape,
in_weights.get_shape().as_list())
else:
kernel_shape = in_weights.get_shape().as_list()[2:]
kernel_shape = tf_shape(in_weights, tf.int32)[2:]

weights = tf.transpose(in_weights, perm)
dilations = node.attrs.get("dilations", [1] * spatial_size)
Expand Down
9 changes: 6 additions & 3 deletions onnx_tf/handlers/backend/dilated_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,12 @@ def dilated_maxpool_with_argmax(self, force_custom_impl=False):

# if there was padding, recalculate the returned index
# to exclude the padding
count_nonzero_op = np.count_nonzero if self.is_known_shape else tf.math.count_nonzero
if count_nonzero_op(self.pads) != 0:
new_ind = self._calc_argmax_without_padding(new_ind)
if self.is_known_shape:
if np.count_nonzero(self.pads) != 0:
new_ind = self._calc_argmax_without_padding(new_ind)
else:
new_ind = tf.where(tf.not_equal(tf.math.count_nonzero(self.pads), 0),
self._calc_argmax_without_padding(new_ind), new_ind)

return (pooled, new_ind)

Expand Down
8 changes: 5 additions & 3 deletions onnx_tf/handlers/backend/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ def _common(cls, node, **kwargs):
x = tensor_dict[node.inputs[0]]
attrs = copy.deepcopy(node.attrs)

if cls.SINCE_VERSION < 7:
if cls.SINCE_VERSION < 7 and attrs.pop("is_test", 0) == 0:
attrs["keep_prob"] = 1 - attrs.pop("ratio", 0.5)
return [cls.make_tensor_from_onnx_node(node, attrs=attrs, **kwargs)]
elif cls.SINCE_VERSION < 12 or attrs.pop("is_test", 0) == 1: # for Opset 7, 10
elif cls.SINCE_VERSION < 12 : # for Opset 7, 10
# at inference mode, is_test attribute is always set to 1
# dropout at inference mode is a no-op
return [x]
else: # for Opset 12, 13
# ratio and training_mode are optional and passed as inputs
Expand All @@ -30,7 +32,7 @@ def _common(cls, node, **kwargs):
training_mode = False # default is false
if len(node.inputs) == 3:
training_mode = tensor_dict[node.inputs[2]]

return_mask = len(node.outputs) == 2 # if there are 2 outputs, mask is requested
if ratio == 0 or training_mode is False: # Inferencing
if return_mask is True:
Expand Down
Loading

0 comments on commit 7bab839

Please sign in to comment.