Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert] Unify generating training tensors and kernels #13656

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 37 additions & 42 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,50 @@ getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph,
}
} // namespace

backend::ITensorRegistry *BackendContext::genTensors()
FunctionMap BackendContext::gen()
{
planForwardTensors();
planBackwardTensors();

_tensor_builder->allocate();
_tensor_builder->allocateBackward();

return _tensor_registry.get();
auto fn_map = generateFunctionMap();

// Initialize TrainableTensors
trainable_graph()->operands().iterate(
[&](const ir::OperandIndex &ind, const ir::Operand &operand) {
if (external_operands().contains(ind) || !operand.isConstant())
return;

auto tensor = tensor_registry()->getNativeITensor(ind);
assert(tensor != nullptr);

VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;

auto data = operand.shareData();
assert(data && data->base());
auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);

if (trainable_tensor == nullptr)
throw std::runtime_error{"This tensor is not trainable tensor"};

trainable_tensor->fillBuffer(data);
});

// NOTE For memory optimization, we want to free some operand data
const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
.operands()
.iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });

// TODO Enable
// for (auto &&it : ret)
// {
// auto &fn_seq = it.second;
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return fn_map;
}

void BackendContext::planForwardTensors()
Expand Down Expand Up @@ -202,46 +237,6 @@ void BackendContext::planBackwardTensors()
tensor_planner.planDisposableBackPropTensors(tensor_builder.get());
}

FunctionMap BackendContext::genKernels()
{
auto ret = generateFunctionMap();

// Initialize TrainableTensors
trainable_graph()->operands().iterate(
[&](const ir::OperandIndex &ind, const ir::Operand &operand) {
if (external_operands().contains(ind) || !operand.isConstant())
return;

auto tensor = tensor_registry()->getNativeITensor(ind);
assert(tensor != nullptr);

VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl;

auto data = operand.shareData();
assert(data && data->base());
auto trainable_tensor = dynamic_cast<TrainableTensor *>(tensor);

if (trainable_tensor == nullptr)
throw std::runtime_error{"This tensor is not trainable tensor"};

trainable_tensor->fillBuffer(data);
});

// NOTE For memory optimization, we want to free some operand data
const_cast<ir::train::TrainableGraph &>(*_tdata->tgraph)
.operands()
.iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });

// TODO Enable
// for (auto &&it : ret)
// {
// auto &fn_seq = it.second;
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return ret;
}

FunctionMap BackendContext::generateFunctionMap()
{
train::FunctionMap ret;
Expand Down
4 changes: 1 addition & 3 deletions runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
BackendContext &operator=(const BackendContext &) = delete;

public:
backend::ITensorRegistry *genTensors() override;
FunctionMap gen() override;

private:
void planForwardTensors();
void planBackwardTensors();

public:
FunctionMap genKernels() override;

std::shared_ptr<ExternalContext> external_context() { return _external_context; }

const exec::train::optimizer::Optimizer *optimizer() const { return _optimizer.get(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class TrainableBackendContext

std::shared_ptr<ITensorRegistry> tensor_registry() { return _tensor_registry; }

virtual backend::ITensorRegistry *genTensors() = 0;
virtual FunctionMap genKernels() = 0;
virtual FunctionMap gen() = 0;

private:
const ITrainableBackend *_backend{nullptr};
Expand Down
14 changes: 5 additions & 9 deletions runtime/onert/core/src/backend/builtin/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,19 @@ namespace builtin
namespace train
{

backend::ITensorRegistry *BackendContext::genTensors()
backend::train::FunctionMap BackendContext::gen()
{
// For now, there is no need to generate tensors for forwarding and backwarding.
// builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
// `Permute`: Tensor generation is not required.
// `IF`, `WHILE`: Not supported yet
return tensor_registry().get();
}

backend::train::FunctionMap BackendContext::genKernels()
{
backend::train::FunctionMap ret;
backend::train::FunctionMap fn_map;

for (auto &&op_ind : _tdata->op_order)
{
auto tn_seq = kernel_gen->generate(op_ind);
ret.emplace(op_ind, std::move(tn_seq));
fn_map.emplace(op_ind, std::move(tn_seq));
}

trainable_graph()->operands().iterate(
Expand All @@ -57,13 +53,13 @@ backend::train::FunctionMap BackendContext::genKernels()
});

// TODO Enable prepare()
// for (auto &&it : ret)
// for (auto &&it : fn_map)
// {
// auto &fn_seq = it.second;
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

return ret;
return fn_map;
}

} // namespace train
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ class BackendContext : public backend::train::TrainableBackendContext
{
}

backend::ITensorRegistry *genTensors() override;

public:
backend::train::FunctionMap genKernels() override;
backend::train::FunctionMap gen() override;

std::shared_ptr<ExternalContext> external_context() { return _external_context; }

Expand Down
53 changes: 35 additions & 18 deletions runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,21 @@ std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext(
return ordered_contexts;
}

void generateCodes(backend::train::FunctionMap &codes,
const compiler::train::LoweredTrainableGraph *lowered_graph,
compiler::train::TrainableCodeMap &code_map)
{
for (auto &&[op_ind, tn_seq] : codes)
{
auto &op = lowered_graph->trainable_graph().operation(op_ind);
const auto backend = lowered_graph->lower_info().operation.at(op_ind);

assert(code_map.find(op_ind) == code_map.end());
code_map.insert(
{op_ind, compiler::train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}});
}
}

} // namespace
} // namespace onert

Expand Down Expand Up @@ -734,9 +749,17 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl;
Linear::dump(*lowered_graph, backward_order);

for (auto &&pair : tbackend_contexts)
train::TrainableCodeMap code_map;
// Generate tensors and kernels
for (auto &&[backend, context] : tbackend_contexts)
{
pair.second->genTensors();
// builtin backend's kernel generator requires access to tensors in other backends.
// So, the other backends must be generated first.
if (backend->config()->id() == "builtin")
continue;

auto fn_map = context->gen();
generateCodes(fn_map, lowered_graph.get(), code_map);
}

prepareMigrantTensors(*lowered_graph, tbackend_contexts);
Expand All @@ -754,6 +777,16 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
}
}

// Generate tensors and kernels for only builtin backend
for (auto &&[backend, context] : tbackend_contexts)
{
if (backend->config()->id() == "builtin")
{
auto fn_map = context->gen();
generateCodes(fn_map, lowered_graph.get(), code_map);
}
}

// Adjust the order of backends for the upcoming iteration
auto ordered_contexts =
onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts);
Expand Down Expand Up @@ -832,22 +865,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
}));
}

train::TrainableCodeMap code_map;
// Generate kernels
for (auto &&pair : ordered_contexts)
{
auto codes = pair.second->genKernels();
for (auto &&[op_ind, tn_seq] : codes)
{
auto &op = lowered_graph->trainable_graph().operation(op_ind);
const auto backend = lowered_graph->lower_info().operation.at(op_ind);

assert(code_map.find(op_ind) == code_map.end());
code_map.insert(
{op_ind, train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}});
}
}

if (order.size() != code_map.size())
{
throw std::runtime_error("ExecutorFactory: Some kernels are not generated");
Expand Down