Skip to content

Commit

Permalink
trim
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Aug 1, 2024
1 parent 1d98d40 commit 5ca9a7c
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 56 deletions.
3 changes: 0 additions & 3 deletions runtime/onert/backend/train/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ class Backend : public ::onert::backend::Backend, public backend::train::ITraina

context->kernel_gen = std::make_shared<train::KernelGenerator>(
tgraph, tr, context->external_context(), context->optimizer());

context->extra_tensor_gen = std::make_unique<train::ExtraTensorGenerator>(tgraph, tb, tr);

return context;
}

Expand Down
9 changes: 6 additions & 3 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "BackendContext.h"

#include "ExtraTensorGenerator.h"
#include "TensorBuilder.h"
#include "KernelGenerator.h"
#include "ops/BackPropInitializer.h"
Expand Down Expand Up @@ -230,6 +231,8 @@ FunctionMap BackendContext::genKernels()
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

ExtraTensorGenerator extra_tensor_gen(trainable_graph(), _tensor_builder, _tensor_registry);

const auto &ops = trainable_graph()->operations();

for (auto &pair : ret)
Expand All @@ -245,11 +248,11 @@ FunctionMap BackendContext::genKernels()
continue;

fn_seq->iterate([&](exec::train::ITrainableFunction &fn) {
extra_tensor_gen->register_tensors(op_idx, (&fn)->requestExtraTensors());
extra_tensor_gen.register_tensors(op_idx, (&fn)->requestExtraTensors());
});
}
extra_tensor_gen->plan_tensors();
extra_tensor_gen->allocate_tensors();
extra_tensor_gen.plan();
extra_tensor_gen.allocate();

return ret;
}
Expand Down
10 changes: 3 additions & 7 deletions runtime/onert/backend/train/BackendContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <backend/train/TrainableBackendContext.h>

#include "ExternalContext.h"
#include "ExtraTensorGenerator.h"
#include "KernelGenerator.h"
#include "TensorBuilder.h"

Expand Down Expand Up @@ -56,12 +55,10 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr,
std::shared_ptr<TensorBuilder> tensor_builder = nullptr,
std::unique_ptr<exec::train::optimizer::Optimizer> optimizer = nullptr,
std::shared_ptr<KernelGenerator> kernel_gen = nullptr,
std::unique_ptr<ExtraTensorGenerator> etensor_gen = nullptr)
std::shared_ptr<KernelGenerator> kernel_gen = nullptr)
: onert::backend::train::TrainableBackendContext(backend, std::move(tdata), tensor_registry),
kernel_gen{kernel_gen}, extra_tensor_gen{std::move(etensor_gen)},
_external_context(new ExternalContext), _tensor_builder{tensor_builder},
_optimizer{std::move(optimizer)}
kernel_gen{kernel_gen}, _external_context(new ExternalContext),
_tensor_builder{tensor_builder}, _optimizer{std::move(optimizer)}
{
}
BackendContext(const BackendContext &) = delete;
Expand Down Expand Up @@ -94,7 +91,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext
public:
// TODO Make it private
std::shared_ptr<KernelGenerator> kernel_gen;
std::unique_ptr<ExtraTensorGenerator> extra_tensor_gen;

private:
// NOTE ruy context has a thread pool, and when multiple ruy contexts are created,
Expand Down
34 changes: 13 additions & 21 deletions runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <ir/Operations.h>
#include <util/logging.h>
#include <memory>

namespace onert
{
Expand All @@ -28,10 +29,13 @@ namespace backend
namespace train
{

ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph &tgraph,
ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph,
std::shared_ptr<TensorBuilder> &tensor_builder,
std::shared_ptr<TensorRegistry> &tensor_registry)
: _tgraph(tgraph), _tensor_builder(tensor_builder), _tensor_reg(tensor_registry){};
std::shared_ptr<ITensorRegistry> &tensor_registry)
: _tgraph(tgraph), _tensor_builder(tensor_builder)
{
_tensor_reg = std::dynamic_pointer_cast<TensorRegistry>(tensor_registry);
}

