Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Introduce FuseMulWithFullyConnectedPass #13607

Merged
merged 14 commits into from
Aug 20, 2024
1 change: 1 addition & 0 deletions compiler/luci/pass/include/luci/CircleOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class CircleOptimizer final
FuseMeanWithMean,
FuseMulWithConv,
FuseMulWithDiv,
FuseMulWithFullyConnected,
FuseTransposeWithMean,
ResolveCustomOpAdd,
ResolveCustomOpBatchMatMul,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__
#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__

#include <logo/Pass.h>

namespace luci
{

/**
* @brief Class to fuse Mul into CircleFullyConnected
*/
struct FuseMulWithFullyConnectedPass final : public logo::Pass
{
const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; }

bool run(loco::Graph *g) final;
};

} // namespace luci

#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__
5 changes: 5 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FuseMulWithConvPass.h"
#include "luci/Pass/FuseMulWithDivPass.h"
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/FusePReluPass.h"
#include "luci/Pass/FuseGeluPass.h"
Expand Down Expand Up @@ -278,6 +279,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());

if (_options->query(Options::Algorithm::FuseMulWithFullyConnected))
{
phase.emplace_back(std::make_unique<FuseMulWithFullyConnectedPass>());
}
seanshpark marked this conversation as resolved.
Show resolved Hide resolved
if (_options->query(Options::Algorithm::CommonSubExpressionElimination))
{
phase.emplace_back(std::make_unique<luci::CommonSubExpressionEliminationPass>());
Expand Down
209 changes: 209 additions & 0 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/FuseMulWithFullyConnectedPass.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>

#include <cmath>

namespace
{

#define RETURN_FALSE_UNLESS(cond) \
if (not(cond)) \
return false;

inline bool is_scalar(luci::CircleConst *node)
{
return ((node->rank() == 1 || node->rank() == 0) && node->size<loco::DataType::FLOAT32>() == 1);
}

inline void update_with_scalar(luci::CircleConst *fused_node, luci::CircleConst *multiplication)
{
for (uint32_t i = 0; i < fused_node->size<loco::DataType::FLOAT32>(); i++)
{
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0);
}
}

inline void update_weights(luci::CircleConst *weights, luci::CircleConst *multiplication)
seanshpark marked this conversation as resolved.
Show resolved Hide resolved
{
// Scalar multiplication:
if (is_scalar(multiplication))
{
update_with_scalar(weights, multiplication);
}
// N-size multiplication:
else
{
// Go along channels, multiplication size is ensured to be compatible with channels.
auto count = weights->dim(0).value();
auto size = weights->dim(weights->rank() - 1).value();
float val;
for (uint32_t c = 0; c < count; c++)
{
val = multiplication->at<loco::DataType::FLOAT32>(c);
for (uint32_t i = 0; i < size; i++)
{
weights->at<loco::DataType::FLOAT32>(c * size + i) *= val;
}
}
}
}

inline void update_bias(luci::CircleConst *bias, luci::CircleConst *multiplication)
seanshpark marked this conversation as resolved.
Show resolved Hide resolved
{
// Scalar multiplication:
if (is_scalar(multiplication))
{
update_with_scalar(bias, multiplication);
}
// N-size multiplication:
else
{
// Go along channels, multiplication size is ensured to be compatible with channels.
for (uint32_t i = 0; i < bias->size<loco::DataType::FLOAT32>(); i++)
{
bias->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i);
}
}
}

/**
* Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant
*
* BEFORE
* |
* [CircleFullyConnected]
* |
* [CircleMul]
* |
*
* AFTER
* |
* [CircleFullyConnected] [CircleMul] (dead)
* |
*
*/
bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
{
// Sanity check:
RETURN_FALSE_UNLESS(fc);
// Allow only FLOAT32 data type:
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32);
// Allow only without activation functions as values are going to
// be multiplied before activation function.
RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE);
// Check for weights being Constant:
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights());
RETURN_FALSE_UNLESS(weights);
// Get Mul node:
auto fc_output = loco::succs(fc);
// Make sure that FullyConnected has only one child:
RETURN_FALSE_UNLESS(fc_output.size() == 1);
auto mul = dynamic_cast<luci::CircleMul *>(*fc_output.begin());
RETURN_FALSE_UNLESS(mul);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually go upward(to input).
For this case, begin with Mul and check one of the input is FC.
Can you please revise in this check order?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you traverse from Mul to FC, FC input of Mul should be created and replaced to new FC as exiting FC may have multiple successors.

Current check only works when there exist single connection of FC-Mul.

Copy link
Contributor Author

@jiwaszki jiwaszki Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see approach similar to mine on other pass which base on FC as well (with Add instead of Mul):

bool fuse_add_with_fc(luci::CircleFullyConnected *fc)

// Get add node
auto fc_output = loco::succs(fc);
if (fc_output.size() != 1)
return false;

