Skip to content

Commit

Permalink
[onert] Replace btopolSortOperations with essentialBackwardOrder (#13375
Browse files Browse the repository at this point in the history
)

This commit replaces btopolSortOperations with essentialBackwardOrder.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 10, 2024
1 parent 2b44d6f commit 20e8e5e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
22 changes: 10 additions & 12 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,25 @@ void AddBackPropInitializers(const ir::train::TrainableGraph &tgraph, TensorRegi
unvisited.add(index);
});

for (const auto &op_index : tgraph.btopolSortOperations())
for (const auto &op_index : tgraph.essentialBackwardOrder())
{
assert(fn_map.find(op_index) != fn_map.end());

auto &tn_seq = fn_map.at(op_index);

// The function added lastest is executed first in a sequence during backwarding.
// The function added latest is executed first in a sequence during backwarding.
std::vector<BackPropTensor *> back_props;
const auto &op = tgraph.operation(op_index);
for (const auto &back_prop_index :
op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
{
if (op.isRequiredForBackward())
assert(op.isRequiredForBackward());
if (unvisited.contains(back_prop_index))
{
if (unvisited.contains(back_prop_index))
{
auto back_prop_tensor = tensor_reg.getBackPropTensor(back_prop_index);
assert(back_prop_tensor != nullptr);
back_props.emplace_back(back_prop_tensor);
unvisited.remove(back_prop_index);
}
auto back_prop_tensor = tensor_reg.getBackPropTensor(back_prop_index);
assert(back_prop_tensor != nullptr);
back_props.emplace_back(back_prop_tensor);
unvisited.remove(back_prop_index);
}
}
if (back_props.size() != 0)
Expand Down Expand Up @@ -138,7 +136,7 @@ backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
tensor_builder->notifyBackwardFirstUse(ind);
});

for (const auto &op_index : tgraph.btopolSortOperations())
for (const auto &op_index : tgraph.essentialBackwardOrder())
{
const auto back_prop_seq = getBackPropSeq(tgraph, op_index);
for (const auto &back_prop_index : back_prop_seq)
Expand All @@ -163,7 +161,7 @@ void BackendContext::planDisposableBackPropTensors()
auto tensor_builder = _tensor_builder;

std::vector<DisposableTensorIndex> prev_seq;
for (const auto &op_index : tgraph.btopolSortOperations())
for (const auto &op_index : tgraph.essentialBackwardOrder())
{
for (const auto &index : prev_seq)
{
Expand Down
4 changes: 1 addition & 3 deletions runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,7 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
Linear::dump(*lowered_graph, order);

// linearize for backwarding
auto backward_order = lowered_graph->trainable_graph().btopolSortOperations();
// get rid of all nodes not reachable from a node with trainable parameters
backward_order = lowered_graph->trainable_graph().truncateBackwardOrder(backward_order);
auto backward_order = lowered_graph->trainable_graph().essentialBackwardOrder();
VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl;
Linear::dump(*lowered_graph, backward_order);

Expand Down

0 comments on commit 20e8e5e

Please sign in to comment.