Skip to content

Commit

Permalink
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwaszki committed Aug 14, 2024
2 parents 27dec03 + 79a2213 commit 7ea759a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 31 deletions.
37 changes: 17 additions & 20 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

#include "luci/Pass/FuseMulWithFullyConnectedPass.h"

#include "helpers/NodeFiller.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>

#include <cmath>

namespace
{

Expand Down Expand Up @@ -107,10 +107,19 @@ luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleCon
* |
*
*/
bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
bool fuse_mul_with_fc(luci::CircleMul *mul)
{
// Sanity check:
RETURN_FALSE_UNLESS(fc);
RETURN_FALSE_UNLESS(mul);
// Allow Mul node only with FLOAT32 data type:
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32);
// Check if any FC node connects to Mul.
// Find the pattern of Mul(FC, CircleConst):
luci::CircleFullyConnected *fc = nullptr;
luci::CircleConst *multiplication = nullptr;
RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul));
// Make sure that FullyConnected has only one successor:
RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1);
// Allow only FLOAT32 data type:
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32);
// Allow only without activation functions as values are going to
Expand All @@ -119,18 +128,6 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
// Check for weights being Constant:
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
RETURN_FALSE_UNLESS(weights);
// Get Mul node:
auto fc_output = loco::succs(fc);
// Make sure that FullyConnected has only one child:
RETURN_FALSE_UNLESS(fc_output.size() == 1);
auto mul = dynamic_cast<luci::CircleMul *>(*fc_output.begin());
RETURN_FALSE_UNLESS(mul);
// Allow Mul node only with FLOAT32 data type:
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32);
// Get multiplication Constant (here: the second input besides weights):
auto multiplication = mul->x() == fc ? dynamic_cast<luci::CircleConst *>(mul->y())
: dynamic_cast<luci::CircleConst *>(mul->x());
RETURN_FALSE_UNLESS(multiplication);
// Get rank of multiplication:
auto rank = multiplication->rank();
// Check that all dimensions are ones, checks broadcast capabilites.
Expand Down Expand Up @@ -197,14 +194,14 @@ bool FuseMulWithFullyConnectedPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
if (not fc)
auto mul = dynamic_cast<luci::CircleMul *>(node);
if (not mul)
continue;

switch (fc->dtype())
switch (mul->dtype())
{
case loco::DataType::FLOAT32:
if (fuse_mul_with_fc(fc))
if (fuse_mul_with_fc(mul))
changed = true;
break;
default:
Expand Down
65 changes: 54 additions & 11 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,31 @@ using namespace luci::test;
/**
* Graph for this test
*
* BEFORE
* BEFORE (without extra_fc_successor)
*
* [FC]
* |
* [Mul w/ Relu]
*
* AFTER
* BEFORE (with extra_fc_successor)
*
* [FC]
* |
* |-------------------
* | |
* | |
* [Mul w/ Relu] [other FC]
*
* AFTER (if pass applied)
*
* [FC w/ Relu] (weights and bias updated)
*
*/
class FCMulGraphlet
{
public:
void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias)
void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias,
bool extra_successor)
{
_fc = g->nodes()->create<luci::CircleFullyConnected>();

Expand Down Expand Up @@ -79,6 +89,22 @@ class FCMulGraphlet
_fc->shape({1, DIM_ONE});
_fc->name("fc");

if (extra_successor)
{
_extra_succ = g->nodes()->create<luci::CircleFullyConnected>();
// Set previous FC as input to bump number of successors for it:
_extra_succ->input(_fc);
std::vector<float> weights_val(DIM_ONE * DIM_TWO);
_extra_f =
luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val);
_extra_succ->weights(_extra_f);
_extra_succ->bias(nullptr);
_extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE);
_extra_succ->dtype(loco::DataType::FLOAT32);
_extra_succ->shape({1, DIM_ONE});
_extra_succ->name("extra_fc");
}

std::vector<float> mul_values;

if (is_mul_scalar)
Expand Down Expand Up @@ -128,15 +154,18 @@ class FCMulGraphlet
luci::CircleConst *_fc_f = nullptr;
luci::CircleNode *_fc_b = nullptr;
luci::CircleConst *_mul_c = nullptr;
luci::CircleFullyConnected *_extra_succ = nullptr;
luci::CircleConst *_extra_f = nullptr;
};

class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet
{
public:
void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias)
void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias,
bool extra_successor)
{
TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE});
FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias);
FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias, extra_successor);

_fc->input(input());

Expand All @@ -155,7 +184,8 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test

TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor)
{
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */);
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */,
false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

Expand Down Expand Up @@ -184,7 +214,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor)

TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar)
{
g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */);
g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */,
false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

Expand Down Expand Up @@ -213,7 +244,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar)

TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias)
{
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */);
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */,
false /* extra_successor */);

EXPECT_EQ(true, pass.run(g.g()));

Expand All @@ -238,7 +270,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias)

TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */);
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */,
false /* extra_successor */);

// Bias cannot be fused as it's passed as feature map.
g.to_fm_bias();
Expand All @@ -248,16 +281,26 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG)

TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG)
{
g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */);
g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */,
false /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_null_weights_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */);
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */,
false /* extra_successor */);

g.fc()->weights(nullptr);

EXPECT_EQ(false, pass.run(g.g()));
}

TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_extra_successor_NEG)
{
g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */,
true /* extra_successor */);

EXPECT_EQ(false, pass.run(g.g()));
}

0 comments on commit 7ea759a

Please sign in to comment.