Skip to content
This repository has been archived by the owner on Feb 1, 2020. It is now read-only.

Commit

Permalink
[OPT] Improve PreComputePrune When Output Is Pruned (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Oct 17, 2017
1 parent 74549d6 commit a96133e
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions src/compiler/precompute_prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
// number of edges that are not variable
int non_var_edge = 0;

auto replace_pruned_entry = [&] (const NodeEntry& e) {
if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
return nnvm::NodeEntry{var, 0, 0};
} else {
return nnvm::NodeEntry{entry_var.at(e), 0, 0};
}
};

DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true;
if (n->is_variable()) {
Expand All @@ -47,20 +66,7 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
if (!e.node->is_variable()) {
++non_var_edge;
}
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name;
if (e.node->num_outputs() != 1) {
var->attrs.name += "_output" + std::to_string(e.index);
}
entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
}
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
e = replace_pruned_entry(e);
}
}
}
Expand All @@ -71,6 +77,12 @@ nnvm::Graph PrecomputePrune(nnvm::Graph src) {
return src;
}

for (auto& e : src.outputs) {
if (pruned.count(e.node.get())) {
e = replace_pruned_entry(e);
}
}

nnvm::Graph pre_graph;
pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names;
Expand Down

0 comments on commit a96133e

Please sign in to comment.