diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 9cbd26f0da5..8a1eb6d4f78 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -49,6 +49,7 @@ class CircleOptimizer final FuseMeanWithMean, FuseMulWithConv, FuseMulWithDiv, + FuseMulWithFullyConnected, FuseTransposeWithMean, ResolveCustomOpAdd, ResolveCustomOpBatchMatMul, diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h new file mode 100644 index 00000000000..718039f1c69 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ +#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul into CircleFullyConnected + */ +struct FuseMulWithFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 840c8dd25dd..246d4f36e78 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -48,6 +48,7 @@ #include "luci/Pass/FuseMeanWithMeanPass.h" #include "luci/Pass/FuseMulWithConvPass.h" #include "luci/Pass/FuseMulWithDivPass.h" +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" @@ -310,6 +311,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseMulWithFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp new file mode 100644 index 00000000000..3049862e216 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" + +#include "helpers/NodeFiller.h" + +#include +#include +#include + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +inline bool is_single_element(const luci::CircleConst *node) +{ + return ((node->rank() == 1 || node->rank() == 0) && node->size() == 1); +} + +inline void update_with_single_element(luci::CircleConst *fused_node, + const luci::CircleConst *multiplication) +{ + for (uint32_t i = 0; i < fused_node->size(); i++) + { + fused_node->at(i) *= multiplication->at(0); + } +} + +luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, + const luci::CircleConst *multiplication) +{ + auto fused_weights = luci::clone(weights); + // Single element multiplication: + if (is_single_element(multiplication)) + { + update_with_single_element(fused_weights, multiplication); + } + // N-size multiplication: + else + { + // Go along channels, multiplication size is ensured to be compatible with channels. + auto count = fused_weights->dim(0).value(); + auto size = fused_weights->dim(fused_weights->rank() - 1).value(); + float val; + for (uint32_t c = 0; c < count; c++) + { + val = multiplication->at(c); + for (uint32_t i = 0; i < size; i++) + { + fused_weights->at(c * size + i) *= val; + } + } + } + return fused_weights; +} + +luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) +{ + auto fused_bias = luci::clone(bias); + // Single element multiplication: + if (is_single_element(multiplication)) + { + update_with_single_element(fused_bias, multiplication); + } + // N-size multiplication: + else + { + // Go along channels, multiplication size is ensured to be compatible with channels. + for (uint32_t i = 0; i < fused_bias->size(); i++) + { + fused_bias->at(i) *= multiplication->at(i); + } + } + return fused_bias; +} + +/** + * Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant + * + * BEFORE + * | + * [CircleFullyConnected] + * | + * [CircleMul] + * | + * + * AFTER + * | + * [CircleFullyConnected] [CircleMul] (dead) + * | + * + */ +bool fuse_mul_with_fc(luci::CircleMul *mul) +{ + // Sanity check: + RETURN_FALSE_UNLESS(mul); + // Allow Mul node only with FLOAT32 data type: + RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); + // Check if any FC node connects to Mul. + // Find the pattern of Mul(FC, CircleConst): + luci::CircleFullyConnected *fc = nullptr; + luci::CircleConst *multiplication = nullptr; + RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul)); + /** + * Make sure that FullyConnected has only one successor. + * + * If the FullyConnected output is connected to more nodes, + * this pass will replace node with new fused FullyConnected. + * Thus pass success will only introduce extra FullyConnected + * without reducing overall number of nodes. + * Which tends to increase model's size and degrades model's performance. + * Thus one successor is required to benefit from this pass. + * + * Example graph that illustrates the described scenario: + * + * BEFORE + * | + * [CircleFullyConnected] + * | + * +-------+----------------+ + * | | + * | | + * [Other Node] [CircleMul] + * | | + * + * AFTER + * | + * +-----------------------+ + * | | + * | | + * [CircleFullyConnected] | + * | | + * +-------+ | + * | | + * | | + * [Other Node] [New CircleFullyConnected Fused with Mul] + * | | + * + */ + RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1); + // Allow only FLOAT32 data type: + RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); + // Allow only without activation functions as values are going to + // be multiplied before activation function. + RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE); + // Check for weights being Constant: + auto weights = dynamic_cast(fc->weights()); + RETURN_FALSE_UNLESS(weights); + // Get rank of multiplication: + auto rank = multiplication->rank(); + // Check that all dimensions are ones, checks broadcast capabilites. + // Last dimesion of multiplication must be compatible with FC. + // N-D case (N>1): + if (multiplication->rank() > 1) + { + // Check channel-wise broadcasting: + for (uint32_t i = 0; i < rank - 1; i++) + RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1); + // Check the last dimesion of Mul is the same with the first dimension of FullyConnected + RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); + } + // 1-D or scalar case: + else if (multiplication->rank() == 1) + { + RETURN_FALSE_UNLESS(multiplication->size() == 1 || + multiplication->size() == weights->dim(0)); + } + else if (multiplication->rank() == 0) + { + RETURN_FALSE_UNLESS(multiplication->size() == 1); + } + + // Only supports: + // (1) constant bias + // (2) no bias + auto bias = loco::must_cast(fc->bias()); + if (bias->opcode() == luci::CircleOpcode::CIRCLECONST) + { + // Create new bias to be updated with values: + auto const_bias = dynamic_cast(fc->bias()); + RETURN_FALSE_UNLESS(const_bias) + RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); + // Create new bias with updated values and replace: + auto fused_bias = gen_fused_bias(const_bias, multiplication); + fc->bias(fused_bias); + } + else if (bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + { + return false; + } + + // Create new weights with updated values and replace: + auto fused_weights = gen_fused_weights(weights, multiplication); + fc->weights(fused_weights); + + // Set origin and copy Activation Function if exisitng: + fc->fusedActivationFunction(mul->fusedActivationFunction()); + luci::add_origin(fc, luci::get_origin(mul)); + + replace(mul).with(fc); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + if (auto mul = dynamic_cast(node)) + { + if (fuse_mul_with_fc(mul)) + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp new file mode 100644 index 00000000000..a4f9d6bf087 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" +#include "helpers/CreateCircleConst.h" + +#include +#include + +#include + +#define DIM_ONE 8 +#define DIM_TWO 4 +#define MUL_VAL 2.0f + +namespace +{ + +using namespace luci::test; + +/** + * Graph for this test + * + * BEFORE (without extra_fc_successor) + * + * [FC] + * | + * [Mul w/ Relu] + * + * BEFORE (with extra_fc_successor) + * + * [FC] + * | + * |------------------- + * | | + * | | + * [Mul w/ Relu] [other FC] + * + * AFTER (if pass applied) + * + * [FC w/ Relu] (weights and bias updated) + * + */ +class FCMulGraphlet +{ +public: + void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) + { + _fc = g->nodes()->create(); + + std::vector weights_val(DIM_ONE * DIM_TWO); + for (uint32_t i = 0; i < DIM_ONE * DIM_TWO; i++) + weights_val.at(i) = i; + + _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + _fc->weights(_fc_f); + + if (use_bias) + { + std::vector bias_val(DIM_ONE); + for (uint32_t i = 0; i < DIM_ONE; i++) + bias_val.at(i) = i; + + _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); + } + else + { + // Create CircleOutputExclude -- no bias + _fc_b = g->nodes()->create(); + } + _fc->bias(_fc_b); + + _fc->fusedActivationFunction(fc_activation); + _fc->dtype(loco::DataType::FLOAT32); + _fc->shape({1, DIM_ONE}); + _fc->name("fc"); + + if (extra_successor) + { + _extra_succ = g->nodes()->create(); + // Set previous FC as input to bump number of successors for it: + _extra_succ->input(_fc); + std::vector weights_val(DIM_ONE * DIM_TWO); + _extra_f = + luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + _extra_succ->weights(_extra_f); + _extra_succ->bias(nullptr); + _extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE); + _extra_succ->dtype(loco::DataType::FLOAT32); + _extra_succ->shape({1, DIM_ONE}); + _extra_succ->name("extra_fc"); + } + + std::vector mul_values; + + if (is_mul_scalar) + { + mul_values.push_back(static_cast(MUL_VAL)); + _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {}, mul_values); + } + else + { + for (uint32_t i = 0; i < DIM_ONE; i++) + { + mul_values.push_back(static_cast(i)); + } + _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 1, 1, DIM_ONE}, mul_values); + } + + _mul = g->nodes()->create(); + _mul->x(_fc); + _mul->y(_mul_c); + _mul->fusedActivationFunction(luci::FusedActFunc::RELU); + _mul->dtype(loco::DataType::FLOAT32); + if (is_mul_scalar) + { + _mul->shape({1, DIM_ONE}); + } + else + { + _mul->shape({1, 1, 1, DIM_ONE}); + } + _mul->name("mul"); + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + + void to_fm_bias(void) + { + assert(_fc != nullptr); + + auto new_fc = _fc->graph()->nodes()->create(); + _fc->bias(new_fc); + } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_fc_f = nullptr; + luci::CircleNode *_fc_b = nullptr; + luci::CircleConst *_mul_c = nullptr; + luci::CircleFullyConnected *_extra_succ = nullptr; + luci::CircleConst *_extra_f = nullptr; +}; + +class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet +{ +public: + void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) + { + TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); + FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias, extra_successor); + + _fc->input(input()); + + output()->from(_mul); + } +}; + +class FuseMulWithFullyConnectedPassTest : public ::testing::Test +{ +public: + FuseMulWithFCTestGraph g; + luci::FuseMulWithFullyConnectedPass pass; +}; + +} // namespace + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(i * offset, weights->at(offset)); + } + } + + auto bias = loco::must_cast(g.fc()->bias()); + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(i * i, bias->at(i)); + } +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) +{ + g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(MUL_VAL * offset, weights->at(offset)); + } + } + + auto bias = loco::must_cast(g.fc()->bias()); + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(MUL_VAL * i, bias->at(i)); + } +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */, + false /* extra_successor */); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + auto no_bias = dynamic_cast(fc->bias()); + ASSERT_NE(nullptr, no_bias); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(i * offset, weights->at(offset)); + } + } +} + +TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); + + // Bias cannot be fused as it's passed as feature map. + g.to_fm_bias(); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) +{ + g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_null_weights_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); + + g.fc()->weights(nullptr); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_extra_successor_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + true /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +}