From 80c957b432dedf8d4b1588250c67fb1ae80cfeed Mon Sep 17 00:00:00 2001 From: shaohua ding Date: Mon, 13 Sep 2021 13:46:33 +0800 Subject: [PATCH] fix(onnx_converter): fix conv converter of onnx --- mgeconvert/mge_context/mge_op.py | 15 ++++++-- mgeconvert/onnx_converter/onnx_op.py | 55 ++++++++++++++++++++-------- mgeconvert/version.py | 2 +- test/test_onnx.py | 40 ++++++++++++++++++++ 4 files changed, 92 insertions(+), 20 deletions(-) diff --git a/mgeconvert/mge_context/mge_op.py b/mgeconvert/mge_context/mge_op.py index b49d99c..6dda4da 100644 --- a/mgeconvert/mge_context/mge_op.py +++ b/mgeconvert/mge_context/mge_op.py @@ -11,7 +11,13 @@ from typing import List # pylint: disable=unused-import from .mge_tensor import Tensor # pylint: disable=unused-import -from .mge_utils import get_mge_version, get_opr_type, get_shape, get_symvar_value +from .mge_utils import ( + get_dep_vars, + get_mge_version, + get_opr_type, + get_shape, + get_symvar_value, +) mge_version = get_mge_version() @@ -141,7 +147,10 @@ class ConvolutionForwardOpr(MgeOpr): def __init__(self, opr): super().__init__(opr) self.kernel_shape = get_shape(opr.inputs[1]) - self.param_W = get_symvar_value(opr.inputs[1]) + if len(get_dep_vars(opr.inputs[1], "Host2DeviceCopy")) == 0: + self.param_W = get_symvar_value(opr.inputs[1]) + else: + self.param_W = None self.data_format = self.params["format"] self.dilation_w = self.params["dilate_w"] self.dilation_h = self.params["dilate_h"] @@ -158,7 +167,7 @@ def __init__(self, opr): self.num_output = get_shape(opr.outputs[0])[1] self.bias_term = False - self.group = self.param_W.shape[0] if self.param_W.ndim == 5 else 1 + self.group = self.kernel_shape[0] if len(self.kernel_shape) == 5 else 1 class ConvBiasForwardOpr(ConvolutionForwardOpr): diff --git a/mgeconvert/onnx_converter/onnx_op.py b/mgeconvert/onnx_converter/onnx_op.py index b350757..28c2d25 100644 --- a/mgeconvert/onnx_converter/onnx_op.py +++ b/mgeconvert/onnx_converter/onnx_op.py @@ -368,31 +368,54 @@ def convert(self): opr = self._opr attrs = self._get_attrs() nodes = [] - exclude_idx = [1] if attrs["group"] != 1 else [] + exclude_idx = [0] if attrs["group"] != 1 else [] inputs = self._get_inputs(exclude_idx) + if isinstance(self._opr, ConvolutionBackwardDataOpr): + inputs = [inputs[1], inputs[0]] + outputs = self._get_outputs() if attrs["group"] != 1: - inputs[1] = opr.name + "_filter_reshape_onnx" - flt = opr.param_W + w_idx = 0 if isinstance(self._opr, ConvolutionBackwardDataOpr) else 1 + flt_shape = self._opr.inp_vars[w_idx].shape flt_shape = [ - flt.shape[0] * flt.shape[1], - flt.shape[2], - flt.shape[3], - flt.shape[4], + flt_shape[0] * flt_shape[1], + flt_shape[2], + flt_shape[3], + flt_shape[4], ] - flt_data = flt.reshape(flt_shape) - flt_tensor = onnx.helper.make_tensor_value_info( - inputs[1], mge2onnx_dtype_mapping[flt.dtype.type], flt_shape - ) - flt_param = onnx.numpy_helper.from_array(flt_data, inputs[1]) - self._net_sources.append(flt_tensor) - self._parameters.append(flt_param) + + if opr.param_W is not None: + inputs[1] = opr.name + "_filter_reshape_onnx" + flt = opr.param_W + flt_data = flt.reshape(flt_shape) + flt_tensor = onnx.helper.make_tensor_value_info( + inputs[1], mge2onnx_dtype_mapping[flt.dtype.type], flt_shape + ) + flt_param = onnx.numpy_helper.from_array(flt_data, inputs[1]) + self._net_sources.append(flt_tensor) + self._parameters.append(flt_param) + else: + reshape_inputs = [inputs[1], opr.name + "shape_onnx"] + shape_tensor = onnx.helper.make_tensor_value_info( + reshape_inputs[1], + mge2onnx_dtype_mapping[np.int64], + (len(flt_shape),), + ) + shape_param = onnx.numpy_helper.from_array( + np.array(flt_shape, dtype="int64"), reshape_inputs[1] + ) + self._net_sources.append(shape_tensor) + self._parameters.append(shape_param) + reshape = onnx.helper.make_node( + "Reshape", reshape_inputs, [opr.name + "_filter_reshape_onnx"] + ) + inputs[1] = opr.name + "_filter_reshape_onnx" + nodes.append(reshape) onnx_op = "Conv" if isinstance(self._opr, ConvolutionBackwardDataOpr): onnx_op = "ConvTranspose" - inputs = [inputs[1], inputs[0]] conv2d = onnx.helper.make_node(onnx_op, inputs, [outputs[0]], **attrs) - nodes.extend([conv2d]) + nodes.append(conv2d) return (nodes, self._net_sources, self._parameters) diff --git a/mgeconvert/version.py b/mgeconvert/version.py index 9134700..f4a54af 100644 --- a/mgeconvert/version.py +++ b/mgeconvert/version.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -__version__ = "0.4.1" +__version__ = "0.4.2" # required megengine version range MEGENGINE_LOWER = "0.6.0" diff --git a/test/test_onnx.py b/test/test_onnx.py index dad1ac4..da51cb4 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -9,10 +9,12 @@ import os import megengine +import megengine.functional as F import megengine.hub import numpy as np import onnxruntime as ort import pytest +from megengine.jit import trace from mgeconvert.mge_context import TopologyNetwork, get_mge_version from mgeconvert.onnx_converter.onnx_converter import OnnxConverter @@ -65,6 +67,44 @@ def test_conv2d(mode): _test_convert_result(net.data, tmp_file, mge_result, max_error) +def test_conv_functional(): + if megengine.__version__ < "1.2.0": + return + + def convf(x, kernel): + batch = int(kernel.shape[0]) + channel = int(kernel.shape[1]) + bc = batch * channel + x = x.reshape((1, bc, int(x.shape[2]), int(x.shape[3]))) + kernel = kernel.reshape(bc, 1, 1, int(kernel.shape[2]), int(kernel.shape[3])) + out = F.conv2d(x, kernel, groups=bc) + out = out.reshape(batch, channel, int(out.shape[2]), int(out.shape[3])) + return out + + @trace(symbolic=True, capture_as_const=True) + def inference(x, kernel): + output = convf(x, kernel) + return output + + inpx = np.random.random((1, 48, 100, 100)).astype("float32") + inpk = np.random.random((1, 48, 64, 64)).astype("float32") + expect = inference(megengine.tensor(inpx), megengine.tensor(inpk)).numpy() + inference.dump( + tmp_file + ".mge", arg_names=["x", "kernel"], optimize_for_inference=False, + ) + net = TopologyNetwork(tmp_file + ".mge") + for version in range(8, 13): + converter = OnnxConverter(net, opset_version=version, graph_name="graph") + model = converter.convert() + with open(tmp_file + ".onnx", "wb") as fout: + fout.write(model.SerializeToString()) + onnx_net = ort.InferenceSession(tmp_file + ".onnx") + pred_onx = onnx_net.run(None, {"x": inpx, "kernel": inpk})[0] + assert pred_onx.shape == expect.shape + assert pred_onx.dtype == expect.dtype + assert np.allclose(pred_onx, expect, atol=1e-6) + + def test_linear(): net = LinearOpr() mge_result = dump_mge_model(net, net.data, tmp_file)