Skip to content

Commit

Permalink
[onert] Apply usedef chains for training (#13462)
Browse files Browse the repository at this point in the history
This commit applies training usedefs to patial trainable graphs to be passed into backends.
   - Add a method to TrainableGraph that updates trainable graph dependency by using training usedefs
   - Call the method for patial trainable graphs

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 26, 2024
1 parent cdd5a91 commit 86c70ad
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 1 deletion.
3 changes: 3 additions & 0 deletions runtime/onert/core/include/ir/train/TrainableGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class TrainableGraph : public IGraph
truncateBackwardOrder(std::vector<ir::OperationIndex> backward_order,
std::function<bool(const ir::OperationIndex &)> truncating_cond) const;

public:
void updateGraphDependency();

private:
Graph _graph;
Operands _backward_operands;
Expand Down
15 changes: 14 additions & 1 deletion runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,20 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
external_operands.remove(index);
}

const auto backend = pair.first;
// NOTE The builtin backend does not yet support initializing UseDefs for training
// because it's graph does not have loss operation
// Without loss opeartion, we cannot call btopolSortOperations() or
// getEssentialBackwardOrder()
// TODO Modify checking the condition to check whether loss op exists
if (backend->config()->id() != "builtin")
{
// Initialize training def-uses
tgraph->updateGraphDependency();

tgraph->verify();
}

// Set trainable context data
backend::train::TrainableContextData tdata;
tdata.tgraph = std::move(tgraph);
Expand All @@ -706,7 +720,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
tdata.optim_info = training_info.optimizerInfo();

// TODO Remove dynamic_cast
const auto backend = pair.first;
const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend);
if (!tbackend)
{
Expand Down
9 changes: 9 additions & 0 deletions runtime/onert/core/src/compiler/train/TrainingCompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value()));
}

for (auto &&[subg_index, subg] : trainable_subgraphs)
{
subg->updateGraphDependency();
subg->verify();

dot_dumper.dump(*subg,
nnfw::misc::str("after_initializing_training_usedefs-", subg_index.value()));
}

// Change input shape according to batch_size
for (auto &&pair : trainable_subgraphs)
{
Expand Down
59 changes: 59 additions & 0 deletions runtime/onert/core/src/ir/train/TrainableGraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ir/train/TrainableGraph.h"

#include "ir/OperandIndexMap.h"
#include "UseDefGenerator.h"
#include "util/Utils.h"
#include "util/Set.h"
#include "../verifier/Verifier.h"
Expand All @@ -26,6 +27,52 @@
#include <map>
#include <misc/polymorphic_downcast.h>

namespace
{

using namespace onert;
using namespace onert::ir;
using namespace onert::ir::train;

void disableUnusedBackwardNodes(const UseDefChains &training_usedefs, TrainableGraph &tgraph)
{
// Disable backward nodes that will be unused
const auto border = tgraph.btopolSortOperations();
for (const auto &op_index : border)
{
const auto &node = tgraph.operations().at(op_index);
const auto &candidates =
(node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
const bool is_backward_op_used =
std::any_of(candidates.begin(), candidates.end(), [&](const OperandIndex &operand) {
const auto training_op_index = TrainingOperationIndex{op_index, false};
const auto forwarding_index = TrainingOperandIndex{operand, true};
const auto &forwarding_uses = training_usedefs.at(forwarding_index).getTrainingUses();
const auto backwarding_index = TrainingOperandIndex{operand, false};
const auto &backwarding_uses = training_usedefs.at(backwarding_index).getTrainingUses();
return forwarding_uses.find(training_op_index) != forwarding_uses.end() ||
backwarding_uses.find(training_op_index) != backwarding_uses.end();
});

// NOTE Backward op does not define any incoming operand in backwarding
const auto &inputs = node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
const bool is_backward_op_def =
std::any_of(inputs.begin(), inputs.end(), [&](const OperandIndex &input) {
const auto training_op_index = TrainingOperationIndex{op_index, false};
const auto outcoming_index = TrainingOperandIndex{input, false};
const auto &backwarding_defs = training_usedefs.at(outcoming_index).getTrainingUses();
return backwarding_defs.find(training_op_index) != backwarding_defs.end();
});

if (is_backward_op_used || is_backward_op_def)
tgraph.enableBackward(op_index);
else
tgraph.disableBackward(op_index);
}
}

} // namespace

namespace onert
{
namespace ir
Expand Down Expand Up @@ -332,6 +379,18 @@ OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const
return (itr == _losses.end()) ? OperandIndex{} : itr->second;
}

void TrainableGraph::updateGraphDependency()
{
_graph.verify();

// Initialize training usedefs
setTrainingUseDefs(UseDefGenerator{*this}());

disableUnusedBackwardNodes(_training_defuses, *this);

verifyTrainingUseDefs();
}

} // namespace train
} // namespace ir
} // namespace onert

0 comments on commit 86c70ad

Please sign in to comment.