Skip to content

Commit

Permalink
feat(qat_pattern): add qat deconv pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
dingshaohua960303 committed Jun 6, 2022
1 parent 7554304 commit 3603e18
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
13 changes: 11 additions & 2 deletions mgeconvert/backend/ir_to_tflite/tflite_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,18 @@ def _deconv(mge_opr, builder):

@_register_op(ConcatOpr)
def _concat(mge_opr, builder):
if len(set([t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors])) != 1:
if (
mge_opr.inp_tensors[0].q_dtype == "int8"
and len({t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors}) != 1
):
logger.warning(
"tflite int8 concat doesn't support inputs outputs with different scale!"
)
if mge_opr.inp_tensors[0].q_dtype == "int16" and not all(
[t.zero_point == 0 for t in mge_opr.inp_tensors + mge_opr.out_tensors]
):
logger.warning(
"tflite concat doesn't support inputs outputs with different scale!"
"tflite int16 concat doesn't support inputs outputs with zero point != 0!"
)

ConcatenationOptions.ConcatenationOptionsStart(builder)
Expand Down
15 changes: 15 additions & 0 deletions mgeconvert/frontend/tm_to_ir/qat_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa
MatchAnyNode,
)

pat_deconv = (
QATModule._apply_fakequant_with_observer,
MatchAnyNode,
(F.nn.conv_transpose2d, InputNode, QATModule._apply_fakequant_with_observer),
MatchAnyNode,
)

pat_deconv_relu = (
QATModule._apply_fakequant_with_observer,
MatchAnyNode,
Expand Down Expand Up @@ -168,6 +175,13 @@ def qat_conv(module, expr, call_expr, net, _):
return op


@register_fusion_pattern(pat_deconv)
def qat_deconv(module, expr, call_expr, net, _):
conv = expr.inputs[1].expr
op = gen_qat_conv_opr(module, conv, call_expr, net, is_deconv=True)
return op


@register_fusion_pattern(pat_deconv_bias)
def qat_deconv_bias(module, expr, call_expr, irgraph, _):
conv = expr.inputs[1].expr
Expand Down Expand Up @@ -206,6 +220,7 @@ def qat_deconv_relu_bias(
pat_deconv_relu,
pat_conv_relu,
pat_conv,
pat_deconv,
pat_deconv_bias,
]

Expand Down

0 comments on commit 3603e18

Please sign in to comment.