Skip to content

Commit

Permalink
fix tflite ConvTranspose2d
Browse files Browse the repository at this point in the history
  • Loading branch information
daisycx authored and xpmemeda committed Jul 16, 2021
1 parent 4c66900 commit e25af7c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
5 changes: 5 additions & 0 deletions mgeconvert/tflite_converter/tflite_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def need_convert(mge_opr):
if tfl_opr_type not in self._opr_type_list:
self._opr_type_list.append(tfl_opr_type)

if hasattr(mge_opr, "type") and mge_opr.type == "ConvolutionBackwardData":
mge_opr.inp_vars = [mge_opr.inp_vars[0]] + list(
reversed(mge_opr.inp_vars[-2:])
) # shape, weight, input

# buffer and tensor
for var in mge_opr.inp_vars + mge_opr.out_vars:
if var in self._var2tensor:
Expand Down
29 changes: 22 additions & 7 deletions mgeconvert/tflite_converter/tflite_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,24 @@ def get_shape_param(tensor: Tensor, mge_opr: MgeOpr, disable_nhwc=False):
# NCHW to NHWC
# except the output of reshape
if not disable_nhwc:
shape = [
tensor.shape[0],
tensor.shape[2],
tensor.shape[3],
tensor.shape[1],
]
if (
hasattr(mge_opr, "type")
and mge_opr.type == "ConvolutionBackwardData"
and tensor.np_data is not None
):
shape = [
tensor.shape[1],
tensor.shape[2],
tensor.shape[3],
tensor.shape[0],
]
else:
shape = [
tensor.shape[0],
tensor.shape[2],
tensor.shape[3],
tensor.shape[1],
]
elif tensor.ndim > 4:
assert False, "ERROR: output ndim {0} is not supported now".format(tensor.ndim)

Expand All @@ -98,7 +110,10 @@ def get_shape_param(tensor: Tensor, mge_opr: MgeOpr, disable_nhwc=False):
value = tensor.np_data
if value is not None:
if value.ndim == 4:
value = value.transpose(0, 2, 3, 1)
if hasattr(mge_opr, "type") and mge_opr.type == "ConvolutionBackwardData":
value = value.transpose(1, 2, 3, 0)
else:
value = value.transpose(0, 2, 3, 1)
number_list = value.reshape(-1)

if len(number_list) > 0:
Expand Down
6 changes: 6 additions & 0 deletions test/test_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def test_conv2d():
_test_convert_result(net.data, tmp_file, mge_result, max_error)


def test_deconv2d():
net = ConvOpr("tflite_transpose")
mge_result = dump_mge_model(net, net.data, tmp_file)
_test_convert_result(net.data, tmp_file, mge_result, max_error)


def test_linear():
net = LinearOpr()
mge_result = dump_mge_model(net, net.data, tmp_file)
Expand Down
11 changes: 11 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def __init__(self, mode):
np.random.random(self.transpose_conv[1].bias.shape).astype(np.float32)
)

self.tflite_transpose_conv = M.Sequential(
M.ConvTranspose2d(3, 5, (3, 4), stride=(3, 2), groups=1),
M.ConvTranspose2d(5, 3, (3, 3)),
)
self.tflite_transpose_conv[0].bias = mge.Parameter(
np.random.random(self.transpose_conv[0].bias.shape).astype(np.float32)
)
self.tflite_transpose_conv[1].bias = mge.Parameter(
np.random.random(self.transpose_conv[1].bias.shape).astype(np.float32)
)

def forward(self, x):
return getattr(self, self.mode + "_conv")(x)

Expand Down

0 comments on commit e25af7c

Please sign in to comment.