Skip to content

Commit

Permalink
fix(tflite): transpose Linear weight to NHWC
Browse files Browse the repository at this point in the history
  • Loading branch information
caowengang authored and CaoWGG committed Jun 6, 2022
1 parent 9e96da5 commit 38c2a19
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
31 changes: 31 additions & 0 deletions mgeconvert/converter_ir/ir_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
)
from .ir_tensor import AxisOrder, IRTensor

# pylint: disable=C0302


class IRConfig:
conv_prefer_same_pad_mode = False
Expand Down Expand Up @@ -116,6 +118,8 @@ class TransformerRule(Enum):
ADD_FAKE_HSIGMOID_OUT = 131
RENAME_CAFFE_LAYER_TENSOR = 132

TRANSPOSE_LINEAR_WEIGHT_TO_NHWC = 133


def cmp_rules(a, b):
if a.value < b.value:
Expand Down Expand Up @@ -1479,3 +1483,30 @@ def _remove_flatten(net: IRGraph):
flatten_opr.inp_tensors[0].user_opr.append(opr)
idx = net.all_oprs.index(flatten_opr)
net.delete_ops(idx)


@_register_tranformation_rule(TransformerRule.TRANSPOSE_LINEAR_WEIGHT_TO_NHWC)
def _convert_linear_weight_to_NHWC(net: IRGraph):
for opr in net.all_oprs:
if not isinstance(opr, (LinearOpr)):
continue
inp_oprs = net.find_inp_oprs(opr)
if (
isinstance(inp_oprs[0], OpBase)
and isinstance(inp_oprs[0], (FlattenOpr, ReshapeOpr))
and len(net.find_out_oprs(inp_oprs[0])) == 1
and net.find_out_oprs(inp_oprs[0])[0] == opr
):
reshape_opr = inp_oprs[0]
inp_shape = reshape_opr.inp_tensors[0].shape
if len(inp_shape) != 4:
continue
weight = opr.inp_tensors[1]
data = weight.np_data
# convert weight to [o_c, in_c, h, w]
data = np.reshape(data, [data.shape[0], *inp_shape[1:]])
# convert weight to [o_c, h, w, in_c]
data = np.transpose(data, [0, 2, 3, 1])
# flatten weight
data = np.reshape(data, [data.shape[0], -1])
weight.np_data = data
2 changes: 2 additions & 0 deletions mgeconvert/converters/tm_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def tracedmodule_to_tflite(
TransformerRule.PAD_WIDTH_AS_INPUT,
TransformerRule.EXPAND_ADD_RELU,
]
if not disable_nhwc:
transformer_options.append(TransformerRule.TRANSPOSE_LINEAR_WEIGHT_TO_NHWC)
if mtk:
# MTK devices only support batch_size 1
set_platform("mtk")
Expand Down

0 comments on commit 38c2a19

Please sign in to comment.