diff --git a/mgeconvert/mge_context/mge_transform.py b/mgeconvert/mge_context/mge_transform.py index b849018..a5350c8 100644 --- a/mgeconvert/mge_context/mge_transform.py +++ b/mgeconvert/mge_context/mge_transform.py @@ -251,12 +251,12 @@ def _fuse_activation(net): activation = "RELU6" if isinstance(op, Relu6Opr) else activation prev_op.activation = activation prev_op.out_vars = op.out_vars - if len(op.out_oprs) > 0: - idx = op.out_oprs[0].inp_oprs.index(op) - op.out_oprs[0].inp_oprs[idx] = prev_op - prev_op.out_oprs = [op.out_oprs[0]] - else: - prev_op.out_oprs = [] + + for post_op in op.out_oprs: + idx = post_op.inp_oprs.index(op) + post_op.inp_oprs[idx] = prev_op + if post_op not in prev_op.out_oprs: + prev_op.out_oprs.append(post_op) delete_intended.append(net._opr_ids.index(op_id)) @@ -873,4 +873,12 @@ def _replace_opr(net, matches: List[Ops.MgeOpr]): max_idx = max(net._opr_ids.index(i.id) for i in opr.inp_oprs) net._opr_ids.insert(max_idx + 1, opr.id) net.all_oprs.insert(max_idx + 1, opr) - net.all_oprs = list(filter(lambda opr: not opr.skip, net.all_oprs)) + new_idxs = [] + new_oprs = [] + for idx, opr in zip(net._opr_ids, net.all_oprs): + if opr.skip: + continue + new_idxs.append(idx) + new_oprs.append(opr) + net._opr_ids = new_idxs + net.all_oprs = new_oprs