diff --git a/doc/API.md b/doc/API.md index 08a7c2b0d..7849fd5d7 100644 --- a/doc/API.md +++ b/doc/API.md @@ -20,7 +20,7 @@ _params_: `model` : The ONNX model to be converted. -`device` : The device to execute this model on. +`device` : The device to execute this model on. It can be either CPU (default) or CUDA. `strict` : Whether to enforce semantic equivalence between the original model diff --git a/doc/CLI.md b/doc/CLI.md index cb55b0b40..0377160ff 100644 --- a/doc/CLI.md +++ b/doc/CLI.md @@ -40,8 +40,8 @@ optional arguments: Output directory. backend arguments (onnx -> tf): - --device DEVICE The device to execute this model on. (from - onnx_tf.backend.prepare) + --device DEVICE The device to execute this model on. It can be either + CPU (default) or CUDA. (from onnx_tf.backend.prepare) --strict STRICT Whether to enforce semantic equivalence between the original model and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence). diff --git a/doc/support_status.md b/doc/support_status.md index 3fe973206..66bee5623 100644 --- a/doc/support_status.md +++ b/doc/support_status.md @@ -1,9 +1,9 @@ # ONNX-Tensorflow Support Status ||| |-:|:-| -|ONNX-Tensorflow Version|Master ( commit id: 6bfd631e0daedbc773b76636a5ea19e77a4b63ed )| -|ONNX Version|Master ( commit id: b2ed660d0a065b8346816f2c3a95d79ca79b88c9 )| -|Tensorflow Version|v2.3.0| +|ONNX-Tensorflow Version|Master ( commit id: f64afb48034af7121341f4ba5d6f56e275c5aedb )| +|ONNX Version|Master ( commit id: a7a0fec7f25cae567429af62b7eaaee1c3f0e247 )| +|Tensorflow Version|v2.3.1| Notes: * Values that are new or updated from a previous opset version are in bold. @@ -51,7 +51,7 @@ Notes: |Div|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**:small_red_triangle:|Div| |Dropout|**1**|1|1|1|1|**6**|**7**|7|7|**10**|10|**12**|**13**|Dropout| |DynamicQuantizeLinear|-|-|-|-|-|-|-|-|-|-|**11**|11|11|DynamicQuantizeLinear| -|Einsum|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|12:small_red_triangle:|Einsum| +|Einsum|-|-|-|-|-|-|-|-|-|-|-|**12**|12|Einsum| |Elu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|Elu| |Equal|**1**|1|1|1|1|1|**7**|7|7|7|**11**|11|**13**|Equal| |Erf|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**|Erf| @@ -72,11 +72,11 @@ Notes: |GreaterOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|GreaterOrEqual| |HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|HardSigmoid| |Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Hardmax| -|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**:small_red_triangle:|Identity| +|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|Identity| |If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|If| |InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|InstanceNormalization| |IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|10|IsInf| -|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**:small_red_triangle:|IsNaN| +|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**|IsNaN| |LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|LRN| |LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM| |LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|LeakyRelu| @@ -100,7 +100,7 @@ Notes: |Mul|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**|Mul| |Multinomial|-|-|-|-|-|-|**7**:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|7:small_red_triangle:|Multinomial| |Neg|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Neg| -|NegativeLogLikelihoodLoss|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|12:small_red_triangle:|NegativeLogLikelihoodLoss| +|NegativeLogLikelihoodLoss|-|-|-|-|-|-|-|-|-|-|-|**12**:small_red_triangle:|**13**:small_red_triangle:|NegativeLogLikelihoodLoss| |NonMaxSuppression|-|-|-|-|-|-|-|-|-|**10**|**11**|11|11|NonMaxSuppression| |NonZero|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**:small_red_triangle:|NonZero| |Not|**1**|1|1|1|1|1|1|1|1|1|1|1|1|Not| @@ -123,10 +123,10 @@ Notes: |ReduceL2|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceL2| |ReduceLogSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceLogSum| |ReduceLogSumExp|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceLogSumExp| -|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ReduceMax| -|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceMean| -|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**:small_red_triangle:|ReduceMin| -|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceProd| +|ReduceMax|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**|ReduceMax| +|ReduceMean|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|ReduceMean| +|ReduceMin|**1**|1|1|1|1|1|1|1|1|1|**11**|**12**|**13**|ReduceMin| +|ReduceProd|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|ReduceProd| |ReduceSum|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceSum| |ReduceSumSquare|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|ReduceSumSquare| |Relu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**:small_red_triangle:|Relu| @@ -179,7 +179,7 @@ Notes: |Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Where| |Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|Xor| -ONNX-TF Supported Operators / ONNX Operators: 105 / 162 +ONNX-TF Supported Operators / ONNX Operators: 118 / 162 Notes: 1. Cast: Cast string to data types other than float32/float64/int32/int64 is not supported in Tensorflow diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index 362782bf3..31188e914 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -49,7 +49,7 @@ def prepare(cls, the converted representation. :param model: The ONNX model to be converted. - :param device: The device to execute this model on. + :param device: The device to execute this model on. It can be either CPU (default) or CUDA. :param strict: Whether to enforce semantic equivalence between the original model and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence). Changing to False is strongly discouraged. @@ -65,6 +65,7 @@ def prepare(cls, common.logger.setLevel(logging_level) common.logger.handlers[0].setLevel(logging_level) common.sys_config.auto_cast = auto_cast + common.sys_config.device = device return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs) @@ -184,6 +185,7 @@ def __call__(self, **input_dict): return cls._onnx_node_to_tensorflow_op(self.node, input_dict) super(TensorflowBackend, cls).run_node(node, inputs, device) + common.sys_config.device = device node = OnnxNode(node) input_tensors = [] diff --git a/onnx_tf/backend_tf_module.py b/onnx_tf/backend_tf_module.py index 3eb6f2279..8774f60e6 100644 --- a/onnx_tf/backend_tf_module.py +++ b/onnx_tf/backend_tf_module.py @@ -12,6 +12,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend): self.graph_def = graph_def self.backend = backend self.outputs = [] + self.initializer_dict = self._get_initializer_from_graph_and_subgraphs( + self.graph_def, dict()) # get initializer from the main graph and all subgraphs in loop or if or scan # into tensor_dict @@ -37,8 +39,8 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict): @tf.function def gen_tensor_dict(self, input_dict): - tensor_dict = self._get_initializer_from_graph_and_subgraphs( - self.graph_def, dict(input_dict)) + tensor_dict = dict(input_dict) + tensor_dict.update(self.initializer_dict) for node in self.graph_def.node: onnx_node = OnnxNode(node) @@ -54,8 +56,8 @@ def gen_tensor_dict(self, input_dict): @tf.function def __call__(self, **kwargs): - tensor_dict = self._get_initializer_from_graph_and_subgraphs( - self.graph_def, kwargs) + tensor_dict = kwargs + tensor_dict.update(self.initializer_dict) for node in self.graph_def.node: onnx_node = OnnxNode(node) diff --git a/onnx_tf/common/__init__.py b/onnx_tf/common/__init__.py index f788c8b9d..c0ce036f7 100644 --- a/onnx_tf/common/__init__.py +++ b/onnx_tf/common/__init__.py @@ -28,6 +28,8 @@ class SysConfig: def __init__(self): self.auto_cast = False + self.device = 'CPU' + sys_config = SysConfig() @@ -160,7 +162,7 @@ def get_data_format(x_rank): sp_dim_string = "".join(reversed(sp_dim_lst)) storage_format = "NC" + sp_dim_string - if supports_device("CUDA"): + if sys_config.device == "CUDA": compute_format = "NC" + sp_dim_string else: compute_format = "N" + sp_dim_string + "C" @@ -169,7 +171,6 @@ def get_data_format(x_rank): def supports_device(device): """ Check if support target device. - :param device: CUDA or CPU. :return: If supports. """ diff --git a/onnx_tf/common/pooling_helper.py b/onnx_tf/common/pooling_helper.py index 7d5f5d92c..1e28aeaec 100644 --- a/onnx_tf/common/pooling_helper.py +++ b/onnx_tf/common/pooling_helper.py @@ -158,6 +158,9 @@ def py_pool(input, kernel_shape, strides=None, dilations=None, def _loop_over_output(batch, channel): dims = [range(output_sp_shape[d]) for d in range(spatial_size)] + image_size = 1 + for d in input_shape[2:]: + image_size *= d for counters in itertools.product(*dims): input_ranges = [] for dim in range(spatial_size): @@ -189,7 +192,10 @@ def _loop_over_output(batch, channel): else: if val > maxval: maxval = val - ind = 0 + # batch_offset = batch * C * image_size + # channel_offset = channel * image_size + # ind = batch_offset + channel_offset + ind = image_size * (batch * input_shape[1] + channel) for i in range(spatial_size): coef = 1 for j in range(i+1, spatial_size): diff --git a/onnx_tf/handlers/backend/conv_mixin.py b/onnx_tf/handlers/backend/conv_mixin.py index ad9ecba63..d33e11e70 100644 --- a/onnx_tf/handlers/backend/conv_mixin.py +++ b/onnx_tf/handlers/backend/conv_mixin.py @@ -1,10 +1,10 @@ import tensorflow as tf +from onnx_tf.common import exception from onnx_tf.common import get_data_format from onnx_tf.common import get_perm_from_formats -from onnx_tf.common import supports_device -from onnx_tf.common import exception from onnx_tf.common.tf_helper import tf_shape +from onnx_tf.common import sys_config from .broadcast_mixin import BroadcastMixin from .pad_mixin import PadMixin @@ -31,7 +31,6 @@ def conv(cls, node, input_dict, transpose=False): x_shape = tf_shape(x, tf.int32) spatial_size = x_rank - 2 - support_cuda = supports_device("CUDA") storage_format, compute_format = get_data_format(x_rank) compute_c_idx = compute_format.find("C") spatial_format = "".join([d for d in compute_format if d not in ["N", "C"]]) @@ -94,7 +93,7 @@ def conv(cls, node, input_dict, transpose=False): weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1) - if support_cuda: + if sys_config.device == 'CUDA': xs = tf.split(x, num_or_size_splits=group, axis=1) else: x = tf.transpose(x, @@ -236,7 +235,7 @@ def conv(cls, node, input_dict, transpose=False): ] if len(node.inputs) == 2: - if support_cuda: + if sys_config.device == 'CUDA': output = tf.concat(convolved, axis=1) else: output = tf.concat(convolved, axis=-1) @@ -247,7 +246,7 @@ def conv(cls, node, input_dict, transpose=False): bias = input_dict[node.inputs[2]] bias = cls.explicit_broadcast([x, bias], compute_c_idx) - if support_cuda: + if sys_config.device == 'CUDA': output = tf.concat(convolved, axis=1) output = tf.add(output, bias) else: diff --git a/onnx_tf/handlers/backend/dilated_pooling.py b/onnx_tf/handlers/backend/dilated_pooling.py index a9e9a279a..7274c3df5 100644 --- a/onnx_tf/handlers/backend/dilated_pooling.py +++ b/onnx_tf/handlers/backend/dilated_pooling.py @@ -3,6 +3,8 @@ import tensorflow as tf import numpy as np +from onnx_tf.common import get_data_format +from onnx_tf.common import get_perm_from_formats from onnx_tf.common import pooling_helper from onnx_tf.common.tf_helper import tf_shape from onnx_tf.common.tf_helper import tf_product @@ -181,6 +183,10 @@ def __init__(self, else: self.padding_constant = 0 + self.storage_format, self.compute_format = get_data_format( + self.spatial_size + 2) + self.need_trans = self.storage_format != self.compute_format + def _calc_input_ind(self, output_ind, kernel, dilation, stride): """ This function maps index from the output of _remove_dilations @@ -228,11 +234,11 @@ def _calc_orig_argmax(self, ind): Maps indices generated by maxpool_with_argmax on top of the dilation reduced input to the orignal input indices - """ + """ - in_width = self.orig_input_shape[2] - num_channels = self.orig_input_shape[3] - output_width = self.output_shape[2] + in_width = self.orig_input_shape[3] + num_channels = self.orig_input_shape[1] + output_width = self.output_shape[3] # mod_floor op is not implemented on GPU # implement it using: a % b = a - (a // b) * b @@ -284,18 +290,12 @@ def _remove_dilations(self): the result is: [[ 10, 11], [ 14, 15]] - """ + """ input_shape = tf_shape(self.input) - in_spatial_shape = input_shape[1:self.spatial_size + 1] + in_spatial_shape = input_shape[2:] - channels_count = input_shape[self.spatial_size + 1] - # Initialize gather_ind with the range of channels - # e.g. [0 1] - gather_ind = tf.range(channels_count, dtype=tf.int64) - # convert the vector to column vector - # in the following logic we use column vectors - gather_ind = tf.expand_dims(gather_ind, 1) + channels_count = input_shape[1] # initilize the output_shape with zeros # self.output_shape will contain the shape of the @@ -366,15 +366,15 @@ def _remove_dilations(self): These are the indices used for gather_nd operation to collect the values from the input data. - """ + """ for dim in range(self.spatial_size - 1, -1, -1): filter_size = (self.kernel_shape[dim] - 1) * \ self.dilations[dim] + 1 output_size = (( - (in_spatial_shape[dim] - filter_size) // self.strides[dim]) + 1 - ) * self.kernel_shape[dim] - self.output_shape[dim + 1] = output_size + (in_spatial_shape[dim] - filter_size) // self.strides[dim]) + + 1) * self.kernel_shape[dim] + self.output_shape[dim + 2] = output_size # initialize the output dimension index with the range of the # dimension output size (e.g. 4): [0, 1, 2, 3] @@ -388,16 +388,25 @@ def _remove_dilations(self): # convert to column vector dim_ind = tf.expand_dims(dim_ind, 1) - # "combine" current dimension indices with the previous dimensions - # using cartesian product - gather_ind = tf_product(dim_ind, gather_ind) + if (dim == self.spatial_size - 1): + gather_ind = dim_ind + else: + # "combine" current dimension indices with the previous dimensions + # using cartesian product + gather_ind = tf_product(dim_ind, gather_ind) # The result from the above loop for 2D data will be: - # [[y1, x1, c], [y2, x2, c], ..., [yn, xm, c]] where n is the height, - # m is the width and c is the channel number. + # [[y1, x1], [y2, x2], ..., [yn, xm]] where n is the height, + # m is the width. # set the channels count in the output_shape - self.output_shape[self.spatial_size + 1] = channels_count + self.output_shape[1] = channels_count + # create the channel indices + channel_ind = tf.range(channels_count, dtype=tf.int64) + # convert to column vector + channel_ind = tf.expand_dims(channel_ind, 1) + # "combine" channel indices with the result from the loop + gather_ind = tf_product(channel_ind, gather_ind) # expand the dimensions to match the input dimensions + 1 for x in range(self.spatial_size): @@ -416,7 +425,7 @@ def _remove_dilations(self): def _calc_pads_same(self, in_spatial_shape): """ Calculate SAME_* paddings. - """ + """ pad_ops = pooling_helper.pad_numpy_ops if self.is_known_shape else \ pooling_helper.pad_tf_ops @@ -428,7 +437,7 @@ def _calc_pads_same(self, in_spatial_shape): def _calc_pads_explicit(self): """ Calculate explicit padding - """ + """ assert type(self.padding) is list pads = [] @@ -439,7 +448,7 @@ def _calc_pads_explicit(self): def _calc_pads_ceil_mode(self, in_spatial_shape): """ Calculate padding in ceil_mode - """ + """ pads = [] for i in range(self.spatial_size): @@ -481,7 +490,7 @@ def _calc_pads(self, in_spatial_shape): def _pad_input(self): """ Pad the input according to the parameters - """ + """ # check if we need to do any padding at all if not self.ceil_mode and ((type(self.padding) is list and self.padding == [0] * self.spatial_size * 2) or @@ -489,23 +498,23 @@ def _pad_input(self): self.pads = np.array([0] * self.spatial_size * 2) return (self.input, self.pads) - in_spatial_shape = self.input_shape[1:self.spatial_size + 1] + in_spatial_shape = self.input_shape[2:] pads = self._calc_pads(in_spatial_shape) if self.is_known_shape and np.count_nonzero(pads) == 0: self.pads = pads return (self.input, pads) - tf_paddings = [[0, 0]] + # no padding on the NC dimensions + tf_paddings = [[0, 0], [0, 0]] + # padding for the (D)HW dimensions for i in range(self.spatial_size): tf_paddings += [[pads[i * 2], pads[i * 2 + 1]]] - tf_paddings += [[0, 0]] - self.input = tf.pad( - self.input, - tf_paddings, - mode='CONSTANT', - constant_values=self.padding_constant) + self.input = tf.pad(self.input, + tf_paddings, + mode='CONSTANT', + constant_values=self.padding_constant) # update input shape and pads values self.input_shape = tf_shape(self.input) self.pads = pads @@ -513,10 +522,10 @@ def _pad_input(self): def _calc_argmax_without_padding(self, ind): """ Calculate the original indices as they would be without padding - """ - in_width = self.orig_input_shape[2] - padded_width = self.input_shape[2] - num_channels = self.input_shape[3] + """ + in_width = self.orig_input_shape[3] + padded_width = self.input_shape[3] + num_channels = self.input_shape[1] # mod_floor op is not implemented on GPU # implement it using: a % b = a - (a // b) * b @@ -535,21 +544,35 @@ def _calc_argmax_without_padding(self, ind): def dilated_maxpool_with_argmax(self, force_custom_impl=False): """ Do a dilated maxpool and return indices/argmax - """ + """ # Tensorflow does not support maxpool_with_argmax on # spatial_size != 2 assert self.spatial_size == 2 + # tf.nn.max_pool_with_argmax only support data_format='NHWC' + self.compute_format = 'NHWC' + self.need_trans = self.storage_format != self.compute_format + if list(self.dilations) != [1] * self.spatial_size or \ force_custom_impl: + # pad the input self._pad_input() new_input = self._remove_dilations() kernel_shape = [1] + list(self.kernel_shape) + [1] - pooled, new_ind = tf.nn.max_pool_with_argmax( - new_input, ksize=kernel_shape, strides=kernel_shape, padding="VALID") + + if self.need_trans: + new_input = tf.transpose(new_input, + perm=get_perm_from_formats( + self.storage_format, self.compute_format)) + + pooled, new_ind = tf.nn.max_pool_with_argmax(new_input, + ksize=kernel_shape, + strides=kernel_shape, + padding="VALID") new_ind = self._calc_orig_argmax(new_ind) + else: self.pads = np.array([0] * self.spatial_size * 2) if type(self.padding) is list or \ @@ -565,20 +588,34 @@ def dilated_maxpool_with_argmax(self, force_custom_impl=False): strides = [1] + list(self.strides) + [1] kernel_shape = [1] + list(self.kernel_shape) + [1] - pooled, new_ind = tf.nn.max_pool_with_argmax( - self.input, ksize=kernel_shape, strides=strides, padding=padding_) + + if self.need_trans: + self.input = tf.transpose(self.input, + perm=get_perm_from_formats( + self.storage_format, self.compute_format)) + + pooled, new_ind = tf.nn.max_pool_with_argmax(self.input, + ksize=kernel_shape, + strides=strides, + padding=padding_) + # if there was padding, recalculate the returned index # to exclude the padding - if np.count_nonzero(self.pads) != 0: + count_nonzero_op = np.count_nonzero if self.is_known_shape else tf.math.count_nonzero + if count_nonzero_op(self.pads) != 0: new_ind = self._calc_argmax_without_padding(new_ind) return (pooled, new_ind) - def _lp_pool(self, input, ksize, strides, padding): + def _lp_pool(self, input, ksize, strides, padding, data_format): window_size = np.prod(ksize) input = tf.math.pow(tf.math.abs(input), self.p) * window_size - pooled = tf.nn.avg_pool(input, ksize=ksize, strides=strides, padding=padding) + pooled = tf.nn.avg_pool(input, + ksize=ksize, + strides=strides, + padding=padding, + data_format=data_format) pooled = tf.math.pow(pooled, 1.0 / self.p) return pooled @@ -587,7 +624,7 @@ def dilated_pool(self, force_custom_impl=False): """ Does N-D dilated max/avg pooling. Pads the input if explicit or SAME_* padding is provided or ceil_mode is True - """ + """ assert self.is_supported() @@ -611,29 +648,42 @@ def dilated_pool(self, force_custom_impl=False): strides = [1] + list(self.strides) + [1] dilations = [1] + list(self.dilations) + [1] + # tf.nn.dilation2d only support data_format='NHWC' + self.compute_format = 'NHWC' + self.need_trans = self.storage_format.startswith("NC") + if self.need_trans: + self.input = tf.transpose(self.input, + perm=get_perm_from_formats( + self.storage_format, self.compute_format)) + filter = tf.zeros( - [self.kernel_shape[0], self.kernel_shape[1], self.input_shape[3]], + [self.kernel_shape[0], self.kernel_shape[1], self.input_shape[1]], self.input.dtype) - pooled = tf.nn.dilation2d( - input=self.input, - filters=filter, - strides=strides, - dilations=dilations, - padding=padding_, - data_format="NHWC") + pooled = tf.nn.dilation2d(input=self.input, + filters=filter, + strides=strides, + dilations=dilations, + padding=padding_, + data_format="NHWC") # if spatial_size < 4 and strides == 1 or dilation == 1 use tf.nn.pool elif self.spatial_size < 4 and (self.strides == [1] * self.spatial_size or self.dilations == [1] * self.spatial_size) and \ not force_custom_impl: + + if self.need_trans: + self.input = tf.transpose(self.input, + perm=get_perm_from_formats( + self.storage_format, self.compute_format)) + # if strides == 1 and not LpPool use tf.nn.pool directly if self.strides == [1] * self.spatial_size and self.pooling_type != "LP": - pooled = tf.nn.pool( - self.input, - window_shape=self.kernel_shape, - dilations=self.dilations, - strides=self.strides, - padding=padding_, - pooling_type=self.pooling_type) + pooled = tf.nn.pool(self.input, + window_shape=self.kernel_shape, + dilations=self.dilations, + strides=self.strides, + padding=padding_, + pooling_type=self.pooling_type, + data_format=self.compute_format) else: # othwerwise check the pooling_type and use the correct op if self.pooling_type.startswith("MAX"): @@ -645,8 +695,11 @@ def dilated_pool(self, force_custom_impl=False): else: raise ValueError("%d-D %s pooling is not supported." % (self.spatial_size, self.pooling_type)) - pooled = op(self.input, ksize=self.kernel_shape, strides=self.strides, - padding=padding_) + pooled = op(self.input, + ksize=self.kernel_shape, + strides=self.strides, + padding=padding_, + data_format=self.compute_format) # in any other case we use custom implementation _remove_dilations # to reduce atrous/dilated pooling into regular pooling and selecting # only the values of the input that should have been selected by @@ -657,27 +710,34 @@ def dilated_pool(self, force_custom_impl=False): # pad the input self._pad_input() input_ = self._remove_dilations() - if self.pooling_type=="LP": - pooled = self._lp_pool( - input_, - ksize=self.kernel_shape, - strides=self.kernel_shape, - padding="VALID") + + if self.need_trans: + input_ = tf.transpose(input_, + perm=get_perm_from_formats( + self.storage_format, self.compute_format)) + + if self.pooling_type == "LP": + pooled = self._lp_pool(input_, + ksize=self.kernel_shape, + strides=self.kernel_shape, + padding="VALID", + data_format=self.compute_format) else: - pooled = tf.nn.pool( - input_, - window_shape=self.kernel_shape, - strides=self.kernel_shape, - padding="VALID", - pooling_type=self.pooling_type) + pooled = tf.nn.pool(input_, + window_shape=self.kernel_shape, + strides=self.kernel_shape, + padding="VALID", + pooling_type=self.pooling_type, + data_format=self.compute_format) + return pooled def is_supported(self): """ Function to check if the current set of arguments are supported for average pool - """ + """ # check for maxpool if self.pooling_type.startswith("MAX") or \ self.pooling_type=="LP": diff --git a/onnx_tf/handlers/backend/einsum.py b/onnx_tf/handlers/backend/einsum.py new file mode 100644 index 000000000..8aecdd3fd --- /dev/null +++ b/onnx_tf/handlers/backend/einsum.py @@ -0,0 +1,14 @@ +import tensorflow as tf + +from onnx_tf.handlers.backend_handler import BackendHandler +from onnx_tf.handlers.handler import onnx_op + +@onnx_op("Einsum") + +class Einsum(BackendHandler): + + @classmethod + def version_12(cls, node, **kwargs): + equation = node.attrs.get("equation", "") + inputs = [kwargs["tensor_dict"][inp] for inp in node.inputs] + return [tf.einsum(equation, *inputs)] diff --git a/onnx_tf/handlers/backend/identity.py b/onnx_tf/handlers/backend/identity.py index b974b2fbf..3a0336f62 100644 --- a/onnx_tf/handlers/backend/identity.py +++ b/onnx_tf/handlers/backend/identity.py @@ -12,3 +12,7 @@ class Identity(BackendHandler): @classmethod def version_1(cls, node, **kwargs): return [cls.make_tensor_from_onnx_node(node, **kwargs)] + + @classmethod + def version_13(cls, node, **kwargs): + return [cls.make_tensor_from_onnx_node(node, **kwargs)] diff --git a/onnx_tf/handlers/backend/is_nan.py b/onnx_tf/handlers/backend/is_nan.py index 5517d64ff..2328ef76f 100644 --- a/onnx_tf/handlers/backend/is_nan.py +++ b/onnx_tf/handlers/backend/is_nan.py @@ -12,3 +12,7 @@ class IsNaN(BackendHandler): @classmethod def version_9(cls, node, **kwargs): return [cls.make_tensor_from_onnx_node(node, **kwargs)] + + @classmethod + def version_13(cls, node, **kwargs): + return [cls.make_tensor_from_onnx_node(node, **kwargs)] diff --git a/onnx_tf/handlers/backend/pool_mixin.py b/onnx_tf/handlers/backend/pool_mixin.py index 2cd151b74..b81ad61f9 100644 --- a/onnx_tf/handlers/backend/pool_mixin.py +++ b/onnx_tf/handlers/backend/pool_mixin.py @@ -1,13 +1,15 @@ import tensorflow as tf from onnx_tf.common import exception -from onnx_tf.common import get_data_format from onnx_tf.common import get_perm_from_formats -from onnx_tf.common import logger -from .dilated_pooling import DilatedPooling +from onnx_tf.common import logger +from onnx_tf.common import sys_config from onnx_tf.common.pooling_helper import py_pool from onnx_tf.common.pooling_helper import calc_pads_same from onnx_tf.common.pooling_helper import calc_output_shape +from onnx_tf.common.tf_helper import tf_shape +from .dilated_pooling import DilatedPooling + class PoolMixin(object): @@ -15,7 +17,6 @@ class PoolMixin(object): @tf.autograph.experimental.do_not_convert() def pool(cls, node, input_dict, pooling_type, strict=True): x = input_dict[node.inputs[0]] - orig_x = x kernel_shape = node.attrs["kernel_shape"] @@ -35,8 +36,8 @@ def pool(cls, node, input_dict, pooling_type, strict=True): # SAME padding in Tensorflow if x.shape.is_fully_defined() and pads != [0] * spatial_size * 2: in_shape = x.get_shape() - same_paddings = calc_pads_same(in_shape[1:x_rank-1], kernel_shape, - strides, dilations, "SAME_UPPER") + same_paddings = calc_pads_same(in_shape[1:x_rank - 1], kernel_shape, + strides, dilations, "SAME_UPPER") if pads == same_paddings: pads = "SAME_UPPER" @@ -60,46 +61,47 @@ def pool(cls, node, input_dict, pooling_type, strict=True): exception.OP_UNSUPPORTED_EXCEPT(pooling_name + " with column major", "Tensorflow") - storage_format, _ = get_data_format(x_rank) - - need_trans = storage_format.startswith("NC") - if need_trans: - compute_format = "N" + storage_format[2:] + "C" - x = tf.transpose( - x, perm=get_perm_from_formats(storage_format, compute_format)) - - dp = DilatedPooling( - input=x, - kernel_shape=kernel_shape, - strides=strides, - dilations=dilations, - padding=pads, - ceil_mode=ceil_mode, - pooling_type=pooling_type, - count_include_pad=count_include_pad, - p=p) + x_dtype = x.dtype + # For max_pool and max_pool_with_argmax tensoflow don't support + # NCHW data format input in int8 or uint8 datatype, therefore + # need to cast to float16 in order to run with NCHW data format + need_cast = pooling_type in [ + 'MAX', 'MAX_WITH_ARGMAX' + ] and sys_config.device == 'CUDA' and x_dtype in [tf.int8, tf.uint8] + x = tf.cast(x, tf.float16) if need_cast else x + + dp = DilatedPooling(input=x, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + padding=pads, + ceil_mode=ceil_mode, + pooling_type=pooling_type, + count_include_pad=count_include_pad, + p=p) if not dp.is_supported(): if strict: logger.warning("Using the pooling op in compatibility mode. " - "This means your graph cannot be serialized.") + "This means your graph cannot be serialized.") result = tf.numpy_function(py_pool, [ - orig_x, kernel_shape, strides, dilations, pads, ceil_mode, - pooling_type, False - ], orig_x.dtype) - - if orig_x.shape.is_fully_defined(): - shape = orig_x.get_shape() - output_shape = shape[0:2] + calc_output_shape(shape[2:x_rank], - kernel_shape, strides, dilations, pads, ceil_mode) + x, kernel_shape, strides, dilations, pads, ceil_mode, pooling_type, + False + ], x.dtype) + + if x.shape.is_fully_defined(): + shape = x.get_shape() + output_shape = shape[0:2] + calc_output_shape( + shape[2:x_rank], kernel_shape, strides, dilations, pads, + ceil_mode) else: output_shape = [None] * x_rank result.set_shape(output_shape) return [result] else: - exception.OP_UNSUPPORTED_EXCEPT("strict == 0 and " + pooling_name + - " arguments not compatible", - "Tensorflow") + exception.OP_UNSUPPORTED_EXCEPT( + "strict == 0 and " + pooling_name + " arguments not compatible", + "Tensorflow") from absl import logging logging.set_verbosity(logging.INFO) @@ -107,20 +109,67 @@ def pool(cls, node, input_dict, pooling_type, strict=True): def dilated_pool(): return (dp.dilated_pool(), None) - def postprocess(pooled, argmax, perm): - return (tf.transpose(pooled, perm=perm) if need_trans else pooled, - tf.transpose(argmax, perm=perm) if need_trans and argmax - is not None else argmax) - # select correct op depending on the pooling type pooling_op = dilated_pool if pooling_type in ["MAX", "AVG", "LP"] else \ dp.dilated_maxpool_with_argmax - # select the correct transpose ops depending on the input storage format - perm = get_perm_from_formats(compute_format, storage_format) + def postprocess(pooled, argmax): + + def convert_NHWC_indices_to_NCHW_indices(argmax): + # i - index in NCHW + # I - index in NHWC + # C - number of channels + # b - batch = I // CHW + # c - channel = I % C + # H - height + # W - weight + # I = i - c(HW - 1) + (C - 1)(i - bCHW - cHW) + # i = (I + c(HW - 1) + (C - 1)(bCHW + cHW))/C + + # x_shape will always be in NCHW format here, + # because maxpool_with_argmax only support 2d input + x_shape = tf_shape(x) + N = x_shape[0] + C = x_shape[1] + H = x_shape[2] + W = x_shape[3] + HW = tf.math.multiply(H, W) + CHW = tf.math.multiply(C, HW) + argmax_b = tf.math.floordiv(argmax, CHW) + argmax_c = tf.math.floormod(argmax, C) + new_ind = tf.math.add( + argmax, tf.math.multiply(argmax_c, tf.math.subtract(HW, 1))) + new_ind = tf.math.add( + new_ind, + tf.math.multiply( + tf.math.subtract(C, 1), + tf.math.add(tf.math.multiply(argmax_b, CHW), + tf.math.multiply(argmax_c, HW)))) + new_ind = tf.math.floordiv(new_ind, C) + + # add batch dimension into the argmax index + batch_offsets = tf.math.multiply(tf.range(N, dtype=new_ind.dtype), CHW) + for _ in range(new_ind.shape.rank - 1): + batch_offsets = tf.expand_dims(batch_offsets, -1) + new_ind = tf.math.add(new_ind, batch_offsets) + + return new_ind + + if argmax is not None: + argmax = convert_NHWC_indices_to_NCHW_indices(argmax) + + # select the correct transpose ops depending on the input storage format + perm = get_perm_from_formats(dp.compute_format, dp.storage_format) + + pooled = tf.transpose(pooled, perm=perm) if dp.need_trans else pooled + pooled = tf.cast(pooled, x_dtype) if need_cast else pooled + argmax = tf.transpose( + argmax, perm=perm) if dp.need_trans and argmax is not None else argmax + + return pooled, argmax pooled, argmax = pooling_op() - pooled, argmax = postprocess(pooled, argmax, perm) + pooled, argmax = postprocess(pooled, argmax) result = [pooled] if argmax is None else [pooled, argmax] diff --git a/onnx_tf/handlers/backend/reduce_max.py b/onnx_tf/handlers/backend/reduce_max.py index 768de25ef..49f413c91 100644 --- a/onnx_tf/handlers/backend/reduce_max.py +++ b/onnx_tf/handlers/backend/reduce_max.py @@ -21,3 +21,7 @@ def version_11(cls, node, **kwargs): @classmethod def version_12(cls, node, **kwargs): return cls._common(node, **kwargs) + + @classmethod + def version_13(cls, node, **kwargs): + return cls._common(node, **kwargs) diff --git a/onnx_tf/handlers/backend/reduce_mean.py b/onnx_tf/handlers/backend/reduce_mean.py index cf3bf2a6c..71979287b 100644 --- a/onnx_tf/handlers/backend/reduce_mean.py +++ b/onnx_tf/handlers/backend/reduce_mean.py @@ -17,3 +17,7 @@ def version_1(cls, node, **kwargs): @classmethod def version_11(cls, node, **kwargs): return cls._common(node, **kwargs) + + @classmethod + def version_13(cls, node, **kwargs): + return cls._common(node, **kwargs) diff --git a/onnx_tf/handlers/backend/reduce_min.py b/onnx_tf/handlers/backend/reduce_min.py index 839cc3d3e..cc5729604 100644 --- a/onnx_tf/handlers/backend/reduce_min.py +++ b/onnx_tf/handlers/backend/reduce_min.py @@ -21,3 +21,7 @@ def version_11(cls, node, **kwargs): @classmethod def version_12(cls, node, **kwargs): return cls._common(node, **kwargs) + + @classmethod + def version_13(cls, node, **kwargs): + return cls._common(node, **kwargs) diff --git a/onnx_tf/handlers/backend/reduce_prod.py b/onnx_tf/handlers/backend/reduce_prod.py index 5d3ea7e09..d04f8edff 100644 --- a/onnx_tf/handlers/backend/reduce_prod.py +++ b/onnx_tf/handlers/backend/reduce_prod.py @@ -17,3 +17,7 @@ def version_1(cls, node, **kwargs): @classmethod def version_11(cls, node, **kwargs): return cls._common(node, **kwargs) + + @classmethod + def version_13(cls, node, **kwargs): + return cls._common(node, **kwargs) diff --git a/onnx_tf/handlers/backend/unpool_mixin.py b/onnx_tf/handlers/backend/unpool_mixin.py index efd54f3ba..81f94a725 100644 --- a/onnx_tf/handlers/backend/unpool_mixin.py +++ b/onnx_tf/handlers/backend/unpool_mixin.py @@ -1,7 +1,5 @@ import tensorflow as tf -from onnx_tf.common import get_data_format -from onnx_tf.common import get_perm_from_formats from onnx_tf.common.tf_helper import tf_shape @@ -12,7 +10,7 @@ class UnpoolMixin(object): def max_unpool(cls, node, input_dict): """ MaxUnpooling operation - """ + """ x = input_dict[node.inputs[0]] ind = input_dict[node.inputs[1]] if len(node.inputs) > 2: @@ -23,8 +21,6 @@ def max_unpool(cls, node, input_dict): kernel_shape = node.attrs["kernel_shape"] spatial_size = len(kernel_shape) - x_rank = spatial_size + 2 - storage_format, _ = get_data_format(x_rank) # if strides are not provided default is 1 along each spatial axis strides = node.attrs.get("strides", [1] * spatial_size) @@ -32,23 +28,10 @@ def max_unpool(cls, node, input_dict): input_shape = tf_shape(x) default_shape = cls._get_default_shape(input_shape, kernel_shape, strides) - - need_trans = storage_format != "NHWC" - if need_trans: - x = tf.transpose(x, perm=get_perm_from_formats(storage_format, "NHWC")) - ind = tf.transpose( - ind, perm=get_perm_from_formats(storage_format, "NHWC")) - - # default_shape to NHWC storage format - default_shape = [input_shape[0]] + default_shape + \ - [input_shape[1]] + default_shape = [input_shape[0]] + [input_shape[1]] + default_shape unpooled = cls._unpool(x, ind, default_shape) - if need_trans: - unpooled = tf.transpose( - unpooled, perm=get_perm_from_formats("NHWC", storage_format)) - if output_shape is not None: pads = cls._get_pads_from_output_shape(unpooled, output_shape) if pads is not None: @@ -66,11 +49,11 @@ def _get_default_shape(cls, input_shape, kernel_shape, strides): output_shape: stride along each spatial axis Return: default_shape: calculated default_shape - """ + """ default_shape = [] for d in range(len(kernel_shape)): - default_shape.append(( - input_shape[d + 2] - 1) * int(strides[d]) + int(kernel_shape[d])) + default_shape.append((input_shape[d + 2] - 1) * int(strides[d]) + + int(kernel_shape[d])) return default_shape @classmethod @@ -85,7 +68,7 @@ def _get_pads_from_output_shape(cls, unpool, output_shape): [x1_begin, x2_begin,.., x1_end, x2_end] where xi_... represent pads added to begin or end of axis i - """ + """ unpool_shape = tf.cast(tf.shape(unpool), dtype=tf.int32) new_shape = tf.cast(output_shape, dtype=tf.int32) @@ -113,13 +96,15 @@ def _pad_output(cls, unpool, pads, constant_values): constant_values: constant value to fill up the padded spaces Return: padded: padded tensor - """ + """ unpool_shape = unpool.get_shape() paddings = [] for d in range(len(unpool_shape)): paddings = paddings + [[pads[d], pads[d + len(unpool_shape)]]] - padded = tf.pad( - unpool, paddings, 'CONSTANT', constant_values=constant_values) + padded = tf.pad(unpool, + paddings, + 'CONSTANT', + constant_values=constant_values) return padded @classmethod @@ -133,25 +118,18 @@ def _unpool(cls, pool, ind, output_shape, scope='unpool'): output_shape: the shape of the output Return: unpool: unpooling tensor - """ + """ with tf.compat.v1.variable_scope(scope): input_shape = tf.shape(pool) flat_input_size = tf.reduce_prod(input_shape) - flat_output_shape = [ - output_shape[0], output_shape[1] * output_shape[2] * output_shape[3] - ] + flat_output_shape = [tf.reduce_prod(output_shape)] pool_ = tf.reshape(pool, [flat_input_size]) - batch_range = tf.reshape( - tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype), - shape=[input_shape[0], 1, 1, 1]) - b = tf.ones_like(ind) * batch_range - b1 = tf.reshape(b, [flat_input_size, 1]) ind_ = tf.reshape(ind, [flat_input_size, 1]) - ind_ = tf.concat([b1, ind_], 1) - ret = tf.scatter_nd( - ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64)) + ret = tf.scatter_nd(ind_, + pool_, + shape=tf.cast(flat_output_shape, tf.int64)) ret = tf.reshape(ret, output_shape) return ret diff --git a/onnx_tf/handlers/backend/upsample.py b/onnx_tf/handlers/backend/upsample.py index 6f48198cd..b775e8c15 100644 --- a/onnx_tf/handlers/backend/upsample.py +++ b/onnx_tf/handlers/backend/upsample.py @@ -1,6 +1,5 @@ import copy -import numpy as np import tensorflow as tf from onnx_tf.common import exception @@ -33,20 +32,28 @@ def args_check(cls, node, **kwargs): @classmethod def version_7(cls, node, **kwargs): x = kwargs["tensor_dict"][node.inputs[0]] - x_shape = x.get_shape().as_list() + x_shape = tf_shape(x) attrs = copy.deepcopy(node.attrs) scales = attrs["scales"] - new_height = np.floor(x_shape[2] * scales[2]) - new_weight = np.floor(x_shape[3] * scales[3]) - mode = attrs.get("mode", "nearest") - if mode.lower() == "bilinear" or mode.lower() == "linear": - mode = tf.image.ResizeMethod.BILINEAR - else: - mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR + assert_n_c_scale_is_one = tf.Assert( + tf.logical_and(tf.equal(scales[0], 1), tf.equal(scales[1], 1)), + [scales]) + + with tf.control_dependencies([assert_n_c_scale_is_one]): + h_w_scale = scales[2:] + h_w_shape = x_shape[2:] + new_h_w_shape = tf.cast(h_w_scale * tf.cast(h_w_shape, type(h_w_scale[0])), + tf.int32) + + mode = attrs.get("mode", "nearest") + if mode.lower() == "bilinear" or mode.lower() == "linear": + mode = tf.image.ResizeMethod.BILINEAR + else: + mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR - attrs["size"] = np.array((new_height, new_weight), dtype=np.int32) - attrs["method"] = mode + attrs["size"] = new_h_w_shape + attrs["method"] = mode return [ cls.make_tensor_from_onnx_node( diff --git a/onnx_tf/handlers/backend_handler.py b/onnx_tf/handlers/backend_handler.py index 40d644cfd..6ad8db57e 100644 --- a/onnx_tf/handlers/backend_handler.py +++ b/onnx_tf/handlers/backend_handler.py @@ -11,7 +11,7 @@ from onnx_tf.common import IS_PYTHON3 from onnx_tf.common import get_data_format from onnx_tf.common import get_perm_from_formats -from onnx_tf.common import supports_device +from onnx_tf.common import sys_config from .handler import Handler @@ -120,8 +120,7 @@ def c_first_cuda_only(cls, tf_func, inputs, attrs): :param attrs: Attributes. :return: Tensor. """ - support_cuda = supports_device("CUDA") - if not support_cuda: + if sys_config.device == 'CPU': return cls._tuck_transpose(tf_func, inputs, attrs) return cls._run_tf_func(tf_func, inputs, attrs) diff --git a/onnx_tf/opset_version.py b/onnx_tf/opset_version.py index 5dd0f2cca..87fc5b7cd 100644 --- a/onnx_tf/opset_version.py +++ b/onnx_tf/opset_version.py @@ -42,7 +42,7 @@ 'Div': [1, 6, 7], 'Dropout': [1, 6, 7, 10, 12, 13], 'DynamicQuantizeLinear': [11], - 'Einsum': [], + 'Einsum': [12], 'Elu': [1, 6], 'Equal': [1, 7, 11, 13], 'Erf': [9, 13], @@ -65,13 +65,13 @@ 'GreaterOrEqual': [12], 'HardSigmoid': [1, 6], 'Hardmax': [1, 11], - 'Identity': [1], + 'Identity': [1, 13], 'If': [1, 11, 13], 'ImageScaler': [1], 'Imputer': [], 'InstanceNormalization': [1, 6], 'IsInf': [10], - 'IsNaN': [9], + 'IsNaN': [9, 13], 'LRN': [1, 13], 'LSTM': [1, 7], 'LabelEncoder': [], @@ -124,10 +124,10 @@ 'ReduceL2': [1, 11], 'ReduceLogSum': [1, 11], 'ReduceLogSumExp': [1, 11], - 'ReduceMax': [1, 11, 12], - 'ReduceMean': [1, 11], - 'ReduceMin': [1, 11, 12], - 'ReduceProd': [1, 11], + 'ReduceMax': [1, 11, 12, 13], + 'ReduceMean': [1, 11, 13], + 'ReduceMin': [1, 11, 12, 13], + 'ReduceProd': [1, 11, 13], 'ReduceSum': [1, 11], 'ReduceSumSquare': [1, 11], 'Relu': [1, 6], diff --git a/test/backend/test_dynamic_shape.py b/test/backend/test_dynamic_shape.py index c3185261c..e380b3ca9 100644 --- a/test/backend/test_dynamic_shape.py +++ b/test/backend/test_dynamic_shape.py @@ -530,6 +530,56 @@ def test_max_pool_2d_dilations_ceil_pads(self): np.testing.assert_almost_equal(output["Y"], test_output) + def test_max_pool_with_argmax_2d_dilations_ceil_pads(self): + if legacy_opset_pre_ver(10): + raise unittest.SkipTest( + "ONNX version {} doesn't support dilations nor ceil mode.".format( + defs.onnx_opset_version())) + + kernel_shape = [3, 3] + strides = [2, 2] + dilations = [3, 3] + pads = [1, 1, 2, 2] + ceil_mode = True + + input_shape = [10, 3, 23, 23] + x = self._get_rnd_float32(shape=input_shape) - 2 + + node_def = helper.make_node("MaxPool", ["X"], ["Y", "Ind"], + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + pads=pads, + ceil_mode=ceil_mode) + + graph_def = helper.make_graph( + [node_def], + name="test_unknown_shape", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, + [None, None, None, None]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, + [None, None, None, None]), + helper.make_tensor_value_info("Ind", TensorProto.INT64, + [None, None, None, None]) + ]) + + tf_rep = onnx_graph_to_tensorflow_rep(graph_def) + output = tf_rep.run({"X": x}) + + test_output, test_ind = py_pool(x, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + padding=pads, + ceil_mode=ceil_mode, + pooling_type="MAX") + + np.testing.assert_almost_equal(output["Y"], test_output) + np.testing.assert_almost_equal(output["Ind"], test_ind) + def test_average_pool_2d(self): kernel_shape = [1, 2] strides = [1, 2] diff --git a/test/backend/test_node.py b/test/backend/test_node.py index 4c901cc0c..78fded8fe 100644 --- a/test/backend/test_node.py +++ b/test/backend/test_node.py @@ -3,13 +3,12 @@ from __future__ import print_function from __future__ import unicode_literals -import sys import math import unittest +from onnx import defs from onnx import helper from onnx import TensorProto -from onnx import defs import numpy as np import tensorflow as tf @@ -24,6 +23,10 @@ class TestNode(unittest.TestCase): """ Tests for nodes """ + def _get_device_list(self): + # Check does the environment support CUDA. + return ['CPU', 'CUDA'] if supports_device("CUDA") else ['CPU'] + def _get_rnd_float32(self, low=-1.0, high=1.0, shape=None): output = np.random.uniform(low, high, shape) if shape is None: @@ -259,13 +262,13 @@ def test_ceil(self): def test_celu(self): if legacy_opset_pre_ver(12): - raise unittest.SkipTest( - "ONNX version {} doesn't support Celu.".format( - defs.onnx_opset_version())) + raise unittest.SkipTest("ONNX version {} doesn't support Celu.".format( + defs.onnx_opset_version())) alpha = 2.0 node_def = helper.make_node("Celu", ["X"], ["Y"], alpha=alpha) - x = np.array([[[-1.0763247, 0.98948643, 0.22292195], - [ 0.1751388, -1.39814249, 1.44396422]]], dtype=np.float32) + x = np.array([[[-1.0763247, 0.98948643, 0.22292195], + [0.1751388, -1.39814249, 1.44396422]]], + dtype=np.float32) output = run_node(node_def, [x]) positive_input = np.maximum(0, x) negative_input = np.minimum(0, alpha * (np.exp(x / alpha) - 1)) @@ -395,39 +398,38 @@ def test_constant_of_shape(self): np.testing.assert_almost_equal(output["Y"], np.zeros(x, dtype=np.int32)) def test_conv(self): - device = "CUDA" if supports_device("CUDA") else "CPU" - - N, C, H, W = 4, 3, 5, 5 - x_shape = [N, C, H, W] - K, kH, kW = 6, 3, 3 - weight_shape = [K, C, kH, kW] - node_def = helper.make_node("Conv", ["X", "weights"], ["Y"], - pads=[1, 1, 1, 1], - kernel_shape=[kH, kW]) - - x = self._get_rnd_float32(shape=x_shape) - weights = self._get_rnd_float32(shape=weight_shape) - output = run_node(node_def, [x, weights], device=device) - - out_shape = [N, K, H, W] - test_output = np.zeros(out_shape) - for n in range(N): - for c in range(C): - for h in range(H): - for w in range(W): - for k in range(K): - for kh in range(kH): - for kw in range(kW): - h_in_range = (h - kH // 2 + kh) < H and (h - kH // 2 + - kh) >= 0 - w_in_range = (w - kW // 2 + kw) < W and (w - kW // 2 + - kw) >= 0 - if h_in_range and w_in_range: - test_output[n][k][h][w] += ( - x[n][c][h - kH // 2 + kh][w - kW // 2 + kw] * - weights[k][c][kh][kw]) - - np.testing.assert_almost_equal(output["Y"], test_output, decimal=4) + for device in self._get_device_list(): + N, C, H, W = 4, 3, 5, 5 + x_shape = [N, C, H, W] + K, kH, kW = 6, 3, 3 + weight_shape = [K, C, kH, kW] + node_def = helper.make_node("Conv", ["X", "weights"], ["Y"], + pads=[1, 1, 1, 1], + kernel_shape=[kH, kW]) + + x = self._get_rnd_float32(shape=x_shape) + weights = self._get_rnd_float32(shape=weight_shape) + output = run_node(node_def, [x, weights], device=device) + + out_shape = [N, K, H, W] + test_output = np.zeros(out_shape) + for n in range(N): + for c in range(C): + for h in range(H): + for w in range(W): + for k in range(K): + for kh in range(kH): + for kw in range(kW): + h_in_range = (h - kH // 2 + kh) < H and (h - kH // 2 + + kh) >= 0 + w_in_range = (w - kW // 2 + kw) < W and (w - kW // 2 + + kw) >= 0 + if h_in_range and w_in_range: + test_output[n][k][h][w] += ( + x[n][c][h - kH // 2 + kh][w - kW // 2 + kw] * + weights[k][c][kh][kw]) + + np.testing.assert_almost_equal(output["Y"], test_output, decimal=4) def test_conv_integer(self): if legacy_opset_pre_ver(10): @@ -435,135 +437,136 @@ def test_conv_integer(self): "ONNX version {} doesn't support ConvInteger.".format( defs.onnx_opset_version())) - # Test w_zero_point - x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( - (1, 1, 3, 3)) - w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) - w_zero_point = np.int8(1) - y = np.array([16, 20, 28, 32]).astype(np.int32).reshape((1, 1, 2, 2)) - - node = helper.make_node("ConvInteger", - ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], - kernel_shape=[2, 2], - pads=[0, 0, 0, 0], - dilations=[1, 1]) - output = run_node(node, [x, w, np.int8(0), w_zero_point]) - np.testing.assert_almost_equal(output["Y"], y) + for device in self._get_device_list(): + # Test w_zero_point + x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( + (1, 1, 3, 3)) + w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) + w_zero_point = np.int8(1) + y = np.array([16, 20, 28, 32]).astype(np.int32).reshape((1, 1, 2, 2)) + + node = helper.make_node("ConvInteger", + ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + dilations=[1, 1]) + output = run_node(node, [x, w, np.int8(0), w_zero_point], device=device) + np.testing.assert_almost_equal(output["Y"], y) - # Test x_zero_point and w_zero_point - x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( - (1, 1, 3, 3)) - x_zero_point = np.int8(1) - w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) - w_zero_point = np.int8(1) - y = np.array([12, 16, 24, 28]).astype(np.int32).reshape((1, 1, 2, 2)) - - node = helper.make_node("ConvInteger", - ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], - kernel_shape=[2, 2], - pads=[0, 0, 0, 0], - dilations=[1, 1]) - output = run_node(node, [x, w, x_zero_point, w_zero_point]) - np.testing.assert_almost_equal(output["Y"], y) + # Test x_zero_point and w_zero_point + x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( + (1, 1, 3, 3)) + x_zero_point = np.int8(1) + w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) + w_zero_point = np.int8(1) + y = np.array([12, 16, 24, 28]).astype(np.int32).reshape((1, 1, 2, 2)) + + node = helper.make_node("ConvInteger", + ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + dilations=[1, 1]) + output = run_node(node, [x, w, x_zero_point, w_zero_point], device=device) + np.testing.assert_almost_equal(output["Y"], y) - # Test w_zero_point as 1d tensor - x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( - (1, 1, 3, 3)) - w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) - w_zero_point = np.array([1]).astype(np.int8) - y = np.array([16, 20, 28, 32]).astype(np.int32).reshape((1, 1, 2, 2)) - - node = helper.make_node("ConvInteger", - ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], - kernel_shape=[2, 2], - pads=[0, 0, 0, 0], - dilations=[1, 1]) - output = run_node(node, [x, w, np.int8(0), w_zero_point]) - np.testing.assert_almost_equal(output["Y"], y) + # Test w_zero_point as 1d tensor + x = np.array([2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int8).reshape( + (1, 1, 3, 3)) + w = np.array([2, 2, 2, 2]).astype(np.int8).reshape((1, 1, 2, 2)) + w_zero_point = np.array([1]).astype(np.int8) + y = np.array([16, 20, 28, 32]).astype(np.int32).reshape((1, 1, 2, 2)) + + node = helper.make_node("ConvInteger", + ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + dilations=[1, 1]) + output = run_node(node, [x, w, np.int8(0), w_zero_point], device=device) + np.testing.assert_almost_equal(output["Y"], y) - # Test w_zero_point as 1d tensor shape 2 - x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]).astype(np.int8).reshape( - (1, 1, 3, 3)) - w = np.array([2, 2, 2, 2, 2, 2, 2, 2]).astype(np.int8).reshape((2, 1, 2, 2)) - w_zero_point = np.array([1, 2]).astype(np.int8) - y = np.array([12, 16, 24, 28, 0, 0, 0, 0]).astype(np.int32).reshape( - (1, 2, 2, 2)) - - node = helper.make_node("ConvInteger", - ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], - kernel_shape=[2, 2], - pads=[0, 0, 0, 0], - dilations=[1, 1]) - output = run_node(node, [x, w, np.int8(0), w_zero_point]) - np.testing.assert_almost_equal(output["Y"], y) + # Test w_zero_point as 1d tensor shape 2 + x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]).astype(np.int8).reshape( + (1, 1, 3, 3)) + w = np.array([2, 2, 2, 2, 2, 2, 2, 2]).astype(np.int8).reshape( + (2, 1, 2, 2)) + w_zero_point = np.array([1, 2]).astype(np.int8) + y = np.array([12, 16, 24, 28, 0, 0, 0, 0]).astype(np.int32).reshape( + (1, 2, 2, 2)) + + node = helper.make_node("ConvInteger", + ["X", "W", "x_zero_point", "w_zero_point"], ["Y"], + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + dilations=[1, 1]) + output = run_node(node, [x, w, np.int8(0), w_zero_point], device=device) + np.testing.assert_almost_equal(output["Y"], y) def test_conv_transpose(self): - device = "CUDA" if supports_device("CUDA") else "CPU" - - pads = [1, 1] - node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"], - pads=pads) - x_shape = [1, 3, 4] - x = self._get_rnd_float32(shape=x_shape) - weight_shape = [3, 5, 2] - weights = self._get_rnd_float32(shape=weight_shape) - output = run_node(node_def, [x, weights], device=device) - - padh_left = weight_shape[2] - 1 - pads[0] - padh_right = weight_shape[2] - 1 - pads[1] - kh = weight_shape[2] - outh = x_shape[2] + padh_right + padh_right - (kh - 1) - - out_shape = [x_shape[0], weight_shape[1], outh] - - test_output = np.zeros(out_shape) - for b in range(0, x_shape[0]): - for m in range(0, weight_shape[1]): - for c in range(0, x_shape[1]): - for h in range(0, outh): - for k in range(h, h + kh): - if (k - padh_left >= 0): - test_output[b][m][h] += x[b][c][k - padh_left] * weights[c][m][ - kh + h - 1 - k] - - np.testing.assert_almost_equal(output["Y"], test_output, decimal=5) - - # test for spatial dimension of colnolution is 2 - pads = [1, 1, 1, 1] - node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"], - pads=pads) - x_shape = [1, 3, 4, 6] - x = self._get_rnd_float32(shape=x_shape) - weight_shape = [3, 5, 2, 2] - weights = self._get_rnd_float32(shape=weight_shape) - output = run_node(node_def, [x, weights], device=device) - - padh_left = weight_shape[2] - 1 - pads[0] - padh_right = weight_shape[2] - 1 - pads[1] - padw_left = weight_shape[3] - 1 - pads[2] - padw_right = weight_shape[3] - 1 - pads[3] - - kh = weight_shape[2] - kw = weight_shape[3] - outh = x_shape[2] + padh_right + padh_right - (kh - 1) - outw = x_shape[3] + padw_right + padw_right - (kw - 1) - - out_shape = [x_shape[0], weight_shape[1], outh, outw] - - test_output = np.zeros(out_shape) - for b in range(0, x_shape[0]): - for m in range(0, weight_shape[1]): - for c in range(0, x_shape[1]): - for h in range(0, outh): - for w in range(0, outw): - for k1 in range(h, h + kh): - for k2 in range(w, w + kw): - if (k1 - padh_left >= 0 and k2 - padw_left >= 0): - test_output[b][m][h][w] += x[b][c][k1 - padh_left][ - k2 - padw_left] * weights[c][m][kh + h - 1 - - k1][kw + w - 1 - k2] - - np.testing.assert_almost_equal(output["Y"], test_output, decimal=5) + for device in self._get_device_list(): + pads = [1, 1] + node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"], + pads=pads) + x_shape = [1, 3, 4] + x = self._get_rnd_float32(shape=x_shape) + weight_shape = [3, 5, 2] + weights = self._get_rnd_float32(shape=weight_shape) + output = run_node(node_def, [x, weights], device=device) + + padh_left = weight_shape[2] - 1 - pads[0] + padh_right = weight_shape[2] - 1 - pads[1] + kh = weight_shape[2] + outh = x_shape[2] + padh_right + padh_right - (kh - 1) + + out_shape = [x_shape[0], weight_shape[1], outh] + + test_output = np.zeros(out_shape) + for b in range(0, x_shape[0]): + for m in range(0, weight_shape[1]): + for c in range(0, x_shape[1]): + for h in range(0, outh): + for k in range(h, h + kh): + if (k - padh_left >= 0): + test_output[b][m][h] += x[b][c][ + k - padh_left] * weights[c][m][kh + h - 1 - k] + + np.testing.assert_almost_equal(output["Y"], test_output, decimal=5) + + # test for spatial dimension of colnolution is 2 + pads = [1, 1, 1, 1] + node_def = helper.make_node("ConvTranspose", ["X", "weights"], ["Y"], + pads=pads) + x_shape = [1, 3, 4, 6] + x = self._get_rnd_float32(shape=x_shape) + weight_shape = [3, 5, 2, 2] + weights = self._get_rnd_float32(shape=weight_shape) + output = run_node(node_def, [x, weights], device=device) + + padh_left = weight_shape[2] - 1 - pads[0] + padh_right = weight_shape[2] - 1 - pads[1] + padw_left = weight_shape[3] - 1 - pads[2] + padw_right = weight_shape[3] - 1 - pads[3] + + kh = weight_shape[2] + kw = weight_shape[3] + outh = x_shape[2] + padh_right + padh_right - (kh - 1) + outw = x_shape[3] + padw_right + padw_right - (kw - 1) + + out_shape = [x_shape[0], weight_shape[1], outh, outw] + + test_output = np.zeros(out_shape) + for b in range(0, x_shape[0]): + for m in range(0, weight_shape[1]): + for c in range(0, x_shape[1]): + for h in range(0, outh): + for w in range(0, outw): + for k1 in range(h, h + kh): + for k2 in range(w, w + kw): + if (k1 - padh_left >= 0 and k2 - padw_left >= 0): + test_output[b][m][h][w] += x[b][c][k1 - padh_left][ + k2 - padw_left] * weights[c][m][kh + h - 1 - + k1][kw + w - 1 - k2] + + np.testing.assert_almost_equal(output["Y"], test_output, decimal=5) def test_cosh(self): if legacy_opset_pre_ver(9): @@ -593,14 +596,16 @@ def test_cumsum(self): np.testing.assert_almost_equal(output["y"], y) def test_depth_to_space(self): - node_def = helper.make_node("DepthToSpace", ["X"], ["Y"], blocksize=2) - x_shape = [1, 12, 1, 1] - x = self._get_rnd_float32(shape=x_shape) - output = run_node(node_def, [x]) - x = np.transpose(x, (0, 2, 3, 1)) - y = np.reshape(np.swapaxes(x.reshape(1, 1, 1, 2, 2, 3), 2, 3), (1, 2, 2, 3)) - y = np.transpose(y, (0, 3, 1, 2)) - np.testing.assert_almost_equal(output["Y"], y, decimal=5) + for device in self._get_device_list(): + node_def = helper.make_node("DepthToSpace", ["X"], ["Y"], blocksize=2) + x_shape = [1, 12, 1, 1] + x = self._get_rnd_float32(shape=x_shape) + output = run_node(node_def, [x], device=device) + x = np.transpose(x, (0, 2, 3, 1)) + y = np.reshape(np.swapaxes(x.reshape(1, 1, 1, 2, 2, 3), 2, 3), + (1, 2, 2, 3)) + y = np.transpose(y, (0, 3, 1, 2)) + np.testing.assert_almost_equal(output["Y"], y, decimal=5) def test_dequantize_linear(self): node_def = helper.make_node("DequantizeLinear", @@ -741,6 +746,19 @@ def test_dynamic_quantize_linear(self): np.testing.assert_almost_equal(output["Y_Scale"], y_scale) np.testing.assert_almost_equal(output["Y_Zero_Point"], y_zero_point) + def test_einsum(self): + if legacy_opset_pre_ver(12): + raise unittest.SkipTest( + "ONNX version {} doesn't support Einsum.".format( + defs.onnx_opset_version())) + equation = 'ij,jk->ik' #matmul + node_def = helper.make_node("Einsum", ["X", "Y"], ["Z"], equation=equation) + x = self._get_rnd_float32(shape=[3, 4]) + y = self._get_rnd_float32(shape=[4, 5]) + z = np.einsum(equation, x, y) + output = run_node(node_def, [x, y]) + np.testing.assert_almost_equal(output["Z"], z) + def test_elu(self): node_def = helper.make_node("Elu", ["X"], ["Y"]) x = self._get_rnd_float32(shape=[100]) @@ -763,8 +781,7 @@ def test_equal(self): x = np.arange(8).reshape((2, 2, 2)).astype(np.uint64) y = np.arange(8).reshape((2, 2, 2)).astype(np.uint64) - with np.testing.assert_raises(RuntimeError): - output = run_node(node_def, [x, y]) + self.assertRaises(RuntimeError, run_node, node_def, [x, y]) def test_erf(self): if legacy_opset_pre_ver(9): @@ -1250,6 +1267,14 @@ def test_global_max_pool(self): test_output[i1][i2][0][0] = max np.testing.assert_almost_equal(output["Y"], test_output) + def test_greater(self): + node_def = helper.make_node("Greater", ["X", "Y"], ["Z"]) + x = self._get_rnd_float32(shape=[5, 3, 3, 2]) + y = self._get_rnd_float32(shape=[3, 3, 1]) + output = run_node(node_def, [x, y]) + np.testing.assert_equal(output["Z"], np.greater(x, np.reshape(y, + [1, 3, 3, 1]))) + def test_less(self): node_def = helper.make_node("Less", ["X", "Y"], ["Z"]) x = self._get_rnd_float32(shape=[5, 3, 3, 2]) @@ -1295,41 +1320,42 @@ def test_lp_normalization(self): rtol=1e-3) def test_l_r_n(self): - # Each input value is divided by: - # - # (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta - alpha = 2.0 - beta = 1.0 - bias = 5.0 - size = 3 - node_def = helper.make_node("LRN", ["X"], ["Y"], - alpha=alpha, - beta=beta, - bias=bias, - size=size) - x = self._get_rnd_float32(shape=[10, 2, 10, 10]) - output = run_node(node_def, [x]) - test_output = np.zeros([10, 10, 10, 2]) - x = np.transpose(x, axes=[0, 2, 3, 1]) - for i1 in range(0, 10): - for i2 in range(0, 10): - for j1 in range(0, 10): - for j2 in range(0, 2): - sqr_sum = 0. - # size of 3 means radius 1 in TF speak - # i.e. the immediate neighbouring values - # if "previous" neighbour exists - if j2 > 0: - sqr_sum += x[i1][i2][j1][j2 - 1] * x[i1][i2][j1][j2 - 1] - # current value - sqr_sum += x[i1][i2][j1][j2] * x[i1][i2][j1][j2] - # if "next" neighbour exists - if j2 < 2 - 1: - sqr_sum += x[i1][i2][j1][j2 + 1] * x[i1][i2][j1][j2 + 1] - test_output[i1][i2][j1][j2] = \ - x[i1][i2][j1][j2] / ((bias + (alpha * 1. / size) * sqr_sum) ** beta) - test_output = np.transpose(test_output, axes=[0, 3, 1, 2]) - np.testing.assert_almost_equal(output["Y"], test_output) + for device in self._get_device_list(): + # Each input value is divided by: + # + # (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta + alpha = 2.0 + beta = 1.0 + bias = 5.0 + size = 3 + node_def = helper.make_node("LRN", ["X"], ["Y"], + alpha=alpha, + beta=beta, + bias=bias, + size=size) + x = self._get_rnd_float32(shape=[10, 2, 10, 10]) + output = run_node(node_def, [x], device=device) + test_output = np.zeros([10, 10, 10, 2]) + x = np.transpose(x, axes=[0, 2, 3, 1]) + for i1 in range(0, 10): + for i2 in range(0, 10): + for j1 in range(0, 10): + for j2 in range(0, 2): + sqr_sum = 0. + # size of 3 means radius 1 in TF speak + # i.e. the immediate neighbouring values + # if "previous" neighbour exists + if j2 > 0: + sqr_sum += x[i1][i2][j1][j2 - 1] * x[i1][i2][j1][j2 - 1] + # current value + sqr_sum += x[i1][i2][j1][j2] * x[i1][i2][j1][j2] + # if "next" neighbour exists + if j2 < 2 - 1: + sqr_sum += x[i1][i2][j1][j2 + 1] * x[i1][i2][j1][j2 + 1] + test_output[i1][i2][j1][j2] = \ + x[i1][i2][j1][j2] / ((bias + (alpha * 1. / size) * sqr_sum) ** beta) + test_output = np.transpose(test_output, axes=[0, 3, 1, 2]) + np.testing.assert_almost_equal(output["Y"], test_output) def test_floor(self): node_def = helper.make_node("Floor", ["X"], ["Y"]) @@ -1579,7 +1605,7 @@ def test_loop(self): ['v1_final', 'v2_final', 'scan_output'], body=graph) try: - output = run_node(node_def, [M, cond, v1_initial, v2_initial]) + run_node(node_def, [M, cond, v1_initial, v2_initial]) raise AssertionError("Expected RuntimeError not raise when Loop inputs " + "M and cond are both not set at the same time") except RuntimeError as e: @@ -1706,58 +1732,65 @@ def _test_pooling(self, input_dtype=np.float32, p=None): - op = "MaxPool" if pooling_type.upper().startswith("MAX") else \ - "AveragePool" if pooling_type.upper() == "AVG" else "LpPool" - node_def_kwargs = { - "op_type": op, - "inputs": ["X"], - "outputs": ["Y"], - "kernel_shape": kernel_shape - } - - if strides is not None: - node_def_kwargs["strides"] = strides - if dilations is not None: - node_def_kwargs["dilations"] = dilations - if pads is not None: - node_def_kwargs["pads"] = pads - if auto_pad is not None: - node_def_kwargs["auto_pad"] = auto_pad - pads = auto_pad - if ceil_mode is not None: - node_def_kwargs["ceil_mode"] = ceil_mode - else: - ceil_mode = 0 - if count_include_pad is not None: - node_def_kwargs["count_include_pad"] = count_include_pad - if p is not None: - node_def_kwargs["p"] = p - - node_def = helper.make_node(**node_def_kwargs) - - if input_dtype == np.float32: - x = self._get_rnd_float32(shape=input_shape) - else: - x = self._get_rnd_int(low=np.iinfo(input_dtype).min, - high=np.iinfo(input_dtype).max, - shape=input_shape, - dtype=input_dtype) - - output = run_node(node_def, [x]) + for device in self._get_device_list(): + op = "MaxPool" if pooling_type.upper().startswith("MAX") else \ + "AveragePool" if pooling_type.upper() == "AVG" else "LpPool" + node_def_kwargs = { + "op_type": op, + "inputs": ["X"], + "outputs": ["Y"], + "kernel_shape": kernel_shape + } + + if strides is not None: + node_def_kwargs["strides"] = strides + if dilations is not None: + node_def_kwargs["dilations"] = dilations + if pads is not None: + node_def_kwargs["pads"] = pads + orig_pads = pads # save it for the 2nd loop + if auto_pad is not None: + node_def_kwargs["auto_pad"] = auto_pad + pads = auto_pad + orig_ceil_mode = ceil_mode # save it for the 2nd loop + if ceil_mode is not None: + node_def_kwargs["ceil_mode"] = ceil_mode + else: + ceil_mode = 0 + if count_include_pad is not None: + node_def_kwargs["count_include_pad"] = count_include_pad + if p is not None: + node_def_kwargs["p"] = p - test_output = py_pool(x, - kernel_shape=kernel_shape, - strides=strides, - dilations=dilations, - padding=pads, - ceil_mode=ceil_mode, - pooling_type=pooling_type, - include_indices=False, - p=p) + node_def = helper.make_node(**node_def_kwargs) - np.testing.assert_almost_equal(output["Y"], - test_output, - decimal=5 if pooling_type == "LP" else 7) + if input_dtype == np.float32: + x = self._get_rnd_float32(shape=input_shape) + else: + x = self._get_rnd_int(low=np.iinfo(input_dtype).min, + high=np.iinfo(input_dtype).max, + shape=input_shape, + dtype=input_dtype) + + output = run_node(node_def, [x], device=device) + + test_output = py_pool(x, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + padding=pads, + ceil_mode=ceil_mode, + pooling_type=pooling_type, + include_indices=False, + p=p) + + np.testing.assert_almost_equal(output["Y"], + test_output, + decimal=5 if pooling_type == "LP" else 7) + + # set pads and ceil_mode values back to the original values for the 2nd loop + pads = orig_pads + ceil_mode = orig_ceil_mode def test_max_pool_2d(self): kernel_shape = [1, 2] @@ -1833,10 +1866,6 @@ def test_max_pool_2d_dilations(self): kernel_shape = [3, 3] strides = [2, 2] dilations = [3, 3] - node_def = helper.make_node("MaxPool", ["X"], ["Y"], - kernel_shape=kernel_shape, - strides=strides, - dilations=dilations) input_shape = [10, 3, 24, 24] self._test_pooling(input_shape=input_shape, @@ -2038,33 +2067,34 @@ def test_max_pool_with_argmax_2d_dilations_ceil_pads(self): raise unittest.SkipTest( "ONNX version {} doesn't support dilations nor ceil mode.".format( defs.onnx_opset_version())) - - kernel_shape = [3, 3] - strides = [2, 2] - dilations = [3, 3] - pads = [1, 1, 2, 2] - ceil_mode = True - node_def = helper.make_node("MaxPool", ["X"], ["Y", "Ind"], - kernel_shape=kernel_shape, - strides=strides, - dilations=dilations, - pads=pads, - ceil_mode=ceil_mode) - - input_shape = [10, 1, 23, 23] - x = self._get_rnd_float32(shape=input_shape) - 2 - output = run_node(node_def, [x]) - - test_output, test_ind = py_pool(x, - kernel_shape=kernel_shape, - strides=strides, - dilations=dilations, - padding=pads, - ceil_mode=ceil_mode, - pooling_type="MAX") - - np.testing.assert_almost_equal(output["Y"], test_output) - np.testing.assert_almost_equal(output["Ind"], test_ind) + for device in self._get_device_list(): + kernel_shape = [3, 3] + strides = [2, 2] + dilations = [3, 3] + pads = [1, 1, 2, 2] + ceil_mode = True + node_def = helper.make_node("MaxPool", ["X"], ["Y", "Ind"], + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + pads=pads, + ceil_mode=ceil_mode) + + input_shape = [10, 3, 23, 23] + x = self._get_rnd_float32(shape=input_shape) - 2 + + output = run_node(node_def, [x], device=device) + + test_output, test_ind = py_pool(x, + kernel_shape=kernel_shape, + strides=strides, + dilations=dilations, + padding=pads, + ceil_mode=ceil_mode, + pooling_type="MAX") + + np.testing.assert_almost_equal(output["Y"], test_output) + np.testing.assert_almost_equal(output["Ind"], test_ind) def test_max_pool_with_argmax_3d(self): kernel_shape = [3, 3, 3] @@ -2089,37 +2119,36 @@ def test_max_pool_4d(self): self.assertRaises(RuntimeError, run_node, node_def, [x]) def test_max_unpool(self): - input_shape = [10, 10, 4, 4] - x = self._get_rnd_float32(shape=input_shape) + for device in self._get_device_list(): + input_shape = [10, 10, 4, 4] + x = self._get_rnd_float32(shape=input_shape) - X = helper.make_tensor_value_info('X', TensorProto.FLOAT, input_shape) - Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, input_shape) - - node_def = helper.make_node("MaxPool", ["X"], ["Pool", "Indices"], - kernel_shape=[2, 2], - strides=[2, 2]) - output_pool = run_node(node_def, [x]) - - node_def = helper.make_node("MaxUnpool", ["Pool", "Indices"], ["Y"], - kernel_shape=[2, 2], - strides=[2, 2]) - output_unpool = run_node(node_def, - [output_pool["Pool"], output_pool["Indices"]]) - - test_output = np.zeros(input_shape) - for i1 in range(0, input_shape[0]): - for i2 in range(0, input_shape[1]): - for i3 in range(0, input_shape[2], 2): - for i4 in range(0, input_shape[3], 2): - max_val = float('-inf') - for j1 in range(i3, i3 + 2): - for j2 in range(i4, i4 + 2): - if x[i1][i2][j1][j2] > max_val: - max_val = x[i1][i2][j1][j2] - max_ind = (j1, j2) - j1, j2 = max_ind - test_output[i1][i2][j1][j2] = max_val - np.testing.assert_almost_equal(output_unpool["Y"], test_output) + node_def = helper.make_node("MaxPool", ["X"], ["Pool", "Indices"], + kernel_shape=[2, 2], + strides=[2, 2]) + output_pool = run_node(node_def, [x], device=device) + + node_def = helper.make_node("MaxUnpool", ["Pool", "Indices"], ["Y"], + kernel_shape=[2, 2], + strides=[2, 2]) + output_unpool = run_node(node_def, + [output_pool["Pool"], output_pool["Indices"]], + device=device) + + test_output = np.zeros(input_shape) + for i1 in range(0, input_shape[0]): + for i2 in range(0, input_shape[1]): + for i3 in range(0, input_shape[2], 2): + for i4 in range(0, input_shape[3], 2): + max_val = float('-inf') + for j1 in range(i3, i3 + 2): + for j2 in range(i4, i4 + 2): + if x[i1][i2][j1][j2] > max_val: + max_val = x[i1][i2][j1][j2] + max_ind = (j1, j2) + j1, j2 = max_ind + test_output[i1][i2][j1][j2] = max_val + np.testing.assert_almost_equal(output_unpool["Y"], test_output) def test_average_pool_1d(self): kernel_shape = [3] @@ -3081,50 +3110,51 @@ def test_qlinearconv(self): raise unittest.SkipTest( "ONNX version {} doesn't support QLinearConv.".format( defs.onnx_opset_version())) + for device in self._get_device_list(): + # Test w_scale and w_zero_point as scalar + node_def = helper.make_node("QLinearConv", + inputs=[ + "x", "x_scale", "x_zero_point", "w", + "w_scale", "w_zero_point", "y_scale", + "y_zero_point" + ], + outputs=["Y"]) + x = np.array([ + [255, 174, 162, 25, 203, 168, 58], + [15, 59, 237, 95, 129, 0, 64], + [56, 242, 153, 221, 168, 12, 166], + [232, 178, 186, 195, 237, 162, 237], + [188, 39, 124, 77, 80, 102, 43], + [127, 230, 21, 83, 41, 40, 134], + [255, 154, 92, 141, 42, 148, 247], + ], + dtype=np.uint8).reshape((1, 1, 7, 7)) + x_scale = np.float32(0.00369204697) + x_zero_point = np.uint8(132) + + w = np.array([0], dtype=np.uint8).reshape((1, 1, 1, 1)) + w_scale = np.float32(0.00172794575) + w_zero_point = np.uint8(255) + + y = np.array([ + [0, 81, 93, 230, 52, 87, 197], + [240, 196, 18, 160, 126, 255, 191], + [199, 13, 102, 34, 87, 243, 89], + [23, 77, 69, 60, 18, 93, 18], + [67, 216, 131, 178, 175, 153, 212], + [128, 25, 234, 172, 214, 215, 121], + [0, 101, 163, 114, 213, 107, 8], + ], + dtype=np.uint8).reshape((1, 1, 7, 7)) + y_scale = np.float32(0.00162681262) + y_zero_point = np.uint8(123) - # Test w_scale and w_zero_point as scalar - node_def = helper.make_node("QLinearConv", - inputs=[ - "x", "x_scale", "x_zero_point", "w", - "w_scale", "w_zero_point", "y_scale", - "y_zero_point" - ], - outputs=["Y"]) - x = np.array([ - [255, 174, 162, 25, 203, 168, 58], - [15, 59, 237, 95, 129, 0, 64], - [56, 242, 153, 221, 168, 12, 166], - [232, 178, 186, 195, 237, 162, 237], - [188, 39, 124, 77, 80, 102, 43], - [127, 230, 21, 83, 41, 40, 134], - [255, 154, 92, 141, 42, 148, 247], - ], - dtype=np.uint8).reshape((1, 1, 7, 7)) - x_scale = np.float32(0.00369204697) - x_zero_point = np.uint8(132) - - w = np.array([0], dtype=np.uint8).reshape((1, 1, 1, 1)) - w_scale = np.float32(0.00172794575) - w_zero_point = np.uint8(255) - - y = np.array([ - [0, 81, 93, 230, 52, 87, 197], - [240, 196, 18, 160, 126, 255, 191], - [199, 13, 102, 34, 87, 243, 89], - [23, 77, 69, 60, 18, 93, 18], - [67, 216, 131, 178, 175, 153, 212], - [128, 25, 234, 172, 214, 215, 121], - [0, 101, 163, 114, 213, 107, 8], - ], - dtype=np.uint8).reshape((1, 1, 7, 7)) - y_scale = np.float32(0.00162681262) - y_zero_point = np.uint8(123) - - output = run_node(node_def, [ - x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, - y_zero_point - ]) - np.testing.assert_almost_equal(output["Y"], y) + output = run_node(node_def, [ + x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, + y_zero_point + ], + device=device) + np.testing.assert_almost_equal(output["Y"], y) def test_quantize_linear(self): node_def = helper.make_node("QuantizeLinear", @@ -3607,7 +3637,7 @@ def test_scatter_elements3(self): node_def = helper.make_node("ScatterElements", ["data", "indices", "updates"], ["outputs"]) with np.testing.assert_raises(tf.errors.InvalidArgumentError): - output = run_node(node_def, [data, indices, updates]) + run_node(node_def, [data, indices, updates]) def test_scatter_nd(self): if legacy_opset_pre_ver(11): @@ -3650,12 +3680,12 @@ def test_scatter_nd(self): dtype=np.int64) updates = np.array([37, 52, 30, 39], dtype=np.float32) with np.testing.assert_raises(tf.errors.InvalidArgumentError): - output = run_node(node_def, [data, indices, updates]) + run_node(node_def, [data, indices, updates]) indices = np.array([[0, 1], [-1, -1], [-2, -4]], dtype=np.int64) updates = np.array([[35, 36, 37, 38], [51, 52, 53, 54], [31, 32, 33, 34]], dtype=np.float32) with np.testing.assert_raises(tf.errors.InvalidArgumentError): - output = run_node(node_def, [data, indices, updates]) + run_node(node_def, [data, indices, updates]) def test_shape(self): node_def = helper.make_node("Shape", ["X"], ["Y"]) @@ -3783,15 +3813,16 @@ def test_softsign(self): np.testing.assert_almost_equal(output["Y"], x / (1 + np.abs(x))) def test_space_to_depth(self): - node_def = helper.make_node("SpaceToDepth", ["X"], ["Y"], blocksize=2) - x_shape = [1, 3, 2, 2] - x = self._get_rnd_float32(shape=x_shape) - output = run_node(node_def, [x]) - x = np.transpose(x, (0, 2, 3, 1)) - y = np.reshape(np.swapaxes(x.reshape(1, 1, 1, 1, 1, 12), 2, 3), - (1, 1, 1, 12)) - y = np.transpose(y, (0, 3, 1, 2)) - np.testing.assert_allclose(output["Y"], y, rtol=1e-3) + for device in self._get_device_list(): + node_def = helper.make_node("SpaceToDepth", ["X"], ["Y"], blocksize=2) + x_shape = [1, 3, 2, 2] + x = self._get_rnd_float32(shape=x_shape) + output = run_node(node_def, [x], device=device) + x = np.transpose(x, (0, 2, 3, 1)) + y = np.reshape(np.swapaxes(x.reshape(1, 1, 1, 1, 1, 12), 2, 3), + (1, 1, 1, 12)) + y = np.transpose(y, (0, 3, 1, 2)) + np.testing.assert_allclose(output["Y"], y, rtol=1e-3) def test_split(self): split = np.array([3, 3, 4]).astype(np.int64)