void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx,
const ExtraTensorRequests &reqs)
Expand All @@ -41,7 +45,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx,
return;

_idx_to_requests[op_idx] = reqs;
auto &operations = _tgraph.operations();
auto &operations = _tgraph->operations();

for (size_t i = 0; i < reqs.size(); i++)
{
Expand All @@ -61,31 +65,23 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx,
return;
}

void ExtraTensorGenerator::plan_tensors()
void ExtraTensorGenerator::plan()
{
// forwarding order
const auto f_order = _tgraph.topolSortOperations();
const auto f_order = _tgraph->topolSortOperations();
for (const auto &op_index : f_order)
{
auto &reqs = _idx_to_requests[op_index];

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
if (lt == ExtraTensorLifeTime::FORWARD || lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD)
if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD)
_tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i));
}

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
if (lt == ExtraTensorLifeTime::FORWARD)
_tensor_builder->notifyLastUse(ExtraTensorIndex(op_index, i));
}
}

// backwarding order
const auto b_order = _tgraph.essentialBackwardOrder();
const auto b_order = _tgraph->essentialBackwardOrder();
for (const auto &op_index : b_order)
{
auto &reqs = _idx_to_requests[op_index];
Expand All @@ -106,11 +102,7 @@ void ExtraTensorGenerator::plan_tensors()
}
}

void ExtraTensorGenerator::allocate_tensors()
{
// allocate
_tensor_builder->allocateExtra();
}
void ExtraTensorGenerator::allocate() { _tensor_builder->allocateExtra(); }

} // namespace train
} // namespace backend
Expand Down
12 changes: 7 additions & 5 deletions runtime/onert/backend/train/ExtraTensorGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@ class ExtraTensorGenerator
{
public:
ExtraTensorGenerator() = delete;
ExtraTensorGenerator(const ir::train::TrainableGraph &tgraph,

ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph,
std::shared_ptr<TensorBuilder> &tensor_builder,
std::shared_ptr<TensorRegistry> &tensor_registry);
std::shared_ptr<ITensorRegistry> &tensor_registry);

public:
// Since register is reserved keyword, use 'register_tensors' intead of 'register'
void register_tensors(ir::OperationIndex idx, const ExtraTensorRequests &requests);
void plan_tensors();
void allocate_tensors();
void plan();
void allocate();

private:
const ir::train::TrainableGraph &_tgraph;
const ir::train::TrainableGraph *_tgraph;
std::shared_ptr<TensorBuilder> _tensor_builder;
std::shared_ptr<TensorRegistry> _tensor_reg;
std::unordered_map<ir::OperationIndex, ExtraTensorRequests> _idx_to_requests;
Expand Down
20 changes: 3 additions & 17 deletions runtime/onert/core/include/backend/train/ExtraTensorRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#ifndef __ONERT_BACKEND_EXTRA_TENSOR_REQUEST_H__
#define __ONERT_BACKEND_EXTRA_TENSOR_REQUEST_H__

#include <backend/basic/Tensor.h>
#include "backend/train/ExtraTensor.h"

namespace onert
{
Expand All @@ -26,23 +26,10 @@ namespace backend
namespace train
{

class ExtraTensor final : public basic::Tensor
{
public:
ExtraTensor() = delete;

public:
ExtraTensor(const ir::OperandInfo &info) : basic::Tensor(info, nullptr)
{
// DO NOTHING
}
};

enum class ExtraTensorLifeTime
{
FORWARD, // live during forward()
BACKWARD, // live during backward()
FORWARD_TO_BACKWARD, // live from forward to backward()
BACKWARD, // alive during backward()
FORWARD_TO_BACKWARD, // alive from forward to backward()
};

class ExtraTensorRequest
Expand All @@ -58,7 +45,6 @@ class ExtraTensorRequest
static ExtraTensorRequest createRequestLike(const IPortableTensor *origin,
backend::train::ExtraTensor **addr)
{

assert(origin != nullptr);
assert(addr != nullptr);

Expand Down

0 comments on commit 5ca9a7c

Please sign in to comment.