From 9a209a89c98ca804a8d0b1d6227943152aa76b90 Mon Sep 17 00:00:00 2001 From: Peng Xiong Date: Thu, 22 Jul 2021 14:08:45 +0800 Subject: [PATCH] feat(tflite_converter): add slice op --- README.md | 6 +- mgeconvert/mge_context/mge_op.py | 5 ++ mgeconvert/mge_context/mge_transform.py | 69 +++++++++++++++++++ .../tflite_converter/tflite_converter.py | 1 + mgeconvert/tflite_converter/tflite_op.py | 25 +++++++ test/test_tflite.py | 7 ++ 6 files changed, 110 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b97e798..e80a01d 100644 --- a/README.md +++ b/README.md @@ -89,8 +89,8 @@ convert caffe -h |softmax| ✓ | ✓ | ✓ | ✓ | |leaky_relu| ✓ | × | × | ✓ | |sub| ✓ | ✓ | ✓ | ✓ | -|slice(subtensor)| ✓ | ✓ | ✓ | × | -|squeeze(axis_add_remove)| ✓ | ✓ | ✓ | × | +|slice(subtensor)| ✓ | ✓ | ✓ | ✓ | +|squeeze(axis_add_remove)| ✓ | ✓ | ✓ | ✓ | |tanh| ✓ | ✓ | ✓ | ✓ | |typecvt| ✓ | ✓ | ✓ | ✓ | -|transpose(dimshuffle)| ✓ | ✓ | ✓ | × | +|transpose(dimshuffle)| ✓ | ✓ | ✓ | ✓ | diff --git a/mgeconvert/mge_context/mge_op.py b/mgeconvert/mge_context/mge_op.py index d3d9356..b49d99c 100644 --- a/mgeconvert/mge_context/mge_op.py +++ b/mgeconvert/mge_context/mge_op.py @@ -448,3 +448,8 @@ class SoftmaxOpr(OpBase): class PadOpr(OpBase): name = "Pad" + + +class SqueezeOpr(OpBase): + name = "Squeeze" + squeeze_dims = [] # type: ignore[var-annotated] diff --git a/mgeconvert/mge_context/mge_transform.py b/mgeconvert/mge_context/mge_transform.py index 6cff4bf..8ea6b00 100644 --- a/mgeconvert/mge_context/mge_transform.py +++ b/mgeconvert/mge_context/mge_transform.py @@ -136,6 +136,8 @@ class TransformerRule(Enum): FUSE_FOR_DECONV_BIAS = 115 FUSE_FOR_FULLY_CONNECTED = 116 RESHAPE_BIAS_TO_1DIM = 117 + # for TFLite Converter + SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE = 200 TRANSFORMMAP: Dict[Enum, Callable] = {} @@ -412,6 +414,73 @@ def _deconv_shape_as_input(net): op.inp_vars = [shape_tensor, op.inp_vars[1], op.inp_vars[0]] +@_register_tranformation_rule(TransformerRule.SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE) +def _make_slice_as_inputs(net): + for op in net.all_oprs: + if not isinstance(op, Ops.SubtensorOpr): + continue + + ndim = op.inp_vars[0].ndim + + def make_input(axis, param, init_value): + # make inputs: begin, end and step. + ret = [init_value] * ndim # pylint: disable=cell-var-from-loop + for k, v in zip(axis, param): + ret[k] = v + ret = FakeSymbolVar( + sid=net.max_id, + name=op.name + "fake_input", # pylint: disable=cell-var-from-loop + shape=[len(ret)], + dtype=np.int32, + owner=op, # pylint: disable=cell-var-from-loop + byte_list=np.array(ret, np.int32).tobytes(), + ) + net.max_id += 1 + return net.get_var(ret) + + begins_tensor = make_input(op.axis, op.begin_param, 0) + ends_tensor = make_input(op.axis, op.end_param, np.iinfo(np.int32).max) + steps_tensor = make_input(op.axis, op.step_param, 1) + + op.inp_vars = [op.inp_vars[0], begins_tensor, ends_tensor, steps_tensor] + + # TFLite slice do not support squeeze axis, so insert a squeeze opr here. + # infer actual output shape of tflite slice + desired_out_shape = op.out_vars[0].shape + actual_out_shape = [1] * ndim + idx = 0 + for i in range(ndim): + if i in op.squeeze_axis: + continue + actual_out_shape[i] = desired_out_shape[idx] + idx += 1 + slice_out_symvar = FakeSymbolVar( + sid=net.max_id, + name=op.name + "fake_output", + shape=actual_out_shape, + dtype=op.out_vars[0].dtype, + owner=op, + ) + net.max_id += 1 + slice_op_output = net.get_var(slice_out_symvar) + old_out = op.out_vars + op.out_vars = [slice_op_output] + + squeeze = Ops.SqueezeOpr() + squeeze.squeeze_dims = op.squeeze_axis + squeeze.inp_vars = [slice_op_output] + squeeze.out_vars = old_out + squeeze.inp_oprs = [op] + squeeze.out_oprs = op.out_oprs + op.out_oprs = [squeeze] + squeeze.id = net.max_id + net.max_id += 1 + + idx = net._opr_ids.index(op.id) + 1 + net._opr_ids.insert(idx, squeeze.id) + net.all_oprs.insert(idx, squeeze) + + @_register_tranformation_rule(TransformerRule.PADDING_FOR_CONV) def _make_padding(net): def have_padding(opr): diff --git a/mgeconvert/tflite_converter/tflite_converter.py b/mgeconvert/tflite_converter/tflite_converter.py index 0326942..7e2bd88 100644 --- a/mgeconvert/tflite_converter/tflite_converter.py +++ b/mgeconvert/tflite_converter/tflite_converter.py @@ -55,6 +55,7 @@ class TFLiteConverter: TransformerRule.FUSE_FOR_LEAKY_RELU, TransformerRule.EXPAND_MUL_ADD3, TransformerRule.EXPAND_ADD_SIGMOID, + TransformerRule.SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE, ] def __init__(self, toponet, transformer_options=None, graph_name="graph"): diff --git a/mgeconvert/tflite_converter/tflite_op.py b/mgeconvert/tflite_converter/tflite_op.py index b667347..a4f72e0 100644 --- a/mgeconvert/tflite_converter/tflite_op.py +++ b/mgeconvert/tflite_converter/tflite_op.py @@ -31,6 +31,8 @@ ReshapeOpr, ResizeForwardOpr, SoftmaxOpr, + SqueezeOpr, + SubtensorOpr, Tensor, get_platform, ) @@ -55,6 +57,8 @@ ReshapeOptions, ResizeBilinearOptions, SoftmaxOptions, + SqueezeOptions, + StridedSliceOptions, SubOptions, TransposeConvOptions, TransposeOptions, @@ -441,3 +445,24 @@ def _leaky_relu(mge_opr, builder): LeakyReluOptions.LeakyReluOptionsAddAlpha(builder, mge_opr.negative_slope[0]) options = LeakyReluOptions.LeakyReluOptionsEnd(builder) return BuiltinOperator.LEAKY_RELU, BuiltinOptions.LeakyReluOptions, options + + +@_register_op(SubtensorOpr) +def _subtensor(_, builder): + StridedSliceOptions.StridedSliceOptionsStart(builder) + options = StridedSliceOptions.StridedSliceOptionsEnd(builder) + return BuiltinOperator.STRIDED_SLICE, BuiltinOptions.StridedSliceOptions, options + + +@_register_op(SqueezeOpr) +def _squeeze(mge_opr, builder): + SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector( + builder, len(mge_opr.squeeze_dims) + ) + for i in mge_opr.squeeze_dims: + builder.PrependInt32(i) + squeeze_dims = builder.EndVector(len(mge_opr.squeeze_dims)) + SqueezeOptions.SqueezeOptionsStart(builder) + SqueezeOptions.SqueezeOptionsAddSqueezeDims(builder, squeeze_dims) + options = SqueezeOptions.SqueezeOptionsEnd(builder) + return BuiltinOperator.SQUEEZE, BuiltinOptions.SqueezeOptions, options diff --git a/test/test_tflite.py b/test/test_tflite.py index ed867bb..b6c752a 100644 --- a/test/test_tflite.py +++ b/test/test_tflite.py @@ -24,6 +24,7 @@ ReduceOpr, ReshapeOpr, SoftmaxOpr, + SubtensorOpr, TransposeOpr, XORNet, dump_mge_model, @@ -170,6 +171,12 @@ def test_transopse(): _test_convert_result(net.data, tmp_file, mge_result, max_error, nhwc2=False) +def test_slice(): + net = SubtensorOpr() + mge_result = dump_mge_model(net, net.data, tmp_file) + _test_convert_result(net.data, tmp_file, mge_result, max_error, False, False) + + def test_xornet(): if megengine.__version__ < "1.1.0": return