Skip to content

Commit

Permalink
fix(onnx_converter): fix conv converter of onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
dingshaohua960303 committed Sep 28, 2021
1 parent 9a209a8 commit 80c957b
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 20 deletions.
15 changes: 12 additions & 3 deletions mgeconvert/mge_context/mge_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from typing import List # pylint: disable=unused-import

from .mge_tensor import Tensor # pylint: disable=unused-import
from .mge_utils import get_mge_version, get_opr_type, get_shape, get_symvar_value
from .mge_utils import (
get_dep_vars,
get_mge_version,
get_opr_type,
get_shape,
get_symvar_value,
)

mge_version = get_mge_version()

Expand Down Expand Up @@ -141,7 +147,10 @@ class ConvolutionForwardOpr(MgeOpr):
def __init__(self, opr):
super().__init__(opr)
self.kernel_shape = get_shape(opr.inputs[1])
self.param_W = get_symvar_value(opr.inputs[1])
if len(get_dep_vars(opr.inputs[1], "Host2DeviceCopy")) == 0:
self.param_W = get_symvar_value(opr.inputs[1])
else:
self.param_W = None
self.data_format = self.params["format"]
self.dilation_w = self.params["dilate_w"]
self.dilation_h = self.params["dilate_h"]
Expand All @@ -158,7 +167,7 @@ def __init__(self, opr):

self.num_output = get_shape(opr.outputs[0])[1]
self.bias_term = False
self.group = self.param_W.shape[0] if self.param_W.ndim == 5 else 1
self.group = self.kernel_shape[0] if len(self.kernel_shape) == 5 else 1


class ConvBiasForwardOpr(ConvolutionForwardOpr):
Expand Down
55 changes: 39 additions & 16 deletions mgeconvert/onnx_converter/onnx_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,31 +368,54 @@ def convert(self):
opr = self._opr
attrs = self._get_attrs()
nodes = []
exclude_idx = [1] if attrs["group"] != 1 else []
exclude_idx = [0] if attrs["group"] != 1 else []
inputs = self._get_inputs(exclude_idx)
if isinstance(self._opr, ConvolutionBackwardDataOpr):
inputs = [inputs[1], inputs[0]]

outputs = self._get_outputs()
if attrs["group"] != 1:
inputs[1] = opr.name + "_filter_reshape_onnx"
flt = opr.param_W
w_idx = 0 if isinstance(self._opr, ConvolutionBackwardDataOpr) else 1
flt_shape = self._opr.inp_vars[w_idx].shape
flt_shape = [
flt.shape[0] * flt.shape[1],
flt.shape[2],
flt.shape[3],
flt.shape[4],
flt_shape[0] * flt_shape[1],
flt_shape[2],
flt_shape[3],
flt_shape[4],
]
flt_data = flt.reshape(flt_shape)
flt_tensor = onnx.helper.make_tensor_value_info(
inputs[1], mge2onnx_dtype_mapping[flt.dtype.type], flt_shape
)
flt_param = onnx.numpy_helper.from_array(flt_data, inputs[1])
self._net_sources.append(flt_tensor)
self._parameters.append(flt_param)

if opr.param_W is not None:
inputs[1] = opr.name + "_filter_reshape_onnx"
flt = opr.param_W
flt_data = flt.reshape(flt_shape)
flt_tensor = onnx.helper.make_tensor_value_info(
inputs[1], mge2onnx_dtype_mapping[flt.dtype.type], flt_shape
)
flt_param = onnx.numpy_helper.from_array(flt_data, inputs[1])
self._net_sources.append(flt_tensor)
self._parameters.append(flt_param)
else:
reshape_inputs = [inputs[1], opr.name + "shape_onnx"]
shape_tensor = onnx.helper.make_tensor_value_info(
reshape_inputs[1],
mge2onnx_dtype_mapping[np.int64],
(len(flt_shape),),
)
shape_param = onnx.numpy_helper.from_array(
np.array(flt_shape, dtype="int64"), reshape_inputs[1]
)
self._net_sources.append(shape_tensor)
self._parameters.append(shape_param)
reshape = onnx.helper.make_node(
"Reshape", reshape_inputs, [opr.name + "_filter_reshape_onnx"]
)
inputs[1] = opr.name + "_filter_reshape_onnx"
nodes.append(reshape)
onnx_op = "Conv"
if isinstance(self._opr, ConvolutionBackwardDataOpr):
onnx_op = "ConvTranspose"
inputs = [inputs[1], inputs[0]]
conv2d = onnx.helper.make_node(onnx_op, inputs, [outputs[0]], **attrs)
nodes.extend([conv2d])
nodes.append(conv2d)
return (nodes, self._net_sources, self._parameters)


Expand Down
2 changes: 1 addition & 1 deletion mgeconvert/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
__version__ = "0.4.1"
__version__ = "0.4.2"

# required megengine version range
MEGENGINE_LOWER = "0.6.0"
40 changes: 40 additions & 0 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import os

import megengine
import megengine.functional as F
import megengine.hub
import numpy as np
import onnxruntime as ort
import pytest
from megengine.jit import trace
from mgeconvert.mge_context import TopologyNetwork, get_mge_version
from mgeconvert.onnx_converter.onnx_converter import OnnxConverter

Expand Down Expand Up @@ -65,6 +67,44 @@ def test_conv2d(mode):
_test_convert_result(net.data, tmp_file, mge_result, max_error)


def test_conv_functional():
if megengine.__version__ < "1.2.0":
return

def convf(x, kernel):
batch = int(kernel.shape[0])
channel = int(kernel.shape[1])
bc = batch * channel
x = x.reshape((1, bc, int(x.shape[2]), int(x.shape[3])))
kernel = kernel.reshape(bc, 1, 1, int(kernel.shape[2]), int(kernel.shape[3]))
out = F.conv2d(x, kernel, groups=bc)
out = out.reshape(batch, channel, int(out.shape[2]), int(out.shape[3]))
return out

@trace(symbolic=True, capture_as_const=True)
def inference(x, kernel):
output = convf(x, kernel)
return output

inpx = np.random.random((1, 48, 100, 100)).astype("float32")
inpk = np.random.random((1, 48, 64, 64)).astype("float32")
expect = inference(megengine.tensor(inpx), megengine.tensor(inpk)).numpy()
inference.dump(
tmp_file + ".mge", arg_names=["x", "kernel"], optimize_for_inference=False,
)
net = TopologyNetwork(tmp_file + ".mge")
for version in range(8, 13):
converter = OnnxConverter(net, opset_version=version, graph_name="graph")
model = converter.convert()
with open(tmp_file + ".onnx", "wb") as fout:
fout.write(model.SerializeToString())
onnx_net = ort.InferenceSession(tmp_file + ".onnx")
pred_onx = onnx_net.run(None, {"x": inpx, "kernel": inpk})[0]
assert pred_onx.shape == expect.shape
assert pred_onx.dtype == expect.dtype
assert np.allclose(pred_onx, expect, atol=1e-6)


def test_linear():
net = LinearOpr()
mge_result = dump_mge_model(net, net.data, tmp_file)
Expand Down

0 comments on commit 80c957b

Please sign in to comment.