Skip to content

Commit

Permalink
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwaszki committed Aug 19, 2024
2 parents 7ea759a + 550e798 commit bda96d8
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,39 @@ bool fuse_mul_with_fc(luci::CircleMul *mul)
luci::CircleFullyConnected *fc = nullptr;
luci::CircleConst *multiplication = nullptr;
RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul));
// Make sure that FullyConnected has only one successor:
/**
* Make sure that FullyConnected has only one successor.
*
* If the FullyConnected output is connected to more nodes,
* this pass will replace node with new fused FullyConnected.
* Thus pass success will only introduce extra FullyConnected
* without reducing overall number of nodes.
* Which tends to increase model's size and degrades model's performance.
* Thus one successor is required to benefit from this pass.
*
* Example graph that illustrates the described scenario:
*
* BEFORE
* |
* [CircleFullyConnected]
* |
* +-------+----------------+
* | |
* | |
* [Other Node] [CircleMul]
* | |
*
* AFTER
* |
* [CircleFullyConnected]
* |
* +-------+-----------------------+
* | |
* | |
* [Other Node] [New CircleFullyConnected Fused with Mul]
* | |
*
*/
RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1);
// Allow only FLOAT32 data type:
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32);
Expand Down Expand Up @@ -194,18 +226,10 @@ bool FuseMulWithFullyConnectedPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto mul = dynamic_cast<luci::CircleMul *>(node);
if (not mul)
continue;

switch (mul->dtype())
if (auto mul = dynamic_cast<luci::CircleMul *>(node))
{
case loco::DataType::FLOAT32:
if (fuse_mul_with_fc(mul))
changed = true;
break;
default:
break;
if (fuse_mul_with_fc(mul))
changed = true;
}
}

Expand Down

0 comments on commit bda96d8

Please sign in to comment.