From 38c2a191868036edf4e4f153d018e380c7071891 Mon Sep 17 00:00:00 2001 From: caowengang Date: Wed, 25 May 2022 14:52:53 +0800 Subject: [PATCH] fix(tflite): transpose Linear weight to NHWC --- mgeconvert/converter_ir/ir_transform.py | 31 +++++++++++++++++++++++++ mgeconvert/converters/tm_to_tflite.py | 2 ++ 2 files changed, 33 insertions(+) diff --git a/mgeconvert/converter_ir/ir_transform.py b/mgeconvert/converter_ir/ir_transform.py index 1b8a827..b4e33b9 100644 --- a/mgeconvert/converter_ir/ir_transform.py +++ b/mgeconvert/converter_ir/ir_transform.py @@ -52,6 +52,8 @@ ) from .ir_tensor import AxisOrder, IRTensor +# pylint: disable=C0302 + class IRConfig: conv_prefer_same_pad_mode = False @@ -116,6 +118,8 @@ class TransformerRule(Enum): ADD_FAKE_HSIGMOID_OUT = 131 RENAME_CAFFE_LAYER_TENSOR = 132 + TRANSPOSE_LINEAR_WEIGHT_TO_NHWC = 133 + def cmp_rules(a, b): if a.value < b.value: @@ -1479,3 +1483,30 @@ def _remove_flatten(net: IRGraph): flatten_opr.inp_tensors[0].user_opr.append(opr) idx = net.all_oprs.index(flatten_opr) net.delete_ops(idx) + + +@_register_tranformation_rule(TransformerRule.TRANSPOSE_LINEAR_WEIGHT_TO_NHWC) +def _convert_linear_weight_to_NHWC(net: IRGraph): + for opr in net.all_oprs: + if not isinstance(opr, (LinearOpr)): + continue + inp_oprs = net.find_inp_oprs(opr) + if ( + isinstance(inp_oprs[0], OpBase) + and isinstance(inp_oprs[0], (FlattenOpr, ReshapeOpr)) + and len(net.find_out_oprs(inp_oprs[0])) == 1 + and net.find_out_oprs(inp_oprs[0])[0] == opr + ): + reshape_opr = inp_oprs[0] + inp_shape = reshape_opr.inp_tensors[0].shape + if len(inp_shape) != 4: + continue + weight = opr.inp_tensors[1] + data = weight.np_data + # convert weight to [o_c, in_c, h, w] + data = np.reshape(data, [data.shape[0], *inp_shape[1:]]) + # convert weight to [o_c, h, w, in_c] + data = np.transpose(data, [0, 2, 3, 1]) + # flatten weight + data = np.reshape(data, [data.shape[0], -1]) + weight.np_data = data diff --git a/mgeconvert/converters/tm_to_tflite.py b/mgeconvert/converters/tm_to_tflite.py index 23d039e..3b16941 100644 --- a/mgeconvert/converters/tm_to_tflite.py +++ b/mgeconvert/converters/tm_to_tflite.py @@ -81,6 +81,8 @@ def tracedmodule_to_tflite( TransformerRule.PAD_WIDTH_AS_INPUT, TransformerRule.EXPAND_ADD_RELU, ] + if not disable_nhwc: + transformer_options.append(TransformerRule.TRANSPOSE_LINEAR_WEIGHT_TO_NHWC) if mtk: # MTK devices only support batch_size 1 set_platform("mtk")