I want to make sure I understand the motivation here. Here is my understanding of a graph and the approach to pass you described:
image

In case of such pattern of single connection of FC-Mul, I understand the fusion of Mul into FC to eliminate node and reduce overall number of operations. However, if FC has more operations connected to it's output, the pass is simply replacing one op for another -- which from "math view" is resulting in the same count of operations.

Is that a matter of existing optimizations in targeted HW and/or more performant kernel implementations in ONE that FC is preferred here in place of Mul?

EDIT: I would understand a case where we look for Mul and check that connected FC has only one child and it's Mul itself. I wonder if that approach would work better as probably there are less Mul nodes to perform a check on than FC nodes (at least in popular models). I would need to have a better understanding of ONE method to applying passes to judge that. @seanshpark is that something that can be considered instead of adding multiple FC?

Copy link
Contributor

@seanshpark seanshpark Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a matter of existing optimizations in targeted HW and/or more performant kernel implementations in ONE that FC is preferred here in place of Mul?

No. I don't think so.

is that something that can be considered instead of adding multiple FC?

I cannot understand your question point.

Copy link
Contributor

@jinevening jinevening Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see approach similar to mine on other pass which base on FC as well (with Add instead of Mul):

FuseAddWithFullyConnected fuses add with FC only when FC has a single successor (add). That is an intended behavior. If FC has multiple successors, FC layers are duplicated (as @jiwaszki said), which tends to increase model size and degrade model performance. I think FuseMulWithFullyConnected can follow the same strategy with FuseAddWithFullyConnected.

So, +1 for the current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I DISAGREE. I DO NOT WANT MULTIPLE WAYS OF SEARCHING.
SOME DAY FuseAddWithFullyConnected SHOULD BE FIXED TOO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanshpark How about the changes in 79a2213 ? It merges both ideas into one.
First, it follows the bottom-up approach, starting from mul:

bool fuse_mul_with_fc(luci::CircleMul *mul)

// Check if any FC node connects to Mul.
// Find the pattern of Mul(FC, CircleConst):
luci::CircleFullyConnected *fc = nullptr;
luci::CircleConst *multiplication = nullptr;
RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul));

Second, still has the intended behavior of fusing only when there is only one successor to fc:

// Make sure that FullyConnected has only one successor:
RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1);

@jinevening please also review the new commit when you have time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me. Please address @seanshpark 's comments.

// Allow Mul node only with FLOAT32 data type:
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32);
// Get multiplication Constant (here: the second input besides weights):
auto multiplication = mul->x() == fc ? dynamic_cast<luci::CircleConst *>(mul->y())
: dynamic_cast<luci::CircleConst *>(mul->x());
RETURN_FALSE_UNLESS(multiplication);
// Get rank of multiplication:
auto rank = multiplication->rank();
RETURN_FALSE_UNLESS(rank != 0);
jinevening marked this conversation as resolved.
Show resolved Hide resolved
// Check that all dimensions are ones, checks broadcast capabilites.
// Last dimesion of multiplication must be compatible with FC.
// N-D case (N>1):
if (multiplication->rank() > 1)
{
// Check channel-wise broadcasting:
for (uint32_t i = 0; i < rank - 1; i++)
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1);
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0));
}
// Scalar case:
else if (multiplication->rank() == 1 || multiplication->rank() == 0)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() != 0);
}
jinevening marked this conversation as resolved.
Show resolved Hide resolved

// Only supports:
// (1) constant bias
// (2) no bias
auto bias = loco::must_cast<luci::CircleNode *>(fc->bias());
RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST or
bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
// Create new bias to be updated with values:
auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias());
RETURN_FALSE_UNLESS(const_bias)
jinevening marked this conversation as resolved.
Show resolved Hide resolved
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);

auto fused_bias = luci::clone(const_bias);
// Create new weights to be updated with values:
auto fused_weights = luci::clone(weights);

// Update bias accordingly:
update_bias(fused_bias, multiplication);
// Update weights accordingly:
update_weights(fused_weights, multiplication);
jinevening marked this conversation as resolved.
Show resolved Hide resolved

// Replace weights and bias:
fc->weights(fused_weights);
fc->bias(fused_bias);

// Set origin and copy Activation Function if exisitng:
fc->fusedActivationFunction(mul->fusedActivationFunction());
luci::add_origin(fc, luci::get_origin(mul));

replace(mul).with(fc);

return true;
}

} // namespace

namespace luci
{

bool FuseMulWithFullyConnectedPass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
if (not fc)
continue;

switch (fc->dtype())
{
case loco::DataType::FLOAT32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype check is duplicate with fuse_mul_with_fc().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) why use switch case with dtype here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in fact redundant check, leftover from refactoring. Removed and further simplified.

if (fuse_mul_with_fc(fc))
changed = true;
break;
default:
break;
}
}

return changed;
}

} // namespace luci
Loading