From 40dcacf41b541c93f84916a9cb0f75a0a9a59904 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 6 Aug 2024 17:38:56 +0200 Subject: [PATCH 01/14] [luci/pass] Introduce FuseMulWithFullyConnectedPass This commit introduce FuseMulWithFullyConnectedPass which will fuse Mul to previous FullyConnected if possible. ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz --- .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../luci/Pass/FuseMulWithFullyConnectedPass.h | 37 +++ compiler/luci/pass/src/CircleOptimizer.cpp | 5 + .../src/FuseMulWithFullyConnectedPass.cpp | 209 ++++++++++++++++ .../FuseMulWithFullyConnectedPass.test.cpp | 223 ++++++++++++++++++ 5 files changed, 475 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h create mode 100644 compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp create mode 100644 compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 9cbd26f0da5..8a1eb6d4f78 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -49,6 +49,7 @@ class CircleOptimizer final FuseMeanWithMean, FuseMulWithConv, FuseMulWithDiv, + FuseMulWithFullyConnected, FuseTransposeWithMean, ResolveCustomOpAdd, ResolveCustomOpBatchMatMul, diff --git a/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h new file mode 100644 index 00000000000..718039f1c69 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseMulWithFullyConnectedPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ +#define __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse Mul into CircleFullyConnected + */ +struct FuseMulWithFullyConnectedPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseMulWithFullyConnectedPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_MUL_WITH_FULLYCONNECTED_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 840c8dd25dd..27cf27e63fd 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -48,6 +48,7 @@ #include "luci/Pass/FuseMeanWithMeanPass.h" #include "luci/Pass/FuseMulWithConvPass.h" #include "luci/Pass/FuseMulWithDivPass.h" +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" @@ -278,6 +279,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const phase.emplace_back(std::make_unique()); phase.emplace_back(std::make_unique()); + if (_options->query(Options::Algorithm::FuseMulWithFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::CommonSubExpressionElimination)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp new file mode 100644 index 00000000000..388c6359138 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" + +#include +#include +#include + +#include + +namespace +{ + +#define RETURN_FALSE_UNLESS(cond) \ + if (not(cond)) \ + return false; + +inline bool is_scalar(luci::CircleConst *node) +{ + return ((node->rank() == 1 || node->rank() == 0) && node->size() == 1); +} + +inline void update_with_scalar(luci::CircleConst *fused_node, luci::CircleConst *multiplication) +{ + for (uint32_t i = 0; i < fused_node->size(); i++) + { + fused_node->at(i) *= multiplication->at(0); + } +} + +inline void update_weights(luci::CircleConst *weights, luci::CircleConst *multiplication) +{ + // Scalar multiplication: + if (is_scalar(multiplication)) + { + update_with_scalar(weights, multiplication); + } + // N-size multiplication: + else + { + // Go along channels, multiplication size is ensured to be compatible with channels. + auto count = weights->dim(0).value(); + auto size = weights->dim(weights->rank() - 1).value(); + float val; + for (uint32_t c = 0; c < count; c++) + { + val = multiplication->at(c); + for (uint32_t i = 0; i < size; i++) + { + weights->at(c * size + i) *= val; + } + } + } +} + +inline void update_bias(luci::CircleConst *bias, luci::CircleConst *multiplication) +{ + // Scalar multiplication: + if (is_scalar(multiplication)) + { + update_with_scalar(bias, multiplication); + } + // N-size multiplication: + else + { + // Go along channels, multiplication size is ensured to be compatible with channels. + for (uint32_t i = 0; i < bias->size(); i++) + { + bias->at(i) *= multiplication->at(i); + } + } +} + +/** + * Fuse Mul to FullyConnected if the multiplied value is a channel(last dimension)-wise constant + * + * BEFORE + * | + * [CircleFullyConnected] + * | + * [CircleMul] + * | + * + * AFTER + * | + * [CircleFullyConnected] [CircleMul] (dead) + * | + * + */ +bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) +{ + // Sanity check: + RETURN_FALSE_UNLESS(fc); + // Allow only FLOAT32 data type: + RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); + // Allow only without activation functions as values are going to + // be multiplied before activation function. + RETURN_FALSE_UNLESS(fc->fusedActivationFunction() == luci::FusedActFunc::NONE); + // Check for weights being Constant: + auto weights = dynamic_cast(fc->weights()); + RETURN_FALSE_UNLESS(weights); + // Get Mul node: + auto fc_output = loco::succs(fc); + // Make sure that FullyConnected has only one child: + RETURN_FALSE_UNLESS(fc_output.size() == 1); + auto mul = dynamic_cast(*fc_output.begin()); + RETURN_FALSE_UNLESS(mul); + // Allow Mul node only with FLOAT32 data type: + RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); + // Get multiplication Constant (here: the second input besides weights): + auto multiplication = mul->x() == fc ? dynamic_cast(mul->y()) + : dynamic_cast(mul->x()); + RETURN_FALSE_UNLESS(multiplication); + // Get rank of multiplication: + auto rank = multiplication->rank(); + RETURN_FALSE_UNLESS(rank != 0); + // Check that all dimensions are ones, checks broadcast capabilites. + // Last dimesion of multiplication must be compatible with FC. + // N-D case (N>1): + if (multiplication->rank() > 1) + { + // Check channel-wise broadcasting: + for (uint32_t i = 0; i < rank - 1; i++) + RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1); + // Check the last dimesion of Mul is the same with the first dimension of FullyConnected + RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); + } + // Scalar case: + else if (multiplication->rank() == 1 || multiplication->rank() == 0) + { + RETURN_FALSE_UNLESS(multiplication->size() != 0); + } + + // Only supports: + // (1) constant bias + // (2) no bias + auto bias = loco::must_cast(fc->bias()); + RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST or + bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + // Create new bias to be updated with values: + auto const_bias = dynamic_cast(fc->bias()); + RETURN_FALSE_UNLESS(const_bias) + RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); + + auto fused_bias = luci::clone(const_bias); + // Create new weights to be updated with values: + auto fused_weights = luci::clone(weights); + + // Update bias accordingly: + update_bias(fused_bias, multiplication); + // Update weights accordingly: + update_weights(fused_weights, multiplication); + + // Replace weights and bias: + fc->weights(fused_weights); + fc->bias(fused_bias); + + // Set origin and copy Activation Function if exisitng: + fc->fusedActivationFunction(mul->fusedActivationFunction()); + luci::add_origin(fc, luci::get_origin(mul)); + + replace(mul).with(fc); + + return true; +} + +} // namespace + +namespace luci +{ + +bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast(node); + if (not fc) + continue; + + switch (fc->dtype()) + { + case loco::DataType::FLOAT32: + if (fuse_mul_with_fc(fc)) + changed = true; + break; + default: + break; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp new file mode 100644 index 00000000000..0043b44bfb9 --- /dev/null +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseMulWithFullyConnectedPass.h" +#include "helpers/CreateCircleConst.h" + +#include +#include + +#include + +#define DIM_ONE 8 +#define DIM_TWO 4 +#define MUL_VAL 2.0f + +namespace +{ + +using namespace luci::test; + +/** + * Graph for this test + * + * BEFORE + * + * [FC] + * | + * [Mul w/ Relu] + * + * AFTER + * + * [FC w/ Relu] (weights and bias updated) + * + */ +class FCMulGraphlet +{ +public: + FCMulGraphlet() = default; + + void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar) + { + std::vector weights_val(DIM_ONE * DIM_TWO); + for (uint32_t i = 0; i < DIM_ONE * DIM_TWO; i++) + weights_val.at(i) = i; + + _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + + std::vector bias_val(DIM_ONE); + for (uint32_t i = 0; i < DIM_ONE; i++) + bias_val.at(i) = i; + + _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); + + _fc = g->nodes()->create(); + _fc->weights(_fc_f); + _fc->bias(_fc_b); + _fc->fusedActivationFunction(fc_activation); + _fc->dtype(loco::DataType::FLOAT32); + _fc->shape({1, DIM_ONE}); + _fc->name("fc"); + + std::vector mul_values; + + if (is_mul_scalar) + { + mul_values.push_back(static_cast(MUL_VAL)); + _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, mul_values); + } + else + { + for (uint32_t i = 0; i < DIM_ONE; i++) + { + mul_values.push_back(static_cast(i)); + } + _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1, 1, 1, DIM_ONE}, mul_values); + } + + _mul = g->nodes()->create(); + _mul->x(_fc); + _mul->y(_mul_c); + _mul->fusedActivationFunction(luci::FusedActFunc::RELU); + _mul->dtype(loco::DataType::FLOAT32); + if (is_mul_scalar) + { + _mul->shape({1}); + } + else + { + _mul->shape({1, DIM_ONE}); + } + _mul->name("mul"); + } + +public: + luci::CircleFullyConnected *fc() { return _fc; } + + void to_fm_bias(void) + { + assert(_fc != nullptr); + + auto new_fc = _fc->graph()->nodes()->create(); + _fc->bias(new_fc); + } + +protected: + luci::CircleFullyConnected *_fc = nullptr; + luci::CircleMul *_mul = nullptr; + luci::CircleConst *_fc_f = nullptr; + luci::CircleConst *_fc_b = nullptr; + luci::CircleConst *_mul_c = nullptr; +}; + +class FuseAddWithFCTestGraph : public TestIOGraph, public FCMulGraphlet +{ +public: + FuseAddWithFCTestGraph() = default; + + void init(luci::FusedActFunc fc_activation = luci::FusedActFunc::NONE, bool is_mul_scalar = false) + { + TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); + FCMulGraphlet::init(g(), fc_activation, is_mul_scalar); + + _fc->input(input()); + + output()->from(_mul); + } +}; + +class FuseMulWithFullyConnectedPassTest : public ::testing::Test +{ +public: + FuseAddWithFCTestGraph g; + luci::FuseMulWithFullyConnectedPass pass; +}; + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) +{ + g.init(luci::FusedActFunc::NONE, false); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(i * offset, weights->at(offset)); + } + } + + auto bias = loco::must_cast(g.fc()->bias()); + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(i * i, bias->at(i)); + } +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) +{ + g.init(luci::FusedActFunc::NONE, true); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(MUL_VAL * offset, weights->at(offset)); + } + } + + auto bias = loco::must_cast(g.fc()->bias()); + for (uint32_t i = 0; i < bias->size(); i++) + { + EXPECT_EQ(MUL_VAL * i, bias->at(i)); + } +} + +TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) +{ + g.init(); + + // Bias cannot be fused as it's passed as feature map. + g.to_fm_bias(); + + EXPECT_EQ(false, pass.run(g.g())); +} + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) +{ + g.init(luci::FusedActFunc::RELU); + + EXPECT_EQ(false, pass.run(g.g())); +} +} // namespace From f661561a5a361124fe17f9856841bbad3123109a Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 7 Aug 2024 16:27:48 +0200 Subject: [PATCH 02/14] Change constness of args, move tests and move FuseMulWithFC after FuseMulWithDiv --- compiler/luci/pass/src/CircleOptimizer.cpp | 8 ++++---- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 8 ++++---- .../luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp | 3 ++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 27cf27e63fd..246d4f36e78 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -279,10 +279,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const phase.emplace_back(std::make_unique()); phase.emplace_back(std::make_unique()); - if (_options->query(Options::Algorithm::FuseMulWithFullyConnected)) - { - phase.emplace_back(std::make_unique()); - } if (_options->query(Options::Algorithm::CommonSubExpressionElimination)) { phase.emplace_back(std::make_unique()); @@ -315,6 +311,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseMulWithFullyConnected)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index 388c6359138..795a8af0237 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -29,12 +29,12 @@ namespace if (not(cond)) \ return false; -inline bool is_scalar(luci::CircleConst *node) +inline bool is_scalar(const luci::CircleConst *node) { return ((node->rank() == 1 || node->rank() == 0) && node->size() == 1); } -inline void update_with_scalar(luci::CircleConst *fused_node, luci::CircleConst *multiplication) +inline void update_with_scalar(luci::CircleConst *fused_node, const luci::CircleConst *multiplication) { for (uint32_t i = 0; i < fused_node->size(); i++) { @@ -42,7 +42,7 @@ inline void update_with_scalar(luci::CircleConst *fused_node, luci::CircleConst } } -inline void update_weights(luci::CircleConst *weights, luci::CircleConst *multiplication) +inline void update_weights(luci::CircleConst *weights, const luci::CircleConst *multiplication) { // Scalar multiplication: if (is_scalar(multiplication)) @@ -67,7 +67,7 @@ inline void update_weights(luci::CircleConst *weights, luci::CircleConst *multip } } -inline void update_bias(luci::CircleConst *bias, luci::CircleConst *multiplication) +inline void update_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) { // Scalar multiplication: if (is_scalar(multiplication)) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 0043b44bfb9..7db58b69909 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -146,6 +146,8 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test luci::FuseMulWithFullyConnectedPass pass; }; +} // namespace + TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) { g.init(luci::FusedActFunc::NONE, false); @@ -220,4 +222,3 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) EXPECT_EQ(false, pass.run(g.g())); } -} // namespace From 85d9783685973a647cd9d71407940284695f5fc7 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 7 Aug 2024 16:29:46 +0200 Subject: [PATCH 03/14] Fix codestyle --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index 795a8af0237..7c6f60ed7c9 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -34,7 +34,8 @@ inline bool is_scalar(const luci::CircleConst *node) return ((node->rank() == 1 || node->rank() == 0) && node->size() == 1); } -inline void update_with_scalar(luci::CircleConst *fused_node, const luci::CircleConst *multiplication) +inline void update_with_scalar(luci::CircleConst *fused_node, + const luci::CircleConst *multiplication) { for (uint32_t i = 0; i < fused_node->size(); i++) { From e3b354e36b7b747c9918d69cf26ceec2bd976fe2 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Thu, 8 Aug 2024 15:10:05 +0200 Subject: [PATCH 04/14] Remove default arguments --- .../src/FuseMulWithFullyConnectedPass.test.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 7db58b69909..681e2d5aa8d 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -48,8 +48,6 @@ using namespace luci::test; class FCMulGraphlet { public: - FCMulGraphlet() = default; - void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar) { std::vector weights_val(DIM_ONE * DIM_TWO); @@ -123,12 +121,10 @@ class FCMulGraphlet luci::CircleConst *_mul_c = nullptr; }; -class FuseAddWithFCTestGraph : public TestIOGraph, public FCMulGraphlet +class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet { public: - FuseAddWithFCTestGraph() = default; - - void init(luci::FusedActFunc fc_activation = luci::FusedActFunc::NONE, bool is_mul_scalar = false) + void init(luci::FusedActFunc fc_activation, bool is_mul_scalar) { TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); FCMulGraphlet::init(g(), fc_activation, is_mul_scalar); @@ -142,7 +138,7 @@ class FuseAddWithFCTestGraph : public TestIOGraph, public FCMulGraphlet class FuseMulWithFullyConnectedPassTest : public ::testing::Test { public: - FuseAddWithFCTestGraph g; + FuseMulWithFCTestGraph g; luci::FuseMulWithFullyConnectedPass pass; }; @@ -208,7 +204,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) { - g.init(); + g.init(luci::FusedActFunc::NONE, false); // Bias cannot be fused as it's passed as feature map. g.to_fm_bias(); @@ -218,7 +214,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) { - g.init(luci::FusedActFunc::RELU); + g.init(luci::FusedActFunc::RELU, false); EXPECT_EQ(false, pass.run(g.g())); } From 31e25edc0339e00aa356e54e8643f953ba0c98bd Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Fri, 9 Aug 2024 16:16:15 +0200 Subject: [PATCH 05/14] Refactor solution and apply comments --- .../src/FuseMulWithFullyConnectedPass.cpp | 35 +++++++++---------- .../FuseMulWithFullyConnectedPass.test.cpp | 2 +- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index 7c6f60ed7c9..ef96635579c 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -43,47 +43,52 @@ inline void update_with_scalar(luci::CircleConst *fused_node, } } -inline void update_weights(luci::CircleConst *weights, const luci::CircleConst *multiplication) +luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, + const luci::CircleConst *multiplication) { + auto fused_weights = luci::clone(weights); // Scalar multiplication: if (is_scalar(multiplication)) { - update_with_scalar(weights, multiplication); + update_with_scalar(fused_weights, multiplication); } // N-size multiplication: else { // Go along channels, multiplication size is ensured to be compatible with channels. - auto count = weights->dim(0).value(); - auto size = weights->dim(weights->rank() - 1).value(); + auto count = fused_weights->dim(0).value(); + auto size = fused_weights->dim(fused_weights->rank() - 1).value(); float val; for (uint32_t c = 0; c < count; c++) { val = multiplication->at(c); for (uint32_t i = 0; i < size; i++) { - weights->at(c * size + i) *= val; + fused_weights->at(c * size + i) *= val; } } } + return fused_weights; } -inline void update_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) +luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) { + auto fused_bias = luci::clone(bias); // Scalar multiplication: if (is_scalar(multiplication)) { - update_with_scalar(bias, multiplication); + update_with_scalar(fused_bias, multiplication); } // N-size multiplication: else { // Go along channels, multiplication size is ensured to be compatible with channels. - for (uint32_t i = 0; i < bias->size(); i++) + for (uint32_t i = 0; i < fused_bias->size(); i++) { - bias->at(i) *= multiplication->at(i); + fused_bias->at(i) *= multiplication->at(i); } } + return fused_bias; } /** @@ -128,7 +133,6 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) RETURN_FALSE_UNLESS(multiplication); // Get rank of multiplication: auto rank = multiplication->rank(); - RETURN_FALSE_UNLESS(rank != 0); // Check that all dimensions are ones, checks broadcast capabilites. // Last dimesion of multiplication must be compatible with FC. // N-D case (N>1): @@ -157,14 +161,9 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) RETURN_FALSE_UNLESS(const_bias) RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); - auto fused_bias = luci::clone(const_bias); - // Create new weights to be updated with values: - auto fused_weights = luci::clone(weights); - - // Update bias accordingly: - update_bias(fused_bias, multiplication); - // Update weights accordingly: - update_weights(fused_weights, multiplication); + // Create new weights and bias with updated values: + auto fused_bias = gen_fused_bias(const_bias, multiplication); + auto fused_weights = gen_fused_weights(weights, multiplication); // Replace weights and bias: fc->weights(fused_weights); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 681e2d5aa8d..ed38f2db380 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -75,7 +75,7 @@ class FCMulGraphlet if (is_mul_scalar) { mul_values.push_back(static_cast(MUL_VAL)); - _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {1}, mul_values); + _mul_c = luci::create_const_node(g, loco::DataType::FLOAT32, {}, mul_values); } else { From 8b17f47a439719f6aa5a6522f405308972f7a5bf Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Fri, 9 Aug 2024 17:03:58 +0200 Subject: [PATCH 06/14] Add handling of no bias case to pass --- .../src/FuseMulWithFullyConnectedPass.cpp | 29 +++++----- .../FuseMulWithFullyConnectedPass.test.cpp | 56 ++++++++++++++----- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index ef96635579c..19171707f6d 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -153,21 +153,24 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Only supports: // (1) constant bias // (2) no bias - auto bias = loco::must_cast(fc->bias()); - RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST or - bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) - // Create new bias to be updated with values: - auto const_bias = dynamic_cast(fc->bias()); - RETURN_FALSE_UNLESS(const_bias) - RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); - - // Create new weights and bias with updated values: - auto fused_bias = gen_fused_bias(const_bias, multiplication); - auto fused_weights = gen_fused_weights(weights, multiplication); + auto bias = dynamic_cast(fc->bias()); + if (bias != nullptr) + { + RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST or + bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + // Create new bias to be updated with values: + auto const_bias = dynamic_cast(fc->bias()); + RETURN_FALSE_UNLESS(const_bias) + RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); + + // Create new bias with updated values and replace: + auto fused_bias = gen_fused_bias(const_bias, multiplication); + fc->bias(fused_bias); + } - // Replace weights and bias: + // Create new weights with updated values and replace: + auto fused_weights = gen_fused_weights(weights, multiplication); fc->weights(fused_weights); - fc->bias(fused_bias); // Set origin and copy Activation Function if exisitng: fc->fusedActivationFunction(mul->fusedActivationFunction()); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index ed38f2db380..a05b4f80ee8 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -48,23 +48,27 @@ using namespace luci::test; class FCMulGraphlet { public: - void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar) + void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) { + _fc = g->nodes()->create(); + std::vector weights_val(DIM_ONE * DIM_TWO); for (uint32_t i = 0; i < DIM_ONE * DIM_TWO; i++) weights_val.at(i) = i; _fc_f = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + _fc->weights(_fc_f); - std::vector bias_val(DIM_ONE); - for (uint32_t i = 0; i < DIM_ONE; i++) - bias_val.at(i) = i; + if (use_bias) + { + std::vector bias_val(DIM_ONE); + for (uint32_t i = 0; i < DIM_ONE; i++) + bias_val.at(i) = i; - _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); + _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); + _fc->bias(_fc_b); + } - _fc = g->nodes()->create(); - _fc->weights(_fc_f); - _fc->bias(_fc_b); _fc->fusedActivationFunction(fc_activation); _fc->dtype(loco::DataType::FLOAT32); _fc->shape({1, DIM_ONE}); @@ -124,10 +128,10 @@ class FCMulGraphlet class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet { public: - void init(luci::FusedActFunc fc_activation, bool is_mul_scalar) + void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) { TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); - FCMulGraphlet::init(g(), fc_activation, is_mul_scalar); + FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias); _fc->input(input()); @@ -146,7 +150,7 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) { - g.init(luci::FusedActFunc::NONE, false); + g.init(luci::FusedActFunc::NONE, false, true); EXPECT_EQ(true, pass.run(g.g())); @@ -175,7 +179,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) { - g.init(luci::FusedActFunc::NONE, true); + g.init(luci::FusedActFunc::NONE, true, true); EXPECT_EQ(true, pass.run(g.g())); @@ -202,9 +206,33 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) } } +TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_no_bias) +{ + g.init(luci::FusedActFunc::NONE, false, false); + + EXPECT_EQ(true, pass.run(g.g())); + + auto fc = dynamic_cast(g.output()->from()); + EXPECT_NE(nullptr, fc); + EXPECT_EQ(nullptr, fc->bias()); + + auto weights = loco::must_cast(g.fc()->weights()); + auto weights_n = weights->dim(0).value(); + auto weights_m = weights->dim(1).value(); + uint32_t offset = 0; + for (uint32_t i = 0; i < weights_n; i++) + { + for (uint32_t j = 0; j < weights_m; j++) + { + offset = i * weights_m + j; + EXPECT_EQ(i * offset, weights->at(offset)); + } + } +} + TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) { - g.init(luci::FusedActFunc::NONE, false); + g.init(luci::FusedActFunc::NONE, false, true); // Bias cannot be fused as it's passed as feature map. g.to_fm_bias(); @@ -214,7 +242,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) { - g.init(luci::FusedActFunc::RELU, false); + g.init(luci::FusedActFunc::RELU, false, true); EXPECT_EQ(false, pass.run(g.g())); } From 9e22b260a8e06cc1a624b764199c2e5bcc785d81 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Mon, 12 Aug 2024 13:12:43 +0200 Subject: [PATCH 07/14] Apply comments, refactor tests and add proper handling of OUTPUTEXCLUDE --- .../src/FuseMulWithFullyConnectedPass.cpp | 17 ++++++----- .../FuseMulWithFullyConnectedPass.test.cpp | 30 +++++++++++-------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index 19171707f6d..d12d129fef3 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -136,7 +136,7 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check that all dimensions are ones, checks broadcast capabilites. // Last dimesion of multiplication must be compatible with FC. // N-D case (N>1): - if (multiplication->rank() > 1) + if (multiplication->rank() >= 1) { // Check channel-wise broadcasting: for (uint32_t i = 0; i < rank - 1; i++) @@ -145,28 +145,29 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); } // Scalar case: - else if (multiplication->rank() == 1 || multiplication->rank() == 0) + else if (multiplication->rank() == 0) { - RETURN_FALSE_UNLESS(multiplication->size() != 0); + RETURN_FALSE_UNLESS(multiplication->size() == 1); } // Only supports: // (1) constant bias // (2) no bias - auto bias = dynamic_cast(fc->bias()); - if (bias != nullptr) + auto bias = loco::must_cast(fc->bias()); + if (bias->opcode() == luci::CircleOpcode::CIRCLECONST) { - RETURN_FALSE_UNLESS(bias->opcode() == luci::CircleOpcode::CIRCLECONST or - bias->opcode() == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) // Create new bias to be updated with values: auto const_bias = dynamic_cast(fc->bias()); RETURN_FALSE_UNLESS(const_bias) RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32); - // Create new bias with updated values and replace: auto fused_bias = gen_fused_bias(const_bias, multiplication); fc->bias(fused_bias); } + else if (bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE) + { + return false; + } // Create new weights with updated values and replace: auto fused_weights = gen_fused_weights(weights, multiplication); diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index a05b4f80ee8..527fb7d9531 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -66,8 +66,13 @@ class FCMulGraphlet bias_val.at(i) = i; _fc_b = luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE}, bias_val); - _fc->bias(_fc_b); } + else + { + // Create CircleOutputExclude -- no bias + _fc_b = g->nodes()->create(); + } + _fc->bias(_fc_b); _fc->fusedActivationFunction(fc_activation); _fc->dtype(loco::DataType::FLOAT32); @@ -101,7 +106,7 @@ class FCMulGraphlet } else { - _mul->shape({1, DIM_ONE}); + _mul->shape({1, 1, 1, DIM_ONE}); } _mul->name("mul"); } @@ -121,7 +126,7 @@ class FCMulGraphlet luci::CircleFullyConnected *_fc = nullptr; luci::CircleMul *_mul = nullptr; luci::CircleConst *_fc_f = nullptr; - luci::CircleConst *_fc_b = nullptr; + luci::CircleNode *_fc_b = nullptr; luci::CircleConst *_mul_c = nullptr; }; @@ -148,9 +153,9 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test } // namespace -TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) +TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) { - g.init(luci::FusedActFunc::NONE, false, true); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); EXPECT_EQ(true, pass.run(g.g())); @@ -177,9 +182,9 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_not_scalar) } } -TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) +TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) { - g.init(luci::FusedActFunc::NONE, true, true); + g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */); EXPECT_EQ(true, pass.run(g.g())); @@ -206,15 +211,16 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_is_scalar) } } -TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_no_bias) +TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) { - g.init(luci::FusedActFunc::NONE, false, false); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */); EXPECT_EQ(true, pass.run(g.g())); auto fc = dynamic_cast(g.output()->from()); EXPECT_NE(nullptr, fc); - EXPECT_EQ(nullptr, fc->bias()); + auto no_bias = dynamic_cast(fc->bias()); + ASSERT_NE(nullptr, no_bias); auto weights = loco::must_cast(g.fc()->weights()); auto weights_n = weights->dim(0).value(); @@ -232,7 +238,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_without_activation_mul_no_bias) TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) { - g.init(luci::FusedActFunc::NONE, false, true); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); // Bias cannot be fused as it's passed as feature map. g.to_fm_bias(); @@ -242,7 +248,7 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) { - g.init(luci::FusedActFunc::RELU, false, true); + g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */); EXPECT_EQ(false, pass.run(g.g())); } From 8977ef9620ada75fe696bf7e18201d4f9a24a99e Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Mon, 12 Aug 2024 13:37:24 +0200 Subject: [PATCH 08/14] Handle rank 0 and 1 --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index d12d129fef3..bceb86e2f5e 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -136,7 +136,7 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check that all dimensions are ones, checks broadcast capabilites. // Last dimesion of multiplication must be compatible with FC. // N-D case (N>1): - if (multiplication->rank() >= 1) + if (multiplication->rank() > 1) { // Check channel-wise broadcasting: for (uint32_t i = 0; i < rank - 1; i++) @@ -144,7 +144,12 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check the last dimesion of Mul is the same with the first dimension of FullyConnected RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); } - // Scalar case: + // 1-D or scalar case: + else if (multiplication->rank() == 1) + { + RETURN_FALSE_UNLESS(multiplication->size() == 1 || + multiplication->size() == weights->dim(0)); + } else if (multiplication->rank() == 0) { RETURN_FALSE_UNLESS(multiplication->size() == 1); From b0851814408fbddf8b5fdb38a23339a29b19aa83 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 13 Aug 2024 13:52:27 +0200 Subject: [PATCH 09/14] Update names from scalar to single element --- .../pass/src/FuseMulWithFullyConnectedPass.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index bceb86e2f5e..c2951d7fda2 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -29,12 +29,12 @@ namespace if (not(cond)) \ return false; -inline bool is_scalar(const luci::CircleConst *node) +inline bool is_single_element(const luci::CircleConst *node) { return ((node->rank() == 1 || node->rank() == 0) && node->size() == 1); } -inline void update_with_scalar(luci::CircleConst *fused_node, +inline void update_with_single_element(luci::CircleConst *fused_node, const luci::CircleConst *multiplication) { for (uint32_t i = 0; i < fused_node->size(); i++) @@ -47,10 +47,10 @@ luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, const luci::CircleConst *multiplication) { auto fused_weights = luci::clone(weights); - // Scalar multiplication: - if (is_scalar(multiplication)) + // Single element multiplication: + if (is_single_element(multiplication)) { - update_with_scalar(fused_weights, multiplication); + update_with_single_element(fused_weights, multiplication); } // N-size multiplication: else @@ -74,10 +74,10 @@ luci::CircleConst *gen_fused_weights(luci::CircleConst *weights, luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleConst *multiplication) { auto fused_bias = luci::clone(bias); - // Scalar multiplication: - if (is_scalar(multiplication)) + // Single element multiplication: + if (is_single_element(multiplication)) { - update_with_scalar(fused_bias, multiplication); + update_with_single_element(fused_bias, multiplication); } // N-size multiplication: else From 1aa79ccb6192741c573fb85fbf18177479ec6d9e Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 13 Aug 2024 13:52:45 +0200 Subject: [PATCH 10/14] Update tests --- .../pass/src/FuseMulWithFullyConnectedPass.test.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 527fb7d9531..821f4ff3d5c 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -102,7 +102,7 @@ class FCMulGraphlet _mul->dtype(loco::DataType::FLOAT32); if (is_mul_scalar) { - _mul->shape({1}); + _mul->shape({1, DIM_ONE}); } else { @@ -252,3 +252,12 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) EXPECT_EQ(false, pass.run(g.g())); } + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_null_weights_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + + g.fc()->weights(nullptr); + + EXPECT_EQ(false, pass.run(g.g())); +} From 1bb278d49f59912206f82e8037c29c7a3b87edd3 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 13 Aug 2024 15:44:32 +0200 Subject: [PATCH 11/14] Fix codestyle --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index c2951d7fda2..cd9face54d6 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -35,7 +35,7 @@ inline bool is_single_element(const luci::CircleConst *node) } inline void update_with_single_element(luci::CircleConst *fused_node, - const luci::CircleConst *multiplication) + const luci::CircleConst *multiplication) { for (uint32_t i = 0; i < fused_node->size(); i++) { From 79a2213d063bbf85e0209a5ec06dc5030cc7c80c Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 14 Aug 2024 13:48:58 +0200 Subject: [PATCH 12/14] Search from mul, update tests --- .../src/FuseMulWithFullyConnectedPass.cpp | 37 +++++------ .../FuseMulWithFullyConnectedPass.test.cpp | 65 +++++++++++++++---- 2 files changed, 71 insertions(+), 31 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index cd9face54d6..c724f832b94 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -16,12 +16,12 @@ #include "luci/Pass/FuseMulWithFullyConnectedPass.h" +#include "helpers/NodeFiller.h" + #include #include #include -#include - namespace { @@ -107,10 +107,19 @@ luci::CircleConst *gen_fused_bias(luci::CircleConst *bias, const luci::CircleCon * | * */ -bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) +bool fuse_mul_with_fc(luci::CircleMul *mul) { // Sanity check: - RETURN_FALSE_UNLESS(fc); + RETURN_FALSE_UNLESS(mul); + // Allow Mul node only with FLOAT32 data type: + RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); + // Check if any FC node connects to Mul. + // Find the pattern of Mul(FC, CircleConst): + luci::CircleFullyConnected *fc = nullptr; + luci::CircleConst *multiplication = nullptr; + RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul)); + // Make sure that FullyConnected has only one successor: + RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1); // Allow only FLOAT32 data type: RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); // Allow only without activation functions as values are going to @@ -119,18 +128,6 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check for weights being Constant: auto weights = dynamic_cast(fc->weights()); RETURN_FALSE_UNLESS(weights); - // Get Mul node: - auto fc_output = loco::succs(fc); - // Make sure that FullyConnected has only one child: - RETURN_FALSE_UNLESS(fc_output.size() == 1); - auto mul = dynamic_cast(*fc_output.begin()); - RETURN_FALSE_UNLESS(mul); - // Allow Mul node only with FLOAT32 data type: - RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32); - // Get multiplication Constant (here: the second input besides weights): - auto multiplication = mul->x() == fc ? dynamic_cast(mul->y()) - : dynamic_cast(mul->x()); - RETURN_FALSE_UNLESS(multiplication); // Get rank of multiplication: auto rank = multiplication->rank(); // Check that all dimensions are ones, checks broadcast capabilites. @@ -197,14 +194,14 @@ bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto fc = dynamic_cast(node); - if (not fc) + auto mul = dynamic_cast(node); + if (not mul) continue; - switch (fc->dtype()) + switch (mul->dtype()) { case loco::DataType::FLOAT32: - if (fuse_mul_with_fc(fc)) + if (fuse_mul_with_fc(mul)) changed = true; break; default: diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp index 821f4ff3d5c..a4f9d6bf087 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.test.cpp @@ -34,13 +34,22 @@ using namespace luci::test; /** * Graph for this test * - * BEFORE + * BEFORE (without extra_fc_successor) * * [FC] * | * [Mul w/ Relu] * - * AFTER + * BEFORE (with extra_fc_successor) + * + * [FC] + * | + * |------------------- + * | | + * | | + * [Mul w/ Relu] [other FC] + * + * AFTER (if pass applied) * * [FC w/ Relu] (weights and bias updated) * @@ -48,7 +57,8 @@ using namespace luci::test; class FCMulGraphlet { public: - void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) + void init(loco::Graph *g, luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) { _fc = g->nodes()->create(); @@ -79,6 +89,22 @@ class FCMulGraphlet _fc->shape({1, DIM_ONE}); _fc->name("fc"); + if (extra_successor) + { + _extra_succ = g->nodes()->create(); + // Set previous FC as input to bump number of successors for it: + _extra_succ->input(_fc); + std::vector weights_val(DIM_ONE * DIM_TWO); + _extra_f = + luci::create_const_node(g, loco::DataType::FLOAT32, {DIM_ONE, DIM_TWO}, weights_val); + _extra_succ->weights(_extra_f); + _extra_succ->bias(nullptr); + _extra_succ->fusedActivationFunction(luci::FusedActFunc::NONE); + _extra_succ->dtype(loco::DataType::FLOAT32); + _extra_succ->shape({1, DIM_ONE}); + _extra_succ->name("extra_fc"); + } + std::vector mul_values; if (is_mul_scalar) @@ -128,15 +154,18 @@ class FCMulGraphlet luci::CircleConst *_fc_f = nullptr; luci::CircleNode *_fc_b = nullptr; luci::CircleConst *_mul_c = nullptr; + luci::CircleFullyConnected *_extra_succ = nullptr; + luci::CircleConst *_extra_f = nullptr; }; class FuseMulWithFCTestGraph : public TestIOGraph, public FCMulGraphlet { public: - void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias) + void init(luci::FusedActFunc fc_activation, bool is_mul_scalar, bool use_bias, + bool extra_successor) { TestIOGraph::init({1, DIM_TWO}, {1, DIM_ONE}); - FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias); + FCMulGraphlet::init(g(), fc_activation, is_mul_scalar, use_bias, extra_successor); _fc->input(input()); @@ -155,7 +184,8 @@ class FuseMulWithFullyConnectedPassTest : public ::testing::Test TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -184,7 +214,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_tensor) TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) { - g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, true /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -213,7 +244,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_mul_scalar) TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, false /* use_bias */, + false /* extra_successor */); EXPECT_EQ(true, pass.run(g.g())); @@ -238,7 +270,8 @@ TEST_F(FuseMulWithFullyConnectedPassTest, fc_no_bias) TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); // Bias cannot be fused as it's passed as feature map. g.to_fm_bias(); @@ -248,16 +281,26 @@ TEST_F(FuseMulWithFullyConnectedPassTest, bias_feature_map_NEG) TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_activation_NEG) { - g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::RELU, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); EXPECT_EQ(false, pass.run(g.g())); } TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_null_weights_NEG) { - g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */); + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + false /* extra_successor */); g.fc()->weights(nullptr); EXPECT_EQ(false, pass.run(g.g())); } + +TEST_F(FuseMulWithFullyConnectedPassTest, fc_with_extra_successor_NEG) +{ + g.init(luci::FusedActFunc::NONE, false /* is_mul_scalar */, true /* use_bias */, + true /* extra_successor */); + + EXPECT_EQ(false, pass.run(g.g())); +} From 550e798f3f88c2a2edc9043e50765411280a6a30 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Mon, 19 Aug 2024 09:49:39 +0200 Subject: [PATCH 13/14] Annotate requirement of one successor and refactor checks --- .../src/FuseMulWithFullyConnectedPass.cpp | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index c724f832b94..d4fb75953ed 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -118,7 +118,39 @@ bool fuse_mul_with_fc(luci::CircleMul *mul) luci::CircleFullyConnected *fc = nullptr; luci::CircleConst *multiplication = nullptr; RETURN_FALSE_UNLESS(luci::fill(&fc, &multiplication).with_commutative_args_of(mul)); - // Make sure that FullyConnected has only one successor: + /** + * Make sure that FullyConnected has only one successor. + * + * If the FullyConnected output is connected to more nodes, + * this pass will replace node with new fused FullyConnected. + * Thus pass success will only introduce extra FullyConnected + * without reducing overall number of nodes. + * Which tends to increase model's size and degrades model's performance. + * Thus one successor is required to benefit from this pass. + * + * Example graph that illustrates the described scenario: + * + * BEFORE + * | + * [CircleFullyConnected] + * | + * +-------+----------------+ + * | | + * | | + * [Other Node] [CircleMul] + * | | + * + * AFTER + * | + * [CircleFullyConnected] + * | + * +-------+-----------------------+ + * | | + * | | + * [Other Node] [New CircleFullyConnected Fused with Mul] + * | | + * + */ RETURN_FALSE_UNLESS(loco::succs(fc).size() == 1); // Allow only FLOAT32 data type: RETURN_FALSE_UNLESS(fc->dtype() == loco::DataType::FLOAT32); @@ -194,18 +226,10 @@ bool FuseMulWithFullyConnectedPass::run(loco::Graph *g) bool changed = false; for (auto node : loco::active_nodes(loco::output_nodes(g))) { - auto mul = dynamic_cast(node); - if (not mul) - continue; - - switch (mul->dtype()) + if (auto mul = dynamic_cast(node)) { - case loco::DataType::FLOAT32: - if (fuse_mul_with_fc(mul)) - changed = true; - break; - default: - break; + if (fuse_mul_with_fc(mul)) + changed = true; } } From 48defc27b2f11cc0da1d0f25a86daf3e69a924a9 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Mon, 19 Aug 2024 10:24:16 +0200 Subject: [PATCH 14/14] Fix graph in explanation --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index d4fb75953ed..3049862e216 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -142,9 +142,12 @@ bool fuse_mul_with_fc(luci::CircleMul *mul) * * AFTER * | - * [CircleFullyConnected] - * | - * +-------+-----------------------+ + * +-----------------------+ + * | | + * | | + * [CircleFullyConnected] | + * | | + * +-------+ | * | | * | | * [Other Node] [New CircleFullyConnected Fused with Mul]