From 36c7317811b8fd39da7fd3790869b06d28bd8806 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Tue, 6 Aug 2024 15:54:40 +0300 Subject: [PATCH 1/5] [DRAFT][compiler] Introduce FuseGRUPass This pr introduces FuseGRUPass for fusing gru pattern into single CircleGRU op. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- compiler/circle2circle/src/Circle2Circle.cpp | 3 + .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../luci/pass/include/luci/Pass/FuseGRUPass.h | 39 + compiler/luci/pass/src/CircleOptimizer.cpp | 5 + compiler/luci/pass/src/FuseGRUPass.cpp | 674 ++++++++++++++++++ compiler/luci/pass/src/FuseGRUPass.test.cpp | 418 +++++++++++ 6 files changed, 1140 insertions(+) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseGRUPass.h create mode 100644 compiler/luci/pass/src/FuseGRUPass.cpp create mode 100644 compiler/luci/pass/src/FuseGRUPass.test.cpp diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 757c368f31d..80d775aa86e 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -130,6 +130,7 @@ int entry(int argc, char **argv) "This will fuse BatchNorm operators of pre-activations to Convolution operator"); add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator"); add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator"); + add_switch(arser, "--fuse_gru", "This will fuse operators to GRU operator"); add_switch(arser, "--fuse_rsqrt", "This will fuse operators to Rsqrt operator"); add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes"); add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators"); @@ -334,6 +335,8 @@ int entry(int argc, char **argv) options->enable(Algorithms::FusePRelu); if (arser.get("--fuse_gelu")) options->enable(Algorithms::FuseGelu); + if (arser.get("--fuse_gru")) + options->enable(Algorithms::FuseGRU); if (arser.get("--fuse_rsqrt")) options->enable(Algorithms::FuseRsqrt); if (arser.get("--fuse_transpose_with_mean")) diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index 9cbd26f0da5..a7807f01bd1 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -76,6 +76,7 @@ class CircleOptimizer final FuseActivationFunction, FusePRelu, FuseGelu, + FuseGRU, FuseRsqrt, ShuffleWeightTo16x1Float32, RemoveRedundantTranspose, diff --git a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h new file mode 100644 index 00000000000..152dc427d95 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_GRU_PASS_H__ +#define __LUCI_FUSE_GRU_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleGRU + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseGRUPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseGRUPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_GRU_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 840c8dd25dd..a9ac64c9ffe 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -51,6 +51,7 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" +#include "luci/Pass/FuseGRUPass.h" #include "luci/Pass/FuseRsqrtPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" @@ -370,6 +371,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique()); } + if (_options->query(Options::Algorithm::FuseGRU)) + { + phase.emplace_back(std::make_unique()); + } if (_options->query(Options::Algorithm::FuseRsqrt)) { phase.emplace_back(std::make_unique()); diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp new file mode 100644 index 00000000000..2f1f2d341ef --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -0,0 +1,674 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" +#include "helpers/NodeFiller.h" + +#include + +#include +#include + +#include + +#include + +// Helper to fuse GRU +namespace +{ + +class GRUPatternBase +{ +public: + GRUPatternBase(luci::CircleNode *candidate) { _pattern_last_node = candidate; } + + virtual ~GRUPatternBase() = default; + +public: + virtual bool matched() = 0; + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleConst *_weight_ih = nullptr; + luci::CircleConst *_bias_ih = nullptr; + luci::CircleConst *_weight_hh = nullptr; + luci::CircleConst *_bias_hh = nullptr; + + luci::CircleConst *_hidden_input = nullptr; + + luci::CircleConst *_less_const = nullptr; + + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleNode *_pattern_last_node = nullptr; +}; + +/** + * Below diagram shows GRU pattern to fuse. + * Note: this pattern for GRU with `return_sequences=False` + * - the below pattern will be replaced with one GRU + * Main Graph: + * [In] [CircleConst] [CircleConst] [CircleConst] [CircleConst] + * | | | | | + * V | | | | + * [CircleWhile]<----------------------------------------------------- + * | + * V + * [CircleWhileOut] + * | + * V + * [Out] + * + * Condition Graph: + * [In] [CircleConst] (scalar int32 value) + * | | + * V | + * [Less]------ + * | + * V + * [Out] + * + * Body Graph must contain: + * - 2 CircleFullyConnected nodes; + * - 3 CircleMul nodes; + * - 2 CircleLogistic nodes; + * - 2 CircleSplit nodes; + * - 6 CircleAdd nodes; + * - 1 CircleGather node; + * - 1 CircleReshape node; + * - 1 CircleSub node; + * - 1 CircleTanh node; + * - 6 CircleSplitOut nodes; + * - 5 CircleInput nodes; + * - 5 CircleOutput nodes; + * + * Body Graph: + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] -------+----+---------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ +class GRUPattern1 final : public GRUPatternBase +{ +public: + GRUPattern1(luci::CircleWhileOut *candidate) : GRUPatternBase(candidate) + { + assert(candidate); + _while_out_node = candidate; + } + +public: + bool matched() override; +}; + +bool GRUPattern1::matched() +{ + // 0 - check while node + _while_node = dynamic_cast(_while_out_node->input()); + if (_while_node == nullptr) + return false; + + // 1 - check condition graph: only one Less operation + // with scalar int const value + { + const auto cond_graph = _while_node->cond_graph(); + + const auto cond_nodes = loco::active_nodes(loco::output_nodes(cond_graph)); + if (cond_nodes.size() != 4) + return false; + luci::CircleLess *less_node = nullptr; + for (auto node : cond_nodes) + { + less_node = dynamic_cast(node); + if (less_node != nullptr) + break; + } + + // doesn't find Less node + if (less_node == nullptr) + return false; + + luci::CircleNode *less_input; + if (not luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)) + return false; + + if (_less_const->dtype() != loco::DataType::S32) + return false; + + if (_less_const->size() != 1) + return false; + + assert(_less_const->at(0) > 0); + } + + // 2 - Check while's input nodes + // Save hidden state input node + { + if (_while_node->input_count() != 5) + return false; + + // Save input node + _ifm = dynamic_cast(_while_node->input(4)); + if (_ifm == nullptr) + return false; + + _hidden_input = dynamic_cast(_while_node->input(3)); + if (_hidden_input == nullptr) + return false; + } + + // 3 - check body graph + { + const auto body_graph = _while_node->body_graph(); + + if (loco::input_nodes(body_graph).size() != 5) + return false; + + if (loco::output_nodes(body_graph).size() != 5) + return false; + + const auto body_nodes = loco::active_nodes(loco::output_nodes(body_graph)); + + // Save all nodes according its types + std::vector fc_nodes; + std::vector split_nodes; + std::vector logistic_nodes; + std::vector mul_nodes; + std::vector add_nodes; + std::vector sub_nodes; + std::vector reshape_nodes; + std::vector gather_nodes; + std::vector tanh_nodes; + std::vector split_out_nodes; + + for (auto node : body_nodes) + { + auto circle_node = dynamic_cast(node); + switch (circle_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + case luci::CircleOpcode::CIRCLEINPUT: + case luci::CircleOpcode::CIRCLEOUTPUT: + case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE: + break; + case luci::CircleOpcode::FULLY_CONNECTED: + fc_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SPLIT: + split_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::LOGISTIC: + logistic_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::MUL: + mul_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::ADD: + add_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SUB: + sub_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::RESHAPE: + reshape_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::GATHER: + gather_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::TANH: + tanh_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + split_out_nodes.push_back(dynamic_cast(circle_node)); + break; + default: + return false; + } + } + + // Check number of nodes + if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or + split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or + reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or + split_out_nodes.size() != 6) + return false; + + // Check structure + // TODO: add more checks + { + // 1 - Check Split ops + // Both has FC nodes as input + // Axis is const + for (auto node : split_nodes) + { + if (dynamic_cast(node->split_dim()) == nullptr or + dynamic_cast(node->input()) == nullptr) + return false; + } + + // 2 - Check Logistic ops + // Add is input node for both nodes + for (auto node : logistic_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + + // 3 - Check Sub + // Const - is first input node + // Logistic - is second input node + for (auto node : sub_nodes) + { + if (dynamic_cast(node->y()) == nullptr or + dynamic_cast(node->x()) == nullptr) + return false; + } + + // 4 - Check Add + // Mul or Const or Input or Split ops can be input nodes + // Mul - 3 times as input + // Const - 2 times as input + // Input - 2 times as input + // Split - 5 times as input + { + int num_mul = 0; + int num_const = 0; + int num_input = 0; + int num_split = 0; + for (auto node : add_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + } + if (num_mul != 3 or num_split != 5 or num_const != 2 or num_input != 2) + return false; + } + } + + // 5 - Check Mul + // Logistic or Tanh or Sub or Input or Split ops can be input nodes + // Logistic - 2 times as input + // Tanh - 1 times as input + // Sub - 1 times as input + // Split - 1 times as input + // Input - 1 times as input + { + int num_logistic = 0; + int num_tanh = 0; + int num_sub = 0; + int num_split = 0; + int num_input = 0; + for (auto node : mul_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + } + if (num_logistic != 2 or num_tanh != 1 or num_sub != 1 or num_split != 1 or num_input != 1) + return false; + } + + // 6 - Check Gather + // Gather has two CircleInput as input + { + for (auto node : gather_nodes) + { + if (dynamic_cast(node->indices()) == nullptr) + return false; + + if (dynamic_cast(node->params()) == nullptr) + return false; + } + } + + // 7 - Check Tanh + // Input is CircleAdd + { + for (auto node : tanh_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + } + + // Find input and hidden FC weights and biases + for (auto node : body_nodes) + { + auto *fc_node = dynamic_cast(node); + if (fc_node == nullptr) + continue; + + const auto input_node = dynamic_cast(fc_node->input()); + if (input_node == nullptr) + return false; + + // For input hidden FullyConnected - input node is CircleInput node + if (dynamic_cast(input_node) != nullptr) + { + _weight_ih = dynamic_cast(fc_node->weights()); + _bias_ih = dynamic_cast(fc_node->bias()); + } + // For hidden hidden FullyConnected - input node is CircleGather node + else if (dynamic_cast(input_node) != nullptr) + { + _weight_hh = dynamic_cast(fc_node->weights()); + _bias_hh = dynamic_cast(fc_node->bias()); + } + else + { + return false; + } + } + + if (_weight_ih == nullptr or _weight_hh == nullptr) + return false; + } + + return true; +} + +class FuseGRU final +{ +public: + FuseGRU(const GRUPatternBase *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleGRU *create_circle_gru(loco::Graph *graph); + +private: + const GRUPatternBase *_p; +}; + +template +void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned) +{ + assert(T == node->dtype()); + assert(T == cloned->dtype()); + + const auto size = node->size(); + cloned->size(size); + for (uint32_t i = 0; i < size; i++) + cloned->at(i) = node->at(i); +} + +luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph) +{ + auto cloned = graph->nodes()->create(); + + if (cloned != nullptr) + { + // dtype/shape + cloned->dtype(node->dtype()); + cloned->rank(node->rank()); + + // values + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + copy_values(node, cloned); + break; + + case loco::DataType::U8: + copy_values(node, cloned); + break; + + case loco::DataType::S8: + copy_values(node, cloned); + break; + + case loco::DataType::S16: + copy_values(node, cloned); + break; + + case loco::DataType::S32: + copy_values(node, cloned); + break; + + case loco::DataType::S64: + copy_values(node, cloned); + break; + + case loco::DataType::BOOL: + copy_values(node, cloned); + break; + + default: + assert(false); + } + } + + return cloned; +} + +luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) +{ + assert(graph); + + auto weight_ih_cloned = clone_circleconst(_p->_weight_ih, graph); + luci::copy_common_attributes(_p->_weight_ih, weight_ih_cloned); + + auto weight_hh_cloned = clone_circleconst(_p->_weight_hh, graph); + luci::copy_common_attributes(_p->_weight_hh, weight_hh_cloned); + + luci::CircleNode *bias_ih_cloned = nullptr; + if (_p->_bias_ih != nullptr) + { + bias_ih_cloned = clone_circleconst(_p->_bias_ih, graph); + luci::copy_common_attributes(_p->_bias_ih, bias_ih_cloned); + } + else + { + bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + luci::CircleNode *bias_hh_cloned = nullptr; + if (_p->_bias_hh != nullptr) + { + bias_hh_cloned = clone_circleconst(_p->_bias_hh, graph); + luci::copy_common_attributes(_p->_bias_hh, bias_hh_cloned); + } + else + { + bias_hh_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); + luci::copy_common_attributes(_p->_hidden_input, hidden_input_cloned); + + auto less_const_cloned = clone_circleconst(_p->_less_const, graph); + luci::copy_common_attributes(_p->_less_const, less_const_cloned); + + // Create and configure new CircleGRU operation. + auto circle_gru = _p->_while_node->graph()->nodes()->create(); + circle_gru->input(_p->_ifm); + circle_gru->hidden_hidden(weight_hh_cloned); + circle_gru->hidden_input(weight_ih_cloned); + circle_gru->hidden_hidden_bias(bias_hh_cloned); + circle_gru->hidden_input_bias(bias_ih_cloned); + circle_gru->state(hidden_input_cloned); + + // Note: Now support only returnSequences = false + circle_gru->returnSequences(false); + circle_gru->name("FusedCircleGRU"); + + return circle_gru; +} + +void FuseGRU::apply() +{ + auto graph = _p->_pattern_last_node->graph(); + + auto gru_out = create_circle_gru(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p->_while_node), luci::get_origin(_p->_while_out_node), + luci::get_origin(_p->_weight_hh), luci::get_origin(_p->_weight_ih)}; + + luci::add_origin(gru_out, luci::composite_origin(origin_vec)); + + replace(_p->_pattern_last_node).with(gru_out); +} + +} // namespace + +namespace +{ + +bool fuse_gru(luci::CircleWhileOut *while_out_node) +{ + assert(while_out_node); + + // check first pattern + GRUPattern1 pattern(while_out_node); + if (pattern.matched()) + { + FuseGRU fuse(&pattern); + fuse.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseGRUPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto while_out_node = dynamic_cast(node); + if (not while_out_node) + continue; + + if (fuse_gru(while_out_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp new file mode 100644 index 00000000000..93909ea673f --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class GRUGraphlet +{ +public: + GRUGraphlet() = default; + + void init(loco::Graph *g) + { + _while_node = g->nodes()->create(5, 5); + _while_out_node = g->nodes()->create(); + _hidden_node = g->nodes()->create(); + _hidden_node->dtype(loco::DataType::FLOAT32); + _time_node = g->nodes()->create(); + _time_node->dtype(loco::DataType::FLOAT32); + _state_node = g->nodes()->create(); + _state_node->dtype(loco::DataType::FLOAT32); + + _body_graph = loco::make_graph(); + _cond_graph = loco::make_graph(); + + _less_node = _cond_graph->nodes()->create(); + _less_const_node = _cond_graph->nodes()->create(); + _less_const_node->dtype(loco::DataType::S32); + _less_const_node->size(1); + _less_const_node->at(0) = 1; + + _add_node_1 = _body_graph->nodes()->create(); + _add_node_2 = _body_graph->nodes()->create(); + _add_node_3 = _body_graph->nodes()->create(); + _add_node_4 = _body_graph->nodes()->create(); + _add_node_5 = _body_graph->nodes()->create(); + _add_node_6 = _body_graph->nodes()->create(); + + _fc_node_1 = _body_graph->nodes()->create(); + _fc_node_2 = _body_graph->nodes()->create(); + _fc_weight_1 = _body_graph->nodes()->create(); + _fc_weight_1->dtype(loco::DataType::FLOAT32); + _fc_weight_2 = _body_graph->nodes()->create(); + _fc_weight_2->dtype(loco::DataType::FLOAT32); + _fc_bias_1 = _body_graph->nodes()->create(); + _fc_bias_1->dtype(loco::DataType::FLOAT32); + _fc_bias_2 = _body_graph->nodes()->create(); + _fc_bias_2->dtype(loco::DataType::FLOAT32); + + _split_const = _body_graph->nodes()->create(); + _split_const->dtype(loco::DataType::S32); + + _logistic_node_1 = _body_graph->nodes()->create(); + _logistic_node_2 = _body_graph->nodes()->create(); + + _gather_node = _body_graph->nodes()->create(); + + _mul_node_1 = _body_graph->nodes()->create(); + _mul_node_2 = _body_graph->nodes()->create(); + _mul_node_3 = _body_graph->nodes()->create(); + + _tanh_node = _body_graph->nodes()->create(); + _sub_node = _body_graph->nodes()->create(); + + _split_node_1 = _body_graph->nodes()->create(); + _split_node_2 = _body_graph->nodes()->create(); + _split_out_node_1 = _body_graph->nodes()->create(); + _split_out_node_2 = _body_graph->nodes()->create(); + _split_out_node_3 = _body_graph->nodes()->create(); + _split_out_node_4 = _body_graph->nodes()->create(); + _split_out_node_5 = _body_graph->nodes()->create(); + _split_out_node_6 = _body_graph->nodes()->create(); + + _reshape_node = _body_graph->nodes()->create(); + + auto graph_input_cond_graph = _cond_graph->inputs()->create(); + _cond_input_node = _cond_graph->nodes()->create(); + _cond_input_node->index(graph_input_cond_graph->index()); + + auto graph_output_cond_graph = _cond_graph->outputs()->create(); + _cond_output_node = _cond_graph->nodes()->create(); + _cond_output_node->index(graph_output_cond_graph->index()); + + auto graph_input_body_graph_1 = _body_graph->inputs()->create(); + _body_input_node_1 = _body_graph->nodes()->create(); + _body_input_node_1->index(graph_input_body_graph_1->index()); + + auto graph_input_body_graph_2 = _body_graph->inputs()->create(); + _body_input_node_2 = _body_graph->nodes()->create(); + _body_input_node_2->index(graph_input_body_graph_2->index()); + + auto graph_input_body_graph_3 = _body_graph->inputs()->create(); + _body_input_node_3 = _body_graph->nodes()->create(); + _body_input_node_3->index(graph_input_body_graph_3->index()); + + auto graph_input_body_graph_4 = _body_graph->inputs()->create(); + _body_input_node_4 = _body_graph->nodes()->create(); + _body_input_node_4->index(graph_input_body_graph_4->index()); + + auto graph_input_body_graph_5 = _body_graph->inputs()->create(); + _body_input_node_5 = _body_graph->nodes()->create(); + _body_input_node_5->index(graph_input_body_graph_5->index()); + + auto graph_output_body_graph_1 = _body_graph->outputs()->create(); + _body_output_node_1 = _body_graph->nodes()->create(); + _body_output_node_1->index(graph_output_body_graph_1->index()); + + auto graph_output_body_graph_2 = _body_graph->outputs()->create(); + _body_output_node_2 = _body_graph->nodes()->create(); + _body_output_node_2->index(graph_output_body_graph_2->index()); + + auto graph_output_body_graph_3 = _body_graph->outputs()->create(); + _body_output_node_3 = _body_graph->nodes()->create(); + _body_output_node_3->index(graph_output_body_graph_3->index()); + + auto graph_output_body_graph_4 = _body_graph->outputs()->create(); + _body_output_node_4 = _body_graph->nodes()->create(); + _body_output_node_4->index(graph_output_body_graph_4->index()); + + auto graph_output_body_graph_5 = _body_graph->outputs()->create(); + _body_output_node_5 = _body_graph->nodes()->create(); + _body_output_node_5->index(graph_output_body_graph_5->index()); + } + + void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } + +protected: + luci::CircleWhile *_while_node; + luci::CircleWhileOut *_while_out_node; + luci::CircleConst *_time_node; + luci::CircleConst *_state_node; + luci::CircleConst *_hidden_node; + + luci::CircleInput *_cond_input_node; + luci::CircleLess *_less_node; + luci::CircleConst *_less_const_node; + luci::CircleOutput *_cond_output_node; + + luci::CircleInput *_body_input_node_1; + luci::CircleInput *_body_input_node_2; + luci::CircleInput *_body_input_node_3; + luci::CircleInput *_body_input_node_4; + luci::CircleInput *_body_input_node_5; + + luci::CircleOutput *_body_output_node_1; + luci::CircleOutput *_body_output_node_2; + luci::CircleOutput *_body_output_node_3; + luci::CircleOutput *_body_output_node_4; + luci::CircleOutput *_body_output_node_5; + + luci::CircleAdd *_add_node_1; + luci::CircleAdd *_add_node_2; + luci::CircleAdd *_add_node_3; + luci::CircleAdd *_add_node_4; + luci::CircleAdd *_add_node_5; + luci::CircleAdd *_add_node_6; + + luci::CircleMul *_mul_node_1; + luci::CircleMul *_mul_node_2; + luci::CircleMul *_mul_node_3; + + luci::CircleSub *_sub_node; + luci::CircleTanh *_tanh_node; + luci::CircleReshape *_reshape_node; + luci::CircleGather *_gather_node; + luci::CircleLogistic *_logistic_node_1; + luci::CircleLogistic *_logistic_node_2; + luci::CircleSplit *_split_node_1; + luci::CircleSplit *_split_node_2; + + luci::CircleSplitOut *_split_out_node_1; + luci::CircleSplitOut *_split_out_node_2; + luci::CircleSplitOut *_split_out_node_3; + luci::CircleSplitOut *_split_out_node_4; + luci::CircleSplitOut *_split_out_node_5; + luci::CircleSplitOut *_split_out_node_6; + + luci::CircleFullyConnected *_fc_node_1; + luci::CircleFullyConnected *_fc_node_2; + + luci::CircleConst *_split_const; + luci::CircleConst *_fc_weight_1; + luci::CircleConst *_fc_bias_1; + luci::CircleConst *_fc_weight_2; + luci::CircleConst *_fc_bias_2; + + std::unique_ptr _cond_graph; + std::unique_ptr _body_graph; +}; + +class FuseGRUTestGraph1 : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestGraph1() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_1->y(_split_const); + _add_node_2->x(_body_input_node_2); + _add_node_2->y(_split_const); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _gather_node->indices(_body_input_node_1); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_1->split_dim(_split_const); + _split_node_2->input(_fc_node_2); + _split_node_2->split_dim(_split_const); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + _sub_node->x(_split_const); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +class FuseGRUTestNegGraph : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + invalid_less_const_type(); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_2->x(_body_input_node_2); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_2->input(_fc_node_2); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +} // namespace + +TEST(FuseGRUPassTest, fuse_pattern1) +{ + FuseGRUTestGraph1 g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGRUPassTest, fuse_NEG) +{ + FuseGRUTestNegGraph g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} From 9538317884dc1ba7c637dc8a461a90e098206be5 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Wed, 7 Aug 2024 14:17:11 +0300 Subject: [PATCH 2/5] [DRAFT][compiler] Introduce EliminateDeadSubgraphPass This pr introduces EliminateDeadSubgraphPass for removing dead subgraph. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- compiler/circle2circle/src/Circle2Circle.cpp | 5 +- compiler/luci/lang/include/luci/IR/Module.h | 7 + compiler/luci/lang/src/Module.cpp | 8 + compiler/luci/lang/src/Module.test.cpp | 34 ++++ .../luci/Pass/EliminateDeadSubgraphPass.h | 45 ++++++ compiler/luci/pass/src/CircleOptimizer.cpp | 3 + .../pass/src/EliminateDeadSubgraphPass.cpp | 145 ++++++++++++++++++ .../src/EliminateDeadSubgraphPass.test.cpp | 133 ++++++++++++++++ 8 files changed, 379 insertions(+), 1 deletion(-) create mode 100644 compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h create mode 100644 compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp create mode 100644 compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 80d775aa86e..99b10c5fbe2 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -518,7 +518,7 @@ int entry(int argc, char **argv) luci::change_outputs(graph, new_outputs); } - // call luci optimizations for module + // call luci optimizations for module before optimizations for graph optimizer.optimize(module.get()); for (size_t idx = 0; idx < module->size(); ++idx) @@ -541,6 +541,9 @@ int entry(int argc, char **argv) } } + // call luci optimizations for module after optimizations for graph + optimizer.optimize(module.get()); + // Export to output Circle file luci::CircleExporter exporter; diff --git a/compiler/luci/lang/include/luci/IR/Module.h b/compiler/luci/lang/include/luci/IR/Module.h index 75cf67905e7..dd4df3340be 100644 --- a/compiler/luci/lang/include/luci/IR/Module.h +++ b/compiler/luci/lang/include/luci/IR/Module.h @@ -51,6 +51,13 @@ class Module final */ loco::Graph *graph(void) const; + /** + * @brief remove graph at index + * + * @note graph(0) is interpreted as a main graph and cannot be deleted + */ + void removeGraphByIndex(size_t idx); + /** * @brief provide graph with an index * diff --git a/compiler/luci/lang/src/Module.cpp b/compiler/luci/lang/src/Module.cpp index 80ef61910f2..16c5a8277d8 100644 --- a/compiler/luci/lang/src/Module.cpp +++ b/compiler/luci/lang/src/Module.cpp @@ -35,6 +35,14 @@ loco::Graph *Module::graph(void) const return graph.get(); } +void Module::removeGraphByIndex(size_t idx) +{ + if (idx >= _graphs.size() or idx == 0) + throw std::invalid_argument("Module: Invalid graph index to be deleted"); + + _graphs.erase(_graphs.begin() + idx); +} + loco::Graph *Module::graph(size_t idx) const { auto &graph = _graphs.at(idx); diff --git a/compiler/luci/lang/src/Module.test.cpp b/compiler/luci/lang/src/Module.test.cpp index a5973e52dad..16c93250eb3 100644 --- a/compiler/luci/lang/src/Module.test.cpp +++ b/compiler/luci/lang/src/Module.test.cpp @@ -37,6 +37,33 @@ TEST(ModuleTest, add) ASSERT_EQ(g_ptr, m->graph(0)); } +TEST(ModuleTest, remove) +{ + auto m = luci::make_module(); + auto g1 = loco::make_graph(); + auto g2 = loco::make_graph(); + auto g3 = loco::make_graph(); + auto g1_ptr = g1.get(); + auto g2_ptr = g2.get(); + auto g3_ptr = g3.get(); + + m->add(std::move(g1)); + m->add(std::move(g2)); + m->add(std::move(g3)); + + ASSERT_EQ(3, m->size()); + ASSERT_EQ(g1_ptr, m->graph()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g2_ptr, m->graph(1)); + ASSERT_EQ(g3_ptr, m->graph(2)); + + // Let's delete graph at second position + m->removeGraphByIndex(1); + ASSERT_EQ(2, m->size()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g3_ptr, m->graph(1)); +} + TEST(ModuleTest, add_more) { auto m = luci::make_module(); @@ -65,6 +92,13 @@ TEST(ModuleTest, add_nullptr_NEG) EXPECT_THROW(m->add(nullptr), std::invalid_argument); } +TEST(ModuleTest, remove_index_overflow_NEG) +{ + auto m = luci::make_module(); + + EXPECT_THROW(m->removeGraphByIndex(10), std::invalid_argument); +} + TEST(ModuleTest, graph_index_overflow_NEG) { auto m = luci::make_module(); diff --git a/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h b/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h new file mode 100644 index 00000000000..1b7e39192f7 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ +#define __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ + +#include +#include +#include + +namespace luci +{ + +/** + * @brief Class to eliminate dead subgraph + * + */ +struct EliminateDeadSubgraphPass final : public luci::Pass +{ + const char *name(void) const final { return "luci::EliminateDeadSubgraphPass"; } + + bool run(luci::Module *m); + bool run(loco::Graph *) + { + // Do nothing + return false; + } +}; + +} // namespace luci + +#endif // __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index a9ac64c9ffe..d4db0e7dac6 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -101,6 +101,7 @@ #include "luci/Pass/CircleShapeInferencePass.h" #include "luci/Pass/CircleTypeInferencePass.h" +#include "luci/Pass/EliminateDeadSubgraphPass.h" // logo passes #include @@ -246,6 +247,8 @@ void CircleOptimizer::optimize(luci::Module *m) const phase.emplace_back(std::make_unique()); } + phase.emplace_back(std::make_unique()); + ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart); PhaseRunner phase_runner{m}; phase_runner.attach(&prog); diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp new file mode 100644 index 00000000000..03550fda210 --- /dev/null +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/EliminateDeadSubgraphPass.h" + +#include + +#include +#include + +namespace luci +{ + +namespace +{ + +// Go through the current graph and check all other graphs reachable from it and save it. +// Note: The main idea for finding achievable graphs is that we can get into other graphs only +// from some operations (see the list below) and we check the graph numbers from these operations. +void checkGraph(loco::Graph *current_graph, std::deque &reachable_graphs_indexes_q) +{ + assert(current_graph != nullptr); + + // 1 - Obtain all active nodes in current graph + // 2 - Go through all active nodes and check its types + // 3 - If it is possible to get to another graph from the operation (see the list below), + // then add the graph numbers to our queue + + // 1 - Obtain all active nodes in current graph + // Let's enumerate nodes required to compute output nodes + auto active_nodes = loco::active_nodes(loco::output_nodes(current_graph)); + + // 2 - Go through all active nodes and check its types + // Nodes from we can obtain different subgraph: + // While, If, ... + // TODO: check all nodes which can be used to reach different subgraph + for (auto &node : active_nodes) + { + auto *circle_node = dynamic_cast(node); + assert(circle_node != nullptr); + + switch (circle_node->opcode()) + { + case CircleOpcode::WHILE: + { + auto *while_node = dynamic_cast(circle_node); + assert(while_node != nullptr); + // Get body and cond graph indexes + int32_t body_graph_index = while_node->body_branch(); + int32_t cond_graph_index = while_node->cond_branch(); + assert(body_graph_index >= 0); + assert(cond_graph_index >= 0); + // Add indexes into queue + reachable_graphs_indexes_q.push_back(size_t(body_graph_index)); + reachable_graphs_indexes_q.push_back(size_t(cond_graph_index)); + } + break; + case CircleOpcode::IF: + { + auto *if_node = dynamic_cast(circle_node); + assert(if_node != nullptr); + // Get then and else graph indexes + int32_t else_index = if_node->else_branch(); + int32_t then_index = if_node->then_branch(); + assert(else_index >= 0); + assert(then_index >= 0); + // Add indexes into queue + reachable_graphs_indexes_q.push_back(size_t(else_index)); + reachable_graphs_indexes_q.push_back(size_t(then_index)); + } + break; + default: + continue; + } + } +} + +} // namespace + +/** + * Eliminate dead subgraph. + * Note: dead means inaccessible from the main (with index zero) graph + **/ +bool EliminateDeadSubgraphPass::run(luci::Module *m) +{ + bool changed = false; + + // Nothing check + if (m->size() == 1 or m->size() == 0) + return false; + + std::unordered_set reachable_indexes; + + // Queue with reachable graphs indexes + std::deque reachable_graphs_indexes_q; + // Insert main graph - with index zero + reachable_graphs_indexes_q.push_back(0); + + while (reachable_graphs_indexes_q.empty() == false) + { + // Get first index from queue and remove it from queue + auto current_graph_index = reachable_graphs_indexes_q.front(); + reachable_graphs_indexes_q.pop_front(); + + // If already check this graph - continue + if (reachable_indexes.find(current_graph_index) != reachable_indexes.end()) + continue; + + // Add current index to reachable set + reachable_indexes.insert(current_graph_index); + + // Check current graph and add all graph indexes which can be reached from current graph + loco::Graph *graph = m->graph(current_graph_index); + assert(graph != nullptr); + checkGraph(graph, reachable_graphs_indexes_q); + } + + assert(!reachable_indexes.empty()); + // Let's remove all indexes which can not be reached from main graph + for (size_t i = 0; i < m->size(); ++i) + { + if (reachable_indexes.find(i) != reachable_indexes.end()) + continue; + + m->removeGraphByIndex(i); + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp new file mode 100644 index 00000000000..d9b6af899e4 --- /dev/null +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/EliminateDeadSubgraphPass.h" +#include "luci/IR/CircleNodes.h" +#include "luci/IR/Module.h" + +#include + +namespace +{ + +class EliminateDeadSubgraphPassTest : public ::testing::Test +{ +public: + EliminateDeadSubgraphPassTest() + { + auto main_g = loco::make_graph(); + _main_graph = main_g.get(); + _module.add(std::move(main_g)); + + auto graph_1 = loco::make_graph(); + _graph_1 = graph_1.get(); + _module.add(std::move(graph_1)); + + auto graph_2 = loco::make_graph(); + _graph_2 = graph_2.get(); + _module.add(std::move(graph_2)); + + // This graph is unreachable + auto graph_3 = loco::make_graph(); + _graph_3 = graph_3.get(); + _module.add(std::move(graph_3)); + + // For main graph + { + auto input_main_node = _main_graph->nodes()->create(); + auto if_node = _main_graph->nodes()->create(1, 1); + if_node->input(0, input_main_node); + if_node->then_branch(1); + if_node->else_branch(2); + auto output_main_node = _main_graph->nodes()->create(); + output_main_node->from(if_node); + + auto graph_input = _main_graph->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _main_graph->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For first graph + { + auto input_main_node = _graph_1->nodes()->create(); + auto output_main_node = _graph_1->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_1->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_1->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For second graph + { + auto input_main_node = _graph_2->nodes()->create(); + auto output_main_node = _graph_2->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_2->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_2->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For third (dead) graph + { + auto input_main_node = _graph_3->nodes()->create(); + auto output_main_node = _graph_3->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_3->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_3->outputs()->create(); + output_main_node->index(graph_output->index()); + } + } + +protected: + luci::Module _module; + loco::Graph *_main_graph = nullptr; + loco::Graph *_graph_1 = nullptr; + loco::Graph *_graph_2 = nullptr; + loco::Graph *_graph_3 = nullptr; +}; + +} // namespace + +TEST_F(EliminateDeadSubgraphPassTest, remove_dead_subgraph) +{ + luci::EliminateDeadSubgraphPass pass; + + // Before removing dead nodes it is has 4 graphs + ASSERT_EQ(_module.size(), 4); + + ASSERT_TRUE(pass.run(&_module)); + + // After remove one dead graph - result is 3 + ASSERT_EQ(_module.size(), 3); +} + +TEST_F(EliminateDeadSubgraphPassTest, no_graphs_NEG) +{ + luci::EliminateDeadSubgraphPass pass; + auto m = luci::make_module(); + ASSERT_ANY_THROW(pass.run(m.get())); +} From 5eb8f9e7378e4c79ff2f80deb8b4e0b1b785ee83 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 12 Aug 2024 14:06:19 +0300 Subject: [PATCH 3/5] [recipe] Add recipe for Decomposed GRU. This pr adds recipe for decomposed GRU. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../Net_Decomposed_GRU_000/test.recipe | 912 ++++++++++++++++++ 1 file changed, 912 insertions(+) create mode 100644 res/TensorFlowLiteRecipes/Net_Decomposed_GRU_000/test.recipe diff --git a/res/TensorFlowLiteRecipes/Net_Decomposed_GRU_000/test.recipe b/res/TensorFlowLiteRecipes/Net_Decomposed_GRU_000/test.recipe new file mode 100644 index 00000000000..a1edc5eb3ff --- /dev/null +++ b/res/TensorFlowLiteRecipes/Net_Decomposed_GRU_000/test.recipe @@ -0,0 +1,912 @@ +operand { + name: "serving_default_x:0" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 2 + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "strided_slice_2" + type: INT32 + shape { + dim: 3 + } + filler { + tag: "explicit" + arg: "-1" + arg: "0" + arg: "0" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "strided_slice_21" + type: INT32 + shape { + dim: 3 + } + filler { + tag: "explicit" + arg: "0" + arg: "1" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "strided_slice_22" + type: INT32 + shape { + dim: 3 + } + filler { + tag: "explicit" + arg: "1" + arg: "1" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "TensorArrayV2_1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 1 + } + filler { + tag: "explicit" + arg: "0" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "time" + type: INT32 + shape { + } + filler { + tag: "explicit" + arg: "0" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "strided_slice" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "sequential/gru/zeros" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + filler { + tag: "explicit" + arg: "0" + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "while" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "while1" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "while2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "while3" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "while4" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 2 + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operand { + name: "StatefulPartitionedCall:0" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false +} +operation { + type: "While" + input: "time" + input: "time" + input: "TensorArrayV2_1" + input: "sequential/gru/zeros" + input: "serving_default_x:0" + output: "while" + output: "while1" + output: "while2" + output: "while3" + output: "while4" + while_options { + cond_subgraph_index: 1 + body_subgraph_index: 2 + } +} +operation { + type: "StridedSlice" + input: "while2" + input: "strided_slice_2" + input: "strided_slice_21" + input: "strided_slice_22" + output: "StatefulPartitionedCall:0" + strided_slice_options { + begin_mask: 6 + end_mask: 6 + ellipsis_mask: 0 + new_axis_mask: 0 + shrink_axis_mask: 1 + } +} +input: "serving_default_x:0" +output: "StatefulPartitionedCall:0" +graph { + operand { + name: "arg0" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg1" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg3" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg4" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 2 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "strided_slice1" + type: INT32 + shape { + } + filler { + tag: "explicit" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/Less" + type: BOOL + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operation { + type: "Less" + input: "arg1" + input: "strided_slice1" + output: "while/Less" + } + input: "arg0" + input: "arg1" + input: "arg2" + input: "arg3" + input: "arg4" + output: "while/Less" +} +graph { + operand { + name: "arg0" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg1" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg3" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "arg4" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 2 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "strided_slice_23" + type: INT32 + shape { + dim: 3 + } + filler { + tag: "explicit" + arg: "1" + arg: "1" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "strided_slice2" + type: INT32 + shape { + } + filler { + tag: "explicit" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/MatMul_11" + type: FLOAT32 + shape { + dim: 3 + dim: 1 + } + filler { + tag: "explicit" + arg: "-0.450822" + arg: "0.837692" + arg: "-0.308273" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/MatMul" + type: FLOAT32 + shape { + dim: 3 + dim: 2 + } + filler { + tag: "explicit" + arg: "-0.857677" + arg: "-0.786605" + arg: "-0.816151" + arg: "0.0673127" + arg: "0.850263" + arg: "0.351641" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/sub/x" + type: FLOAT32 + shape { + } + filler { + tag: "explicit" + arg: "1" + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add_4" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/MatMul_12" + type: FLOAT32 + shape { + dim: 1 + dim: 3 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split_1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split_11" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split_12" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/TensorArrayV2Read/TensorListGetItem;time" + type: FLOAT32 + shape { + dim: 1 + dim: 2 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/MatMul1" + type: FLOAT32 + shape { + dim: 1 + dim: 3 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/split2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/Sigmoid" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/mul_1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/sub" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add_1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/Sigmoid_1" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/mul" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add_2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/Tanh" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/mul_2" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add_3" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/TensorArrayV2Write/TensorListSetItem" + type: FLOAT32 + shape { + dim: 1 + dim: 1 + dim: 1 + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operand { + name: "while/add_5" + type: INT32 + shape { + } + quant { + quantized_dimension: 0 + } + is_variable: false + } + operation { + type: "Add" + input: "arg1" + input: "strided_slice2" + output: "while/add_4" + add_options { + activation: NONE + } + } + operation { + type: "FullyConnected" + input: "arg3" + input: "while/MatMul_11" + input: "" + output: "while/MatMul_12" + fullyconnected_options { + activation: NONE + keep_num_dims: false + } + } + operation { + type: "Split" + input: "strided_slice2" + input: "while/MatMul_12" + output: "while/split_1" + output: "while/split_11" + output: "while/split_12" + split_options { + num_splits: 3 + } + } + operation { + type: "Gather" + input: "arg4" + input: "arg1" + output: "while/TensorArrayV2Read/TensorListGetItem;time" + gather_options { + axis: 0 + } + } + operation { + type: "FullyConnected" + input: "while/TensorArrayV2Read/TensorListGetItem;time" + input: "while/MatMul" + input: "" + output: "while/MatMul1" + fullyconnected_options { + activation: NONE + keep_num_dims: false + } + } + operation { + type: "Split" + input: "strided_slice2" + input: "while/MatMul1" + output: "while/split" + output: "while/split1" + output: "while/split2" + split_options { + num_splits: 3 + } + } + operation { + type: "Add" + input: "while/split" + input: "while/split_1" + output: "while/add" + add_options { + activation: NONE + } + } + operation { + type: "Logistic" + input: "while/add" + output: "while/Sigmoid" + } + operation { + type: "Mul" + input: "while/Sigmoid" + input: "arg3" + output: "while/mul_1" + mul_options { + activation: NONE + } + } + operation { + type: "Sub" + input: "while/sub/x" + input: "while/Sigmoid" + output: "while/sub" + sub_options { + activation: NONE + } + } + operation { + type: "Add" + input: "while/split1" + input: "while/split_11" + output: "while/add_1" + add_options { + activation: NONE + } + } + operation { + type: "Logistic" + input: "while/add_1" + output: "while/Sigmoid_1" + } + operation { + type: "Mul" + input: "while/Sigmoid_1" + input: "while/split_12" + output: "while/mul" + mul_options { + activation: NONE + } + } + operation { + type: "Add" + input: "while/split2" + input: "while/mul" + output: "while/add_2" + add_options { + activation: NONE + } + } + operation { + type: "Tanh" + input: "while/add_2" + output: "while/Tanh" + } + operation { + type: "Mul" + input: "while/sub" + input: "while/Tanh" + output: "while/mul_2" + mul_options { + activation: NONE + } + } + operation { + type: "Add" + input: "while/mul_1" + input: "while/mul_2" + output: "while/add_3" + add_options { + activation: NONE + } + } + operation { + type: "Reshape" + input: "while/add_3" + input: "strided_slice_23" + output: "while/TensorArrayV2Write/TensorListSetItem" + } + operation { + type: "Add" + input: "arg0" + input: "strided_slice2" + output: "while/add_5" + add_options { + activation: NONE + } + } + input: "arg0" + input: "arg1" + input: "arg2" + input: "arg3" + input: "arg4" + output: "while/add_5" + output: "while/add_4" + output: "while/TensorArrayV2Write/TensorListSetItem" + output: "while/add_3" + output: "arg4" +} From ac88f3d036b0146caf39a1c90b5359c91cab45ec Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 12 Aug 2024 14:18:16 +0300 Subject: [PATCH 4/5] [circle2circle-dredd-recipe-test] Add fuse_gru test This commit adds fuse_gru test to circle2circle-dredd-recipe-test. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- compiler/circle2circle-dredd-recipe-test/test.lst | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler/circle2circle-dredd-recipe-test/test.lst b/compiler/circle2circle-dredd-recipe-test/test.lst index 4bf6a80d65a..e2a09f6e6d5 100644 --- a/compiler/circle2circle-dredd-recipe-test/test.lst +++ b/compiler/circle2circle-dredd-recipe-test/test.lst @@ -43,6 +43,7 @@ Add(Net_Conv_Mul_003 PASS fuse_mul_with_conv) Add(Net_Conv_PReluGraph_000 PASS fuse_prelu) Add(Net_Conv_QuantDequant_000 PASS remove_quantdequant) Add(Net_Conv_Relu6_000 PASS fuse_activation_function) +Add(Net_Decomposed_GRU_000 PASS fuse_gru) Add(Net_Duplicate_Weights_000 PASS remove_duplicate_const) Add(Net_DwConv_BN_000 PASS fuse_batchnorm_with_dwconv) Add(Net_DwConv_BN_001 PASS fuse_batchnorm_with_dwconv) From bc903d1d6c19b3110ad5d4e7cb80507d27c0a64a Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 12 Aug 2024 14:18:26 +0300 Subject: [PATCH 5/5] small fixes --- compiler/circle2circle/src/Circle2Circle.cpp | 2 +- .../pass/src/EliminateDeadSubgraphPass.cpp | 23 ++++++++----------- .../src/EliminateDeadSubgraphPass.test.cpp | 2 +- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 99b10c5fbe2..0965618048e 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -518,7 +518,7 @@ int entry(int argc, char **argv) luci::change_outputs(graph, new_outputs); } - // call luci optimizations for module before optimizations for graph + // call luci optimizations for module optimizer.optimize(module.get()); for (size_t idx = 0; idx < module->size(); ++idx) diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp index 03550fda210..5653c1b6e9f 100644 --- a/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp @@ -28,15 +28,15 @@ namespace { // Go through the current graph and check all other graphs reachable from it and save it. -// Note: The main idea for finding achievable graphs is that we can get into other graphs only -// from some operations (see the list below) and we check the graph numbers from these operations. +// Note: The main idea for finding achievable graphs is that we can reach other graphs only +// from some operations (see the list below) and we check the graph indexes from these operations. void checkGraph(loco::Graph *current_graph, std::deque &reachable_graphs_indexes_q) { assert(current_graph != nullptr); // 1 - Obtain all active nodes in current graph // 2 - Go through all active nodes and check its types - // 3 - If it is possible to get to another graph from the operation (see the list below), + // 3 - If it is possible to reach another graph from the current operation (see the list below), // then add the graph numbers to our queue // 1 - Obtain all active nodes in current graph @@ -49,37 +49,34 @@ void checkGraph(loco::Graph *current_graph, std::deque &reachable_graphs // TODO: check all nodes which can be used to reach different subgraph for (auto &node : active_nodes) { - auto *circle_node = dynamic_cast(node); - assert(circle_node != nullptr); + auto *circle_node = loco::must_cast(node); switch (circle_node->opcode()) { case CircleOpcode::WHILE: { - auto *while_node = dynamic_cast(circle_node); - assert(while_node != nullptr); + auto *while_node = loco::must_cast(circle_node); // Get body and cond graph indexes int32_t body_graph_index = while_node->body_branch(); int32_t cond_graph_index = while_node->cond_branch(); assert(body_graph_index >= 0); assert(cond_graph_index >= 0); // Add indexes into queue - reachable_graphs_indexes_q.push_back(size_t(body_graph_index)); - reachable_graphs_indexes_q.push_back(size_t(cond_graph_index)); + reachable_graphs_indexes_q.push_back(static_cast(body_graph_index)); + reachable_graphs_indexes_q.push_back(static_cast(cond_graph_index)); } break; case CircleOpcode::IF: { - auto *if_node = dynamic_cast(circle_node); - assert(if_node != nullptr); + auto *if_node = loco::must_cast(circle_node); // Get then and else graph indexes int32_t else_index = if_node->else_branch(); int32_t then_index = if_node->then_branch(); assert(else_index >= 0); assert(then_index >= 0); // Add indexes into queue - reachable_graphs_indexes_q.push_back(size_t(else_index)); - reachable_graphs_indexes_q.push_back(size_t(then_index)); + reachable_graphs_indexes_q.push_back(static_cast(else_index)); + reachable_graphs_indexes_q.push_back(static_cast(then_index)); } break; default: diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp index d9b6af899e4..950066178e9 100644 --- a/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp @@ -129,5 +129,5 @@ TEST_F(EliminateDeadSubgraphPassTest, no_graphs_NEG) { luci::EliminateDeadSubgraphPass pass; auto m = luci::make_module(); - ASSERT_ANY_THROW(pass.run(m.get())); + ASSERT_FALSE(pass.run(m.get())); }