From a0d6933a465d7cd04b1aa82e819131f041c8bf70 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Wed, 7 Aug 2024 14:17:11 +0300 Subject: [PATCH] [DRAFT][compiler] Introduce EliminateDeadSubgraphPass This pr introduces EliminateDeadSubgraphPass for removing dead subgraph. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- compiler/circle2circle/src/Circle2Circle.cpp | 5 +- compiler/luci/lang/include/luci/IR/Module.h | 7 + compiler/luci/lang/src/Module.cpp | 8 + compiler/luci/lang/src/Module.test.cpp | 34 ++++ .../luci/Pass/EliminateDeadSubgraphPass.h | 45 ++++++ compiler/luci/pass/src/CircleOptimizer.cpp | 3 + .../pass/src/EliminateDeadSubgraphPass.cpp | 148 ++++++++++++++++++ .../src/EliminateDeadSubgraphPass.test.cpp | 133 ++++++++++++++++ 8 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h create mode 100644 compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp create mode 100644 compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp index 80d775aa86e..99b10c5fbe2 100644 --- a/compiler/circle2circle/src/Circle2Circle.cpp +++ b/compiler/circle2circle/src/Circle2Circle.cpp @@ -518,7 +518,7 @@ int entry(int argc, char **argv) luci::change_outputs(graph, new_outputs); } - // call luci optimizations for module + // call luci optimizations for module before optimizations for graph optimizer.optimize(module.get()); for (size_t idx = 0; idx < module->size(); ++idx) @@ -541,6 +541,9 @@ int entry(int argc, char **argv) } } + // call luci optimizations for module after optimizations for graph + optimizer.optimize(module.get()); + // Export to output Circle file luci::CircleExporter exporter; diff --git a/compiler/luci/lang/include/luci/IR/Module.h b/compiler/luci/lang/include/luci/IR/Module.h index 75cf67905e7..dd4df3340be 100644 --- a/compiler/luci/lang/include/luci/IR/Module.h +++ b/compiler/luci/lang/include/luci/IR/Module.h @@ -51,6 +51,13 @@ class Module final */ loco::Graph *graph(void) const; + /** + * @brief remove graph at index + * + * @note graph(0) is interpreted as a main graph and cannot be deleted + */ + void removeGraphByIndex(size_t idx); + /** * @brief provide graph with an index * diff --git a/compiler/luci/lang/src/Module.cpp b/compiler/luci/lang/src/Module.cpp index 80ef61910f2..16c5a8277d8 100644 --- a/compiler/luci/lang/src/Module.cpp +++ b/compiler/luci/lang/src/Module.cpp @@ -35,6 +35,14 @@ loco::Graph *Module::graph(void) const return graph.get(); } +void Module::removeGraphByIndex(size_t idx) +{ + if (idx >= _graphs.size() or idx == 0) + throw std::invalid_argument("Module: Invalid graph index to be deleted"); + + _graphs.erase(_graphs.begin() + idx); +} + loco::Graph *Module::graph(size_t idx) const { auto &graph = _graphs.at(idx); diff --git a/compiler/luci/lang/src/Module.test.cpp b/compiler/luci/lang/src/Module.test.cpp index a5973e52dad..16c93250eb3 100644 --- a/compiler/luci/lang/src/Module.test.cpp +++ b/compiler/luci/lang/src/Module.test.cpp @@ -37,6 +37,33 @@ TEST(ModuleTest, add) ASSERT_EQ(g_ptr, m->graph(0)); } +TEST(ModuleTest, remove) +{ + auto m = luci::make_module(); + auto g1 = loco::make_graph(); + auto g2 = loco::make_graph(); + auto g3 = loco::make_graph(); + auto g1_ptr = g1.get(); + auto g2_ptr = g2.get(); + auto g3_ptr = g3.get(); + + m->add(std::move(g1)); + m->add(std::move(g2)); + m->add(std::move(g3)); + + ASSERT_EQ(3, m->size()); + ASSERT_EQ(g1_ptr, m->graph()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g2_ptr, m->graph(1)); + ASSERT_EQ(g3_ptr, m->graph(2)); + + // Let's delete graph at second position + m->removeGraphByIndex(1); + ASSERT_EQ(2, m->size()); + ASSERT_EQ(g1_ptr, m->graph(0)); + ASSERT_EQ(g3_ptr, m->graph(1)); +} + TEST(ModuleTest, add_more) { auto m = luci::make_module(); @@ -65,6 +92,13 @@ TEST(ModuleTest, add_nullptr_NEG) EXPECT_THROW(m->add(nullptr), std::invalid_argument); } +TEST(ModuleTest, remove_index_overflow_NEG) +{ + auto m = luci::make_module(); + + EXPECT_THROW(m->removeGraphByIndex(10), std::invalid_argument); +} + TEST(ModuleTest, graph_index_overflow_NEG) { auto m = luci::make_module(); diff --git a/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h b/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h new file mode 100644 index 00000000000..2b3c115ea1a --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ +#define __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ + +#include +#include +#include + +namespace luci +{ + +/** + * @brief Class to eliminate dead subgraph + * + */ +struct EliminateDeadSubgraphPass final : public luci::Pass +{ + const char *name(void) const final { return "luci::EliminateDeadSubgraphPass"; } + + bool run(luci::Module *m); + bool run(loco::Graph *graph) + { + // Do nothing + return false; + } +}; + +} // namespace luci + +#endif // __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index a9ac64c9ffe..d4db0e7dac6 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -101,6 +101,7 @@ #include "luci/Pass/CircleShapeInferencePass.h" #include "luci/Pass/CircleTypeInferencePass.h" +#include "luci/Pass/EliminateDeadSubgraphPass.h" // logo passes #include @@ -246,6 +247,8 @@ void CircleOptimizer::optimize(luci::Module *m) const phase.emplace_back(std::make_unique()); } + phase.emplace_back(std::make_unique()); + ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart); PhaseRunner phase_runner{m}; phase_runner.attach(&prog); diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp new file mode 100644 index 00000000000..da7396f08a2 --- /dev/null +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/EliminateDeadSubgraphPass.h" + +#include + +#include +#include + +namespace luci +{ + +namespace +{ + +// Go through the current graph and check all other graphs reachable from it and save it. +// Note: The main idea for finding achievable graphs is that we can get into other graphs only +// from some operations (see the list below) and we check the graph numbers from these operations. +void checkGraph(loco::Graph *current_graph, std::deque &reachable_graphs_indexes_q) +{ + assert(current_graph != nullptr); + + // 1 - Obtain all active nodes in current graph + // 2 - Go through all active nodes and check its types + // 3 - If it is possible to get to another graph from the operation (see the list below), + // then add the graph numbers to our queue + + // 1 - Obtain all active nodes in current graph + // Let's enumerate nodes required to compute output nodes + auto active_nodes = loco::active_nodes(loco::output_nodes(current_graph)); + + // 2 - Go through all active nodes and check its types + // Nodes from we can obtain different subgraph: + // While, If, ... + // TODO: check all nodes which can be used to reach different subgraph + for (auto &node : active_nodes) + { + auto *circle_node = dynamic_cast(node); + assert(circle_node != nullptr); + + switch (circle_node->opcode()) + { + case CircleOpcode::WHILE: + { + auto *while_node = dynamic_cast(circle_node); + assert(while_node != nullptr); + // Get body and cond graph indexes + int32_t body_graph_index = while_node->body_branch(); + int32_t cond_graph_index = while_node->cond_branch(); + assert(body_graph_index >= 0); + assert(cond_graph_index >= 0); + // Add indexes into queue + reachable_graphs_indexes_q.push_back(size_t(body_graph_index)); + reachable_graphs_indexes_q.push_back(size_t(cond_graph_index)); + } + break; + case CircleOpcode::IF: + { + auto *if_node = dynamic_cast(circle_node); + assert(if_node != nullptr); + // Get then and else graph indexes + int32_t else_index = if_node->else_branch(); + int32_t then_index = if_node->then_branch(); + assert(else_index >= 0); + assert(then_index >= 0); + // Add indexes into queue + reachable_graphs_indexes_q.push_back(size_t(else_index)); + reachable_graphs_indexes_q.push_back(size_t(then_index)); + } + break; + default: + continue; + } + } +} + +} // namespace + +/** + * Eliminate dead subgraph. + * Note: dead means inaccessible from the main (with index zero) graph + **/ +bool EliminateDeadSubgraphPass::run(luci::Module *m) +{ + bool changed = false; + + if (m->size() == 0) + throw std::invalid_argument("No any graphs"); + + // Nothing check + if (m->size() == 1) + return false; + + std::unordered_set reachable_indexes; + + // Queue with reachable graphs indexes + std::deque reachable_graphs_indexes_q; + // Insert main graph - with index zero + reachable_graphs_indexes_q.push_back(0); + + while (reachable_graphs_indexes_q.empty() == false) + { + // Get first index from queue and remove it from queue + auto current_graph_index = reachable_graphs_indexes_q.front(); + reachable_graphs_indexes_q.pop_front(); + + // If already check this graph - continue + if (reachable_indexes.find(current_graph_index) != reachable_indexes.end()) + continue; + + // Add current index to reachable set + reachable_indexes.insert(current_graph_index); + + // Check current graph and add all graph indexes which can be reached from current graph + loco::Graph *graph = m->graph(current_graph_index); + assert(graph != nullptr); + checkGraph(graph, reachable_graphs_indexes_q); + } + + assert(!reachable_indexes.empty()); + // Let's remove all indexes which can not be reached from main graph + for (size_t i = 0; i < m->size(); ++i) + { + if (reachable_indexes.find(i) != reachable_indexes.end()) + continue; + + m->removeGraphByIndex(i); + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp new file mode 100644 index 00000000000..d9b6af899e4 --- /dev/null +++ b/compiler/luci/pass/src/EliminateDeadSubgraphPass.test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/EliminateDeadSubgraphPass.h" +#include "luci/IR/CircleNodes.h" +#include "luci/IR/Module.h" + +#include + +namespace +{ + +class EliminateDeadSubgraphPassTest : public ::testing::Test +{ +public: + EliminateDeadSubgraphPassTest() + { + auto main_g = loco::make_graph(); + _main_graph = main_g.get(); + _module.add(std::move(main_g)); + + auto graph_1 = loco::make_graph(); + _graph_1 = graph_1.get(); + _module.add(std::move(graph_1)); + + auto graph_2 = loco::make_graph(); + _graph_2 = graph_2.get(); + _module.add(std::move(graph_2)); + + // This graph is unreachable + auto graph_3 = loco::make_graph(); + _graph_3 = graph_3.get(); + _module.add(std::move(graph_3)); + + // For main graph + { + auto input_main_node = _main_graph->nodes()->create(); + auto if_node = _main_graph->nodes()->create(1, 1); + if_node->input(0, input_main_node); + if_node->then_branch(1); + if_node->else_branch(2); + auto output_main_node = _main_graph->nodes()->create(); + output_main_node->from(if_node); + + auto graph_input = _main_graph->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _main_graph->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For first graph + { + auto input_main_node = _graph_1->nodes()->create(); + auto output_main_node = _graph_1->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_1->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_1->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For second graph + { + auto input_main_node = _graph_2->nodes()->create(); + auto output_main_node = _graph_2->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_2->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_2->outputs()->create(); + output_main_node->index(graph_output->index()); + } + + // For third (dead) graph + { + auto input_main_node = _graph_3->nodes()->create(); + auto output_main_node = _graph_3->nodes()->create(); + output_main_node->from(input_main_node); + + auto graph_input = _graph_3->inputs()->create(); + input_main_node->index(graph_input->index()); + + auto graph_output = _graph_3->outputs()->create(); + output_main_node->index(graph_output->index()); + } + } + +protected: + luci::Module _module; + loco::Graph *_main_graph = nullptr; + loco::Graph *_graph_1 = nullptr; + loco::Graph *_graph_2 = nullptr; + loco::Graph *_graph_3 = nullptr; +}; + +} // namespace + +TEST_F(EliminateDeadSubgraphPassTest, remove_dead_subgraph) +{ + luci::EliminateDeadSubgraphPass pass; + + // Before removing dead nodes it is has 4 graphs + ASSERT_EQ(_module.size(), 4); + + ASSERT_TRUE(pass.run(&_module)); + + // After remove one dead graph - result is 3 + ASSERT_EQ(_module.size(), 3); +} + +TEST_F(EliminateDeadSubgraphPassTest, no_graphs_NEG) +{ + luci::EliminateDeadSubgraphPass pass; + auto m = luci::make_module(); + ASSERT_ANY_THROW(pass.run(m.get())); +}