Skip to content

Commit

Permalink
[DRAFT][compiler] Introduce EliminateDeadSubgraphPass
Browse files Browse the repository at this point in the history
This pr introduces EliminateDeadSubgraphPass for removing dead subgraph.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Aug 7, 2024
1 parent 36c7317 commit 5d77de3
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 1 deletion.
5 changes: 4 additions & 1 deletion compiler/circle2circle/src/Circle2Circle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ int entry(int argc, char **argv)
luci::change_outputs(graph, new_outputs);
}

// call luci optimizations for module
// call luci optimizations for module before optimizations for graph
optimizer.optimize(module.get());

for (size_t idx = 0; idx < module->size(); ++idx)
Expand All @@ -541,6 +541,9 @@ int entry(int argc, char **argv)
}
}

// call luci optimizations for module after optimizations for graph
optimizer.optimize(module.get());

// Export to output Circle file
luci::CircleExporter exporter;

Expand Down
7 changes: 7 additions & 0 deletions compiler/luci/lang/include/luci/IR/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ class Module final
*/
loco::Graph *graph(void) const;

/**
* @brief remove graph at index
*
* @note graph(0) is interpreted as a main graph and cannot be deleted
*/
void removeGraphByIndex(size_t idx);

/**
* @brief provide graph with an index
*
Expand Down
8 changes: 8 additions & 0 deletions compiler/luci/lang/src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ loco::Graph *Module::graph(void) const
return graph.get();
}

void Module::removeGraphByIndex(size_t idx)
{
if (idx >= _graphs.size() or idx == 0)
throw std::invalid_argument("Module: Invalid graph index to be deleted");

_graphs.erase(_graphs.begin() + idx);
}

loco::Graph *Module::graph(size_t idx) const
{
auto &graph = _graphs.at(idx);
Expand Down
34 changes: 34 additions & 0 deletions compiler/luci/lang/src/Module.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ TEST(ModuleTest, add)
ASSERT_EQ(g_ptr, m->graph(0));
}

TEST(ModuleTest, remove)
{
auto m = luci::make_module();
auto g1 = loco::make_graph();
auto g2 = loco::make_graph();
auto g3 = loco::make_graph();
auto g1_ptr = g1.get();
auto g2_ptr = g2.get();
auto g3_ptr = g3.get();

m->add(std::move(g1));
m->add(std::move(g2));
m->add(std::move(g3));

ASSERT_EQ(3, m->size());
ASSERT_EQ(g1_ptr, m->graph());
ASSERT_EQ(g1_ptr, m->graph(0));
ASSERT_EQ(g2_ptr, m->graph(1));
ASSERT_EQ(g3_ptr, m->graph(2));

// Let's delete graph at second position
m->removeGraphByIndex(1);
ASSERT_EQ(2, m->size());
ASSERT_EQ(g1_ptr, m->graph(0));
ASSERT_EQ(g3_ptr, m->graph(1));
}

TEST(ModuleTest, add_more)
{
auto m = luci::make_module();
Expand Down Expand Up @@ -65,6 +92,13 @@ TEST(ModuleTest, add_nullptr_NEG)
EXPECT_THROW(m->add(nullptr), std::invalid_argument);
}

TEST(ModuleTest, remove_index_overflow_NEG)
{
auto m = luci::make_module();

EXPECT_THROW(m->removeGraphByIndex(10), std::invalid_argument);
}

TEST(ModuleTest, graph_index_overflow_NEG)
{
auto m = luci::make_module();
Expand Down
45 changes: 45 additions & 0 deletions compiler/luci/pass/include/luci/Pass/EliminateDeadSubgraphPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__
#define __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__

#include <logo/Pass.h>
#include <luci/ModulePass.h>
#include <luci/IR/Module.h>

namespace luci
{

/**
* @brief Class to eliminate dead subgraph
*
*/
struct EliminateDeadSubgraphPass final : public luci::Pass
{
const char *name(void) const final { return "luci::EliminateDeadSubgraphPass"; }

bool run(luci::Module *m);
bool run(loco::Graph *)
{
// Do nothing
return false;
}
};

} // namespace luci

#endif // __LUCI_ELIMINATE_DEAD_SUBGRAPH_PASS_H__
3 changes: 3 additions & 0 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@

#include "luci/Pass/CircleShapeInferencePass.h"
#include "luci/Pass/CircleTypeInferencePass.h"
#include "luci/Pass/EliminateDeadSubgraphPass.h"

