From 3603e187da4b8193111ae941d3f151fc6281af4c Mon Sep 17 00:00:00 2001 From: dingshaohua Date: Thu, 26 May 2022 14:23:53 +0800 Subject: [PATCH] feat(qat_pattern): add qat deconv pattern --- mgeconvert/backend/ir_to_tflite/tflite_op.py | 13 +++++++++++-- mgeconvert/frontend/tm_to_ir/qat_pattern.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mgeconvert/backend/ir_to_tflite/tflite_op.py b/mgeconvert/backend/ir_to_tflite/tflite_op.py index d0dfe77..9d6057b 100644 --- a/mgeconvert/backend/ir_to_tflite/tflite_op.py +++ b/mgeconvert/backend/ir_to_tflite/tflite_op.py @@ -426,9 +426,18 @@ def _deconv(mge_opr, builder): @_register_op(ConcatOpr) def _concat(mge_opr, builder): - if len(set([t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors])) != 1: + if ( + mge_opr.inp_tensors[0].q_dtype == "int8" + and len({t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors}) != 1 + ): + logger.warning( + "tflite int8 concat doesn't support inputs outputs with different scale!" + ) + if mge_opr.inp_tensors[0].q_dtype == "int16" and not all( + [t.zero_point == 0 for t in mge_opr.inp_tensors + mge_opr.out_tensors] + ): logger.warning( - "tflite concat doesn't support inputs outputs with different scale!" + "tflite int16 concat doesn't support inputs outputs with zero point != 0!" ) ConcatenationOptions.ConcatenationOptionsStart(builder) diff --git a/mgeconvert/frontend/tm_to_ir/qat_pattern.py b/mgeconvert/frontend/tm_to_ir/qat_pattern.py index 4f2cd51..7a4d19b 100644 --- a/mgeconvert/frontend/tm_to_ir/qat_pattern.py +++ b/mgeconvert/frontend/tm_to_ir/qat_pattern.py @@ -118,6 +118,13 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa MatchAnyNode, ) +pat_deconv = ( + QATModule._apply_fakequant_with_observer, + MatchAnyNode, + (F.nn.conv_transpose2d, InputNode, QATModule._apply_fakequant_with_observer), + MatchAnyNode, +) + pat_deconv_relu = ( QATModule._apply_fakequant_with_observer, MatchAnyNode, @@ -168,6 +175,13 @@ def qat_conv(module, expr, call_expr, net, _): return op +@register_fusion_pattern(pat_deconv) +def qat_deconv(module, expr, call_expr, net, _): + conv = expr.inputs[1].expr + op = gen_qat_conv_opr(module, conv, call_expr, net, is_deconv=True) + return op + + @register_fusion_pattern(pat_deconv_bias) def qat_deconv_bias(module, expr, call_expr, irgraph, _): conv = expr.inputs[1].expr @@ -206,6 +220,7 @@ def qat_deconv_relu_bias( pat_deconv_relu, pat_conv_relu, pat_conv, + pat_deconv, pat_deconv_bias, ]