Skip to content

Commit

Permalink
Use user defined device choice during prepare model (#726)
Browse files Browse the repository at this point in the history
1. Added sys_config.device and set it to user defined device choice
"CUDA" or "CPU" in prepare and run_node
2. Updated conv_mixin.py to use sys_config.device and eliminate
unnecessary transpose
3. Updated pool_mixin.py to eliminate the mandatory conversion of input
x to channel last(NHWC) format. Added function to convert NHWC indices
to NCHW indices for MaxPool_with_Argmax to fix issue #719
4. Updated unpool_mixin.py to eliminate the mandatory conversion of
input x to channel last(NHWC) format
5. Updated dilated_pooling.py to process NCHW format
input instead of NHWC format for all pooling operators,
except MaxPool_with_Argmax and MaxPool_with_dilation_not_equal_to_1_and
_spatial_size_equal_to_2. (tf.nn.maxpool_with_argmax and
tf.nn.dilation2d only support NHWC format)
6. Added dynamic_shape test for Maxpool_with_Argmax
7. Set device in run_node for operators that behave differently in
NCHW/NHWC format in test_node.py

Signed-off-by: Winnie Tsang <[email protected]>

Co-authored-by: Chin Huang <[email protected]>
  • Loading branch information
winnietsang and chinhuang007 authored Oct 29, 2020
1 parent 651c90f commit c63d435
Show file tree
Hide file tree
Showing 12 changed files with 706 additions and 552 deletions.
2 changes: 1 addition & 1 deletion doc/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ _params_:
`model` : The ONNX model to be converted.


`device` : The device to execute this model on.
`device` : The device to execute this model on. It can be either CPU (default) or CUDA.


`strict` : Whether to enforce semantic equivalence between the original model
Expand Down
4 changes: 2 additions & 2 deletions doc/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ optional arguments:
Output directory.
backend arguments (onnx -> tf):
--device DEVICE The device to execute this model on. (from
onnx_tf.backend.prepare)
--device DEVICE The device to execute this model on. It can be either
CPU (default) or CUDA. (from onnx_tf.backend.prepare)
--strict STRICT Whether to enforce semantic equivalence between the
original model and the converted tensorflow model,
defaults to True (yes, enforce semantic equivalence).
Expand Down
4 changes: 3 additions & 1 deletion onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def prepare(cls,
the converted representation.
:param model: The ONNX model to be converted.
:param device: The device to execute this model on.
:param device: The device to execute this model on. It can be either CPU (default) or CUDA.
:param strict: Whether to enforce semantic equivalence between the original model
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
Changing to False is strongly discouraged.
Expand All @@ -65,6 +65,7 @@ def prepare(cls,
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.sys_config.auto_cast = auto_cast
common.sys_config.device = device

return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)

Expand Down Expand Up @@ -184,6 +185,7 @@ def __call__(self, **input_dict):
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)

super(TensorflowBackend, cls).run_node(node, inputs, device)
common.sys_config.device = device

node = OnnxNode(node)
input_tensors = []
Expand Down
5 changes: 3 additions & 2 deletions onnx_tf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class SysConfig:

def __init__(self):
self.auto_cast = False
self.device = 'CPU'



sys_config = SysConfig()
Expand Down Expand Up @@ -160,7 +162,7 @@ def get_data_format(x_rank):
sp_dim_string = "".join(reversed(sp_dim_lst))
storage_format = "NC" + sp_dim_string

if supports_device("CUDA"):
if sys_config.device == "CUDA":
compute_format = "NC" + sp_dim_string
else:
compute_format = "N" + sp_dim_string + "C"
Expand All @@ -169,7 +171,6 @@ def get_data_format(x_rank):

def supports_device(device):
""" Check if support target device.
:param device: CUDA or CPU.
:return: If supports.
"""
Expand Down
8 changes: 7 additions & 1 deletion onnx_tf/common/pooling_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def py_pool(input, kernel_shape, strides=None, dilations=None,

def _loop_over_output(batch, channel):
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
image_size = 1
for d in input_shape[2:]:
image_size *= d
for counters in itertools.product(*dims):
input_ranges = []
for dim in range(spatial_size):
Expand Down Expand Up @@ -189,7 +192,10 @@ def _loop_over_output(batch, channel):
else:
if val > maxval:
maxval = val
ind = 0
# batch_offset = batch * C * image_size
# channel_offset = channel * image_size
# ind = batch_offset + channel_offset
ind = image_size * (batch * input_shape[1] + channel)
for i in range(spatial_size):
coef = 1
for j in range(i+1, spatial_size):
Expand Down
11 changes: 5 additions & 6 deletions onnx_tf/handlers/backend/conv_mixin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import tensorflow as tf

from onnx_tf.common import exception
from onnx_tf.common import get_data_format
from onnx_tf.common import get_perm_from_formats
from onnx_tf.common import supports_device
from onnx_tf.common import exception
from onnx_tf.common.tf_helper import tf_shape
from onnx_tf.common import sys_config
from .broadcast_mixin import BroadcastMixin
from .pad_mixin import PadMixin

Expand All @@ -31,7 +31,6 @@ def conv(cls, node, input_dict, transpose=False):
x_shape = tf_shape(x, tf.int32)
spatial_size = x_rank - 2

support_cuda = supports_device("CUDA")
storage_format, compute_format = get_data_format(x_rank)
compute_c_idx = compute_format.find("C")
spatial_format = "".join([d for d in compute_format if d not in ["N", "C"]])
Expand Down Expand Up @@ -94,7 +93,7 @@ def conv(cls, node, input_dict, transpose=False):

weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)

if support_cuda:
if sys_config.device == 'CUDA':
xs = tf.split(x, num_or_size_splits=group, axis=1)
else:
x = tf.transpose(x,
Expand Down Expand Up @@ -236,7 +235,7 @@ def conv(cls, node, input_dict, transpose=False):
]

if len(node.inputs) == 2:
if support_cuda:
if sys_config.device == 'CUDA':
output = tf.concat(convolved, axis=1)
else:
output = tf.concat(convolved, axis=-1)
Expand All @@ -247,7 +246,7 @@ def conv(cls, node, input_dict, transpose=False):
bias = input_dict[node.inputs[2]]
bias = cls.explicit_broadcast([x, bias], compute_c_idx)

if support_cuda:
if sys_config.device == 'CUDA':
output = tf.concat(convolved, axis=1)
output = tf.add(output, bias)
else:
Expand Down
Loading

0 comments on commit c63d435

Please sign in to comment.