Skip to content

Commit

Permalink
feat(tflite_converter): add slice op
Browse files Browse the repository at this point in the history
  • Loading branch information
Peng Xiong authored and xpmemeda committed Jul 28, 2021
1 parent 775428b commit 9a209a8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 3 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ convert caffe -h
|softmax|||||
|leaky_relu|| × | × ||
|sub|||||
|slice(subtensor)|||| × |
|squeeze(axis_add_remove)|||| × |
|slice(subtensor)|||| |
|squeeze(axis_add_remove)|||| |
|tanh|||||
|typecvt|||||
|transpose(dimshuffle)|||| × |
|transpose(dimshuffle)|||| |
5 changes: 5 additions & 0 deletions mgeconvert/mge_context/mge_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,8 @@ class SoftmaxOpr(OpBase):

class PadOpr(OpBase):
name = "Pad"


class SqueezeOpr(OpBase):
name = "Squeeze"
squeeze_dims = [] # type: ignore[var-annotated]
69 changes: 69 additions & 0 deletions mgeconvert/mge_context/mge_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ class TransformerRule(Enum):
FUSE_FOR_DECONV_BIAS = 115
FUSE_FOR_FULLY_CONNECTED = 116
RESHAPE_BIAS_TO_1DIM = 117
# for TFLite Converter
SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE = 200


TRANSFORMMAP: Dict[Enum, Callable] = {}
Expand Down Expand Up @@ -412,6 +414,73 @@ def _deconv_shape_as_input(net):
op.inp_vars = [shape_tensor, op.inp_vars[1], op.inp_vars[0]]


@_register_tranformation_rule(TransformerRule.SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE)
def _make_slice_as_inputs(net):
for op in net.all_oprs:
if not isinstance(op, Ops.SubtensorOpr):
continue

ndim = op.inp_vars[0].ndim

def make_input(axis, param, init_value):
# make inputs: begin, end and step.
ret = [init_value] * ndim # pylint: disable=cell-var-from-loop
for k, v in zip(axis, param):
ret[k] = v
ret = FakeSymbolVar(
sid=net.max_id,
name=op.name + "fake_input", # pylint: disable=cell-var-from-loop
shape=[len(ret)],
dtype=np.int32,
owner=op, # pylint: disable=cell-var-from-loop
byte_list=np.array(ret, np.int32).tobytes(),
)
net.max_id += 1
return net.get_var(ret)

begins_tensor = make_input(op.axis, op.begin_param, 0)
ends_tensor = make_input(op.axis, op.end_param, np.iinfo(np.int32).max)
steps_tensor = make_input(op.axis, op.step_param, 1)

op.inp_vars = [op.inp_vars[0], begins_tensor, ends_tensor, steps_tensor]

# TFLite slice do not support squeeze axis, so insert a squeeze opr here.
# infer actual output shape of tflite slice
desired_out_shape = op.out_vars[0].shape
actual_out_shape = [1] * ndim
idx = 0
for i in range(ndim):
if i in op.squeeze_axis:
continue
actual_out_shape[i] = desired_out_shape[idx]
idx += 1
slice_out_symvar = FakeSymbolVar(
sid=net.max_id,
name=op.name + "fake_output",
shape=actual_out_shape,
dtype=op.out_vars[0].dtype,
owner=op,
)
net.max_id += 1
slice_op_output = net.get_var(slice_out_symvar)
old_out = op.out_vars
op.out_vars = [slice_op_output]

squeeze = Ops.SqueezeOpr()
squeeze.squeeze_dims = op.squeeze_axis
squeeze.inp_vars = [slice_op_output]
squeeze.out_vars = old_out
squeeze.inp_oprs = [op]
squeeze.out_oprs = op.out_oprs
op.out_oprs = [squeeze]
squeeze.id = net.max_id
net.max_id += 1

idx = net._opr_ids.index(op.id) + 1
net._opr_ids.insert(idx, squeeze.id)
net.all_oprs.insert(idx, squeeze)


@_register_tranformation_rule(TransformerRule.PADDING_FOR_CONV)
def _make_padding(net):
def have_padding(opr):
Expand Down
1 change: 1 addition & 0 deletions mgeconvert/tflite_converter/tflite_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TFLiteConverter:
TransformerRule.FUSE_FOR_LEAKY_RELU,
TransformerRule.EXPAND_MUL_ADD3,
TransformerRule.EXPAND_ADD_SIGMOID,
TransformerRule.SLICE_PARAMS_AS_INPUTS_AND_MAKE_SQUEEZE,
]

def __init__(self, toponet, transformer_options=None, graph_name="graph"):
Expand Down
25 changes: 25 additions & 0 deletions mgeconvert/tflite_converter/tflite_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
ReshapeOpr,
ResizeForwardOpr,
SoftmaxOpr,
SqueezeOpr,
SubtensorOpr,
Tensor,
get_platform,
)
Expand All @@ -55,6 +57,8 @@
ReshapeOptions,
ResizeBilinearOptions,
SoftmaxOptions,
SqueezeOptions,
StridedSliceOptions,
SubOptions,
TransposeConvOptions,
TransposeOptions,
Expand Down Expand Up @@ -441,3 +445,24 @@ def _leaky_relu(mge_opr, builder):
LeakyReluOptions.LeakyReluOptionsAddAlpha(builder, mge_opr.negative_slope[0])
options = LeakyReluOptions.LeakyReluOptionsEnd(builder)
return BuiltinOperator.LEAKY_RELU, BuiltinOptions.LeakyReluOptions, options


@_register_op(SubtensorOpr)
def _subtensor(_, builder):
StridedSliceOptions.StridedSliceOptionsStart(builder)
options = StridedSliceOptions.StridedSliceOptionsEnd(builder)
return BuiltinOperator.STRIDED_SLICE, BuiltinOptions.StridedSliceOptions, options


@_register_op(SqueezeOpr)
def _squeeze(mge_opr, builder):
SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
builder, len(mge_opr.squeeze_dims)
)
for i in mge_opr.squeeze_dims:
builder.PrependInt32(i)
squeeze_dims = builder.EndVector(len(mge_opr.squeeze_dims))
SqueezeOptions.SqueezeOptionsStart(builder)
SqueezeOptions.SqueezeOptionsAddSqueezeDims(builder, squeeze_dims)
options = SqueezeOptions.SqueezeOptionsEnd(builder)
return BuiltinOperator.SQUEEZE, BuiltinOptions.SqueezeOptions, options
7 changes: 7 additions & 0 deletions test/test_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ReduceOpr,
ReshapeOpr,
SoftmaxOpr,
SubtensorOpr,
TransposeOpr,
XORNet,
dump_mge_model,
Expand Down Expand Up @@ -170,6 +171,12 @@ def test_transopse():
_test_convert_result(net.data, tmp_file, mge_result, max_error, nhwc2=False)


def test_slice():
net = SubtensorOpr()
mge_result = dump_mge_model(net, net.data, tmp_file)
_test_convert_result(net.data, tmp_file, mge_result, max_error, False, False)


def test_xornet():
if megengine.__version__ < "1.1.0":
return
Expand Down

0 comments on commit 9a209a8

Please sign in to comment.