// logo passes
#include <logo/RemoveDeadNodeWithQueryPass.h>
Expand Down Expand Up @@ -246,6 +247,8 @@ void CircleOptimizer::optimize(luci::Module *m) const
phase.emplace_back(std::make_unique<FuseBCQPass>());
}

phase.emplace_back(std::make_unique<luci::EliminateDeadSubgraphPass>());

ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
phase_runner.attach(&prog);
Expand Down
148 changes: 148 additions & 0 deletions compiler/luci/pass/src/EliminateDeadSubgraphPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/Pass/EliminateDeadSubgraphPass.h"

#include <luci/IR/CircleNodes.h>

#include <unordered_set>
#include <deque>

namespace luci
{

namespace
{

// Go through the current graph and check all other graphs reachable from it and save it.
// Note: The main idea for finding achievable graphs is that we can get into other graphs only
// from some operations (see the list below) and we check the graph numbers from these operations.
void checkGraph(loco::Graph *current_graph, std::deque<size_t> &reachable_graphs_indexes_q)
{
assert(current_graph != nullptr);

// 1 - Obtain all active nodes in current graph
// 2 - Go through all active nodes and check its types
// 3 - If it is possible to get to another graph from the operation (see the list below),
// then add the graph numbers to our queue

// 1 - Obtain all active nodes in current graph
// Let's enumerate nodes required to compute output nodes
auto active_nodes = loco::active_nodes(loco::output_nodes(current_graph));

// 2 - Go through all active nodes and check its types
// Nodes from we can obtain different subgraph:
// While, If, ...
// TODO: check all nodes which can be used to reach different subgraph
for (auto &node : active_nodes)
{
auto *circle_node = dynamic_cast<luci::CircleNode *>(node);
assert(circle_node != nullptr);

switch (circle_node->opcode())
{
case CircleOpcode::WHILE:
{
auto *while_node = dynamic_cast<luci::CircleWhile *>(circle_node);
assert(while_node != nullptr);
// Get body and cond graph indexes
int32_t body_graph_index = while_node->body_branch();
int32_t cond_graph_index = while_node->cond_branch();
assert(body_graph_index >= 0);
assert(cond_graph_index >= 0);
// Add indexes into queue
reachable_graphs_indexes_q.push_back(size_t(body_graph_index));
reachable_graphs_indexes_q.push_back(size_t(cond_graph_index));
}
break;
case CircleOpcode::IF:
{
auto *if_node = dynamic_cast<luci::CircleIf *>(circle_node);
assert(if_node != nullptr);
// Get then and else graph indexes
int32_t else_index = if_node->else_branch();
int32_t then_index = if_node->then_branch();
assert(else_index >= 0);
assert(then_index >= 0);
// Add indexes into queue
reachable_graphs_indexes_q.push_back(size_t(else_index));
reachable_graphs_indexes_q.push_back(size_t(then_index));
}
break;
default:
continue;
}
}
}

} // namespace

/**
* Eliminate dead subgraph.
* Note: dead means inaccessible from the main (with index zero) graph
**/
bool EliminateDeadSubgraphPass::run(luci::Module *m)
{
bool changed = false;

if (m->size() == 0)
throw std::invalid_argument("No any graphs");

// Nothing check
if (m->size() == 1)
return false;

std::unordered_set<size_t> reachable_indexes;

// Queue with reachable graphs indexes
std::deque<size_t> reachable_graphs_indexes_q;
// Insert main graph - with index zero
reachable_graphs_indexes_q.push_back(0);

while (reachable_graphs_indexes_q.empty() == false)
{
// Get first index from queue and remove it from queue
auto current_graph_index = reachable_graphs_indexes_q.front();
reachable_graphs_indexes_q.pop_front();

// If already check this graph - continue
if (reachable_indexes.find(current_graph_index) != reachable_indexes.end())
continue;

// Add current index to reachable set
reachable_indexes.insert(current_graph_index);

// Check current graph and add all graph indexes which can be reached from current graph
loco::Graph *graph = m->graph(current_graph_index);
assert(graph != nullptr);
checkGraph(graph, reachable_graphs_indexes_q);
}

assert(!reachable_indexes.empty());
// Let's remove all indexes which can not be reached from main graph
for (size_t i = 0; i < m->size(); ++i)
{
if (reachable_indexes.find(i) != reachable_indexes.end())
continue;

m->removeGraphByIndex(i);
changed = true;
}

return changed;
}

} // namespace luci
Loading

0 comments on commit 5d77de3

Please sign in to comment.