Skip to content

Commit

Permalink
fix(mge_transform): consider activation when fusing op
Browse files Browse the repository at this point in the history
  • Loading branch information
Peng Xiong authored and xpmemeda committed Jul 28, 2021
1 parent b18b9be commit 775428b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pytest-transform:
runs-on: ubuntu-latest
container:
image: xpdnbd/mgeconvert_ci:v0.1
image: enginesh233/mgeconvert_ci:v1.0
defaults:
run:
shell: bash
Expand All @@ -43,7 +43,7 @@ jobs:
pytest-cambricon:
runs-on: ubuntu-latest
container:
image: xpdnbd/mgeconvert_ci:v0.1
image: enginesh233/mgeconvert_ci:v1.0
defaults:
run:
shell: bash
Expand All @@ -61,7 +61,7 @@ jobs:
pytest-caffe-and-onnx:
runs-on: ubuntu-latest
container:
image: xpdnbd/mgeconvert_ci:v0.1
image: enginesh233/mgeconvert_ci:v1.0
defaults:
run:
shell: bash
Expand All @@ -73,7 +73,7 @@ jobs:
pytest-tflite:
runs-on: ubuntu-latest
container:
image: xpdnbd/mgeconvert_ci:v0.1
image: enginesh233/mgeconvert_ci:v1.0
defaults:
run:
shell: bash
Expand Down
6 changes: 6 additions & 0 deletions mgeconvert/mge_context/mge_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def _fuse_activation(net):
isinstance(op, ElemwiseOpr) and op.mode in ("RELU", "TANH")
):
prev_op = op.inp_oprs[0]
if prev_op.activation != "IDENTITY":
continue

# activation(relu/relu6/tanh) must be fused with previous opr
activation = getattr(op, "mode", "IDENTITY")
Expand Down Expand Up @@ -728,6 +730,7 @@ def _fuse_for_conv_bias(opr):
conv_node.op._opr,
bias_node.inp_const[0][1],
)
conv_bias.activation = add_opr.activation
conv_bias.inp_vars = conv_node.op.inp_vars + bias_node.op.inp_vars[1:]
conv_bias.out_vars = bias_node.op.out_vars
conv_bias.inp_oprs = conv_node.op.inp_oprs
Expand Down Expand Up @@ -787,6 +790,7 @@ def _fuse_for_deconv_bias(opr):
conv_node.op._opr,
bias_node.inp_const[0][1],
)
deconv_bias.activation = add_opr.activation
deconv_bias.inp_vars = conv_node.op.inp_vars + bias_node.op.inp_vars[1:]
deconv_bias.out_vars = bias_node.op.out_vars
deconv_bias.inp_oprs = conv_node.op.inp_oprs
Expand Down Expand Up @@ -825,13 +829,15 @@ def _fuse_for_fully_connected(opr):
bias_node = PatternNode("ADD", is_output=True)
matrix_mul_node = PatternNode(Ops.MatrixMulOpr.__name__)
bias_node.inp_oprs = [matrix_mul_node]

add_opr = opr.out_oprs[0]
if match(bias_node, add_opr):
fully_connected = Ops.FullyConnectedOpr(
"FullyConnected_" + bias_node.op.name,
matrix_mul_node.op._opr,
bias_node.inp_const[0][1],
)
fully_connected.activation = add_opr.activation
fully_connected.inp_vars = (
matrix_mul_node.op.inp_vars + bias_node.op.inp_vars[1:]
)
Expand Down
2 changes: 2 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_fuse_for_fully_connected():
net = LinearOpr()
dump_mge_model(net, net.data, "test_model.mge")
net = TopologyNetwork("test_model.mge")
optimize_for_conversion(net, TransformerRule.FUSE_ACTIVATION)
optimize_for_conversion(net, TransformerRule.FUSE_FOR_FULLY_CONNECTED)
assert net.all_oprs[-1].activation == "RELU"
actual = list(type(opr).__name__ for opr in net.all_oprs)
desired = ["Host2DeviceCopyOpr", "MatrixMulOpr", "FullyConnectedOpr"]
assert actual == desired
1 change: 1 addition & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self):
def forward(self, x):
x = self.linear(x)
x = self.linear_bias(x)
x = F.relu(x)
return x


Expand Down

0 comments on commit 775428b

Please sign in to comment.