-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Introduce FuseMulWithFullyConnectedPass #13607
Changes from 11 commits
40dcacf
f661561
85d9783
e3b354e
31e25ed
8b17f47
9e22b26
8977ef9
b085181
1aa79cc
1bb278d
79a2213
550e798
48defc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ | ||
#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ | ||
|
||
#include <logo/Pass.h> | ||
|
||
namespace luci | ||
{ | ||
|
||
/** | ||
* @brief Class to fuse Mul into CircleFullyConnected | ||
*/ | ||
struct FuseMulWithFullyConnectedPass final : public logo::Pass | ||
{ | ||
const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; } | ||
|
||
bool run(loco::Graph *g) final; | ||
}; | ||
|
||
} // namespace luci | ||
|
||
#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "luci/Pass/FuseMulWithFullyConnectedPass.h" | ||
|
||
#include <luci/IR/CircleNodes.h> | ||
#include <luci/Service/Nodes/CircleConst.h> | ||
#include <luci/Profile/CircleNodeOrigin.h> | ||
|
||
#include <cmath> | ||
|
||
namespace | ||
{ | ||
|
||
#define RETURN_FALSE_UNLESS(cond) \ | ||
if (not(cond)) \ | ||
return false; | ||
|
||
inline bool is_single_element(const luci::CircleConst *node) | ||
{ | ||
return ((node->rank() == 1 || node->rank() == 0) && node->size<loco::DataType::FLOAT32>() == 1); | ||
} | ||
|
||
inline void update_with_single_element(luci::CircleConst *fused_node, | ||
const luci::CircleConst *multiplication) | ||
{ | ||
for (uint32_t i = 0; i < fused_node->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
fused_node->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(0); | ||
} | ||
} | ||
|
||
luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, | ||
const luci::CircleConst *multiplication) | ||
{ | ||
auto fused_weights = luci::clone(weights); | ||
// Single element multiplication: | ||
if (is_single_element(multiplication)) | ||
{ | ||
update_with_single_element(fused_weights, multiplication); | ||
} | ||
// N-size multiplication: | ||
else | ||
{ | ||
// Go along channels, multiplication size is ensured to be compatible with channels. | ||
auto count = fused_weights->dim(0).value(); | ||
auto size = fused_weights->dim(fused_weights->rank() - 1).value(); | ||
float val; | ||
for (uint32_t c = 0; c < count; c++) | ||
{ | ||
val = multiplication->at<loco::DataType::FLOAT32>(c); | ||
for (uint32_t i = 0; i < size; i++) | ||
{ | ||
fused_weights->at<loco::DataType::FLOAT32>(c * size + i) *= val; | ||
} | ||
} | ||
} | ||
return fused_weights; | ||
} | ||
|
||
luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) | ||
{ | ||
auto fused_bias = luci::clone(bias); | ||
// Single element multiplication: | ||
if (is_single_element(multiplication)) | ||
{ | ||
update_with_single_element(fused_bias, multiplication); | ||
} | ||
// N-size multiplication: | ||
else | ||
{ | ||
// Go along channels, multiplication size is ensured to be compatible with channels. | ||
for (uint32_t i = 0; i < fused_bias->size<loco::DataType::FLOAT32>(); i++) | ||
{ | ||
fused_bias->at<loco::DataType::FLOAT32>(i) *= multiplication->at<loco::DataType::FLOAT32>(i); | ||
} | ||
} | ||
return fused_bias; | ||
} | ||
|
||
/** | ||
* Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant | ||
* | ||
* BEFORE | ||
* | | ||
* [CircleFullyConnected] | ||
* | | ||
* [CircleMul] | ||
* | | ||
* | ||
* AFTER | ||
* | | ||
* [CircleFullyConnected] [CircleMul] (dead) | ||
* | | ||
* | ||
*/ | ||
bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) | ||
{ | ||
// Sanity check: | ||
RETURN_FALSE_UNLESS(fc); | ||
// Allow only FLOAT32 data type: | ||
RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); | ||
// Allow only without activation functions as values are going to | ||
// be multiplied before activation function. | ||
RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE); | ||
// Check for weights being Constant: | ||
auto weights = dynamic_cast<luci::CircleConst *>(fc->weights()); | ||
RETURN_FALSE_UNLESS(weights); | ||
// Get Mul node: | ||
auto fc_output = loco::succs(fc); | ||
// Make sure that FullyConnected has only one child: | ||
RETURN_FALSE_UNLESS(fc_output.size() == 1); | ||
auto mul = dynamic_cast<luci::CircleMul *>(*fc_output.begin()); | ||
RETURN_FALSE_UNLESS(mul); | ||
// Allow Mul node only with FLOAT32 data type: | ||
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); | ||
// Get multiplication Constant (here: the second input besides weights): | ||
auto multiplication = mul->x() == fc ? dynamic_cast<luci::CircleConst *>(mul->y()) | ||
: dynamic_cast<luci::CircleConst *>(mul->x()); | ||
RETURN_FALSE_UNLESS(multiplication); | ||
// Get rank of multiplication: | ||
auto rank = multiplication->rank(); | ||
// Check that all dimensions are ones, checks broadcast capabilites. | ||
// Last dimesion of multiplication must be compatible with FC. | ||
// N-D case (N>1): | ||
if (multiplication->rank() > 1) | ||
{ | ||
// Check channel-wise broadcasting: | ||
for (uint32_t i = 0; i < rank - 1; i++) | ||
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1); | ||
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected | ||
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); | ||
} | ||
// 1-D or scalar case: | ||
else if (multiplication->rank() == 1) | ||
{ | ||
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1 || | ||
multiplication->size<loco::DataType::FLOAT32>() == weights->dim(0)); | ||
} | ||
else if (multiplication->rank() == 0) | ||
{ | ||
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1); | ||
} | ||
|
||
// Only supports: | ||
// (1) constant bias | ||
// (2) no bias | ||
auto bias = loco::must_cast<luci::CircleNode *>(fc->bias()); | ||
if (bias->opcode() == luci::CircleOpcode::CIRCLECONST) | ||
{ | ||
// Create new bias to be updated with values: | ||
auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()); | ||
RETURN_FALSE_UNLESS(const_bias) | ||
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); | ||
// Create new bias with updated values and replace: | ||
auto fused_bias = gen_fused_bias(const_bias, multiplication); | ||
fc->bias(fused_bias); | ||
} | ||
else if (bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) | ||
{ | ||
return false; | ||
} | ||
|
||
// Create new weights with updated values and replace: | ||
auto fused_weights = gen_fused_weights(weights, multiplication); | ||
fc->weights(fused_weights); | ||
|
||
// Set origin and copy Activation Function if exisitng: | ||
fc->fusedActivationFunction(mul->fusedActivationFunction()); | ||
luci::add_origin(fc, luci::get_origin(mul)); | ||
|
||
replace(mul).with(fc); | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
namespace luci | ||
{ | ||
|
||
bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) | ||
{ | ||
bool changed = false; | ||
for (auto node : loco::active_nodes(loco::output_nodes(g))) | ||
{ | ||
auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); | ||
if (not fc) | ||
continue; | ||
|
||
switch (fc->dtype()) | ||
{ | ||
case loco::DataType::FLOAT32: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q) why use switch case with dtype here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's in fact redundant check, leftover from refactoring. Removed and further simplified. |
||
if (fuse_mul_with_fc(fc)) | ||
changed = true; | ||
break; | ||
default: | ||
break; | ||
} | ||
} | ||
|
||
return changed; | ||
} | ||
|
||
} // namespace luci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually go upward(to input).
For this case, begin with
Mul
and check one of the input isFC
.Can you please revise in this check order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you traverse from
Mul
toFC
,FC
input ofMul
should be created and replaced to newFC
as exitingFC
may have multiple successors.Current check only works when there exist single connection of
FC
-Mul
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see approach similar to mine on other pass which base on
FC
as well (withAdd
instead ofMul
):ONE/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
Line 43 in 719a6f7
ONE/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
Lines 58 to 61 in 719a6f7
I want to make sure I understand the motivation here. Here is my understanding of a graph and the approach to pass you described:
In case of such pattern of single connection of
FC-Mul
, I understand the fusion ofMul
intoFC
to eliminate node and reduce overall number of operations. However, ifFC
has more operations connected to it's output, the pass is simply replacing one op for another -- which from "math view" is resulting in the same count of operations.Is that a matter of existing optimizations in targeted HW and/or more performant kernel implementations in ONE that
FC
is preferred here in place ofMul
?EDIT: I would understand a case where we look for
Mul
and check that connectedFC
has only one child and it'sMul
itself. I wonder if that approach would work better as probably there are lessMul
nodes to perform a check on thanFC
nodes (at least in popular models). I would need to have a better understanding of ONE method to applying passes to judge that. @seanshpark is that something that can be considered instead of adding multipleFC
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. I don't think so.
I cannot understand your question point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FuseAddWithFullyConnected
fuses add with FC only when FC has a single successor (add). That is an intended behavior. If FC has multiple successors, FC layers are duplicated (as @jiwaszki said), which tends to increase model size and degrade model performance. I thinkFuseMulWithFullyConnected
can follow the same strategy withFuseAddWithFullyConnected
.So, +1 for the current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I DISAGREE. I DO NOT WANT MULTIPLE WAYS OF SEARCHING.
SOME DAY FuseAddWithFullyConnected SHOULD BE FIXED TOO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanshpark How about the changes in 79a2213 ? It merges both ideas into one.
First, it follows the bottom-up approach, starting from
mul
:ONE/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Line 110 in 79a2213
ONE/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Lines 116 to 120 in 79a2213
Second, still has the intended behavior of fusing only when there is only one successor to
fc
:ONE/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Lines 121 to 122 in 79a2213
@jinevening please also review the new commit when you have time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good to me. Please address @seanshpark 's comments.