From 5ca9a7c7d08e1c0da86f0a4604c86b8dc9ce47de Mon Sep 17 00:00:00 2001 From: sseung Date: Thu, 1 Aug 2024 17:39:11 +0900 Subject: [PATCH] trim --- runtime/onert/backend/train/Backend.h | 3 -- runtime/onert/backend/train/BackendContext.cc | 9 +++-- runtime/onert/backend/train/BackendContext.h | 10 ++---- .../backend/train/ExtraTensorGenerator.cc | 34 +++++++------------ .../backend/train/ExtraTensorGenerator.h | 12 ++++--- .../backend/train/ExtraTensorRequest.h | 20 ++--------- 6 files changed, 32 insertions(+), 56 deletions(-) diff --git a/runtime/onert/backend/train/Backend.h b/runtime/onert/backend/train/Backend.h index 896355574e8..9ba689d0343 100644 --- a/runtime/onert/backend/train/Backend.h +++ b/runtime/onert/backend/train/Backend.h @@ -61,9 +61,6 @@ class Backend : public ::onert::backend::Backend, public backend::train::ITraina context->kernel_gen = std::make_shared( tgraph, tr, context->external_context(), context->optimizer()); - - context->extra_tensor_gen = std::make_unique(tgraph, tb, tr); - return context; } diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index fbfe8475ba8..8d0d6f22f00 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -16,6 +16,7 @@ #include "BackendContext.h" +#include "ExtraTensorGenerator.h" #include "TensorBuilder.h" #include "KernelGenerator.h" #include "ops/BackPropInitializer.h" @@ -230,6 +231,8 @@ FunctionMap BackendContext::genKernels() // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); // } + ExtraTensorGenerator extra_tensor_gen(trainable_graph(), _tensor_builder, _tensor_registry); + const auto &ops = trainable_graph()->operations(); for (auto &pair : ret) @@ -245,11 +248,11 @@ FunctionMap BackendContext::genKernels() continue; fn_seq->iterate([&](exec::train::ITrainableFunction &fn) { - extra_tensor_gen->register_tensors(op_idx, (&fn)->requestExtraTensors()); + extra_tensor_gen.register_tensors(op_idx, (&fn)->requestExtraTensors()); }); } - extra_tensor_gen->plan_tensors(); - extra_tensor_gen->allocate_tensors(); + extra_tensor_gen.plan(); + extra_tensor_gen.allocate(); return ret; } diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h index e30a2bd3240..6c36a1e8924 100644 --- a/runtime/onert/backend/train/BackendContext.h +++ b/runtime/onert/backend/train/BackendContext.h @@ -20,7 +20,6 @@ #include #include "ExternalContext.h" -#include "ExtraTensorGenerator.h" #include "KernelGenerator.h" #include "TensorBuilder.h" @@ -56,12 +55,10 @@ class BackendContext : public onert::backend::train::TrainableBackendContext std::shared_ptr tensor_registry = nullptr, std::shared_ptr tensor_builder = nullptr, std::unique_ptr optimizer = nullptr, - std::shared_ptr kernel_gen = nullptr, - std::unique_ptr etensor_gen = nullptr) + std::shared_ptr kernel_gen = nullptr) : onert::backend::train::TrainableBackendContext(backend, std::move(tdata), tensor_registry), - kernel_gen{kernel_gen}, extra_tensor_gen{std::move(etensor_gen)}, - _external_context(new ExternalContext), _tensor_builder{tensor_builder}, - _optimizer{std::move(optimizer)} + kernel_gen{kernel_gen}, _external_context(new ExternalContext), + _tensor_builder{tensor_builder}, _optimizer{std::move(optimizer)} { } BackendContext(const BackendContext &) = delete; @@ -94,7 +91,6 @@ class BackendContext : public onert::backend::train::TrainableBackendContext public: // TODO Make it private std::shared_ptr kernel_gen; - std::unique_ptr extra_tensor_gen; private: // NOTE ruy context has a thread pool, and when multiple ruy contexts are created, diff --git a/runtime/onert/backend/train/ExtraTensorGenerator.cc b/runtime/onert/backend/train/ExtraTensorGenerator.cc index 834af50ce6b..37735b078f7 100644 --- a/runtime/onert/backend/train/ExtraTensorGenerator.cc +++ b/runtime/onert/backend/train/ExtraTensorGenerator.cc @@ -20,6 +20,7 @@ #include #include +#include namespace onert { @@ -28,10 +29,13 @@ namespace backend namespace train { -ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph &tgraph, +ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph, std::shared_ptr &tensor_builder, - std::shared_ptr &tensor_registry) - : _tgraph(tgraph), _tensor_builder(tensor_builder), _tensor_reg(tensor_registry){}; + std::shared_ptr &tensor_registry) + : _tgraph(tgraph), _tensor_builder(tensor_builder) +{ + _tensor_reg = std::dynamic_pointer_cast(tensor_registry); +} void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, const ExtraTensorRequests &reqs) @@ -41,7 +45,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, return; _idx_to_requests[op_idx] = reqs; - auto &operations = _tgraph.operations(); + auto &operations = _tgraph->operations(); for (size_t i = 0; i < reqs.size(); i++) { @@ -61,31 +65,23 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, return; } -void ExtraTensorGenerator::plan_tensors() +void ExtraTensorGenerator::plan() { // forwarding order - const auto f_order = _tgraph.topolSortOperations(); + const auto f_order = _tgraph->topolSortOperations(); for (const auto &op_index : f_order) { auto &reqs = _idx_to_requests[op_index]; - for (auto i = 0u; i < reqs.size(); ++i) { auto < = reqs[i].lifetime; - if (lt == ExtraTensorLifeTime::FORWARD || lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD) + if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD) _tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i)); } - - for (auto i = 0u; i < reqs.size(); ++i) - { - auto < = reqs[i].lifetime; - if (lt == ExtraTensorLifeTime::FORWARD) - _tensor_builder->notifyLastUse(ExtraTensorIndex(op_index, i)); - } } // backwarding order - const auto b_order = _tgraph.essentialBackwardOrder(); + const auto b_order = _tgraph->essentialBackwardOrder(); for (const auto &op_index : b_order) { auto &reqs = _idx_to_requests[op_index]; @@ -106,11 +102,7 @@ void ExtraTensorGenerator::plan_tensors() } } -void ExtraTensorGenerator::allocate_tensors() -{ - // allocate - _tensor_builder->allocateExtra(); -} +void ExtraTensorGenerator::allocate() { _tensor_builder->allocateExtra(); } } // namespace train } // namespace backend diff --git a/runtime/onert/backend/train/ExtraTensorGenerator.h b/runtime/onert/backend/train/ExtraTensorGenerator.h index e6d0132e7e3..bf0341dd0e9 100644 --- a/runtime/onert/backend/train/ExtraTensorGenerator.h +++ b/runtime/onert/backend/train/ExtraTensorGenerator.h @@ -34,17 +34,19 @@ class ExtraTensorGenerator { public: ExtraTensorGenerator() = delete; - ExtraTensorGenerator(const ir::train::TrainableGraph &tgraph, + + ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph, std::shared_ptr &tensor_builder, - std::shared_ptr &tensor_registry); + std::shared_ptr &tensor_registry); public: + // Since register is reserved keyword, use 'register_tensors' intead of 'register' void register_tensors(ir::OperationIndex idx, const ExtraTensorRequests &requests); - void plan_tensors(); - void allocate_tensors(); + void plan(); + void allocate(); private: - const ir::train::TrainableGraph &_tgraph; + const ir::train::TrainableGraph *_tgraph; std::shared_ptr _tensor_builder; std::shared_ptr _tensor_reg; std::unordered_map _idx_to_requests; diff --git a/runtime/onert/core/include/backend/train/ExtraTensorRequest.h b/runtime/onert/core/include/backend/train/ExtraTensorRequest.h index 3275a45178c..84a7f10797a 100644 --- a/runtime/onert/core/include/backend/train/ExtraTensorRequest.h +++ b/runtime/onert/core/include/backend/train/ExtraTensorRequest.h @@ -17,7 +17,7 @@ #ifndef __ONERT_BACKEND_EXTRA_TENSOR_REQUEST_H__ #define __ONERT_BACKEND_EXTRA_TENSOR_REQUEST_H__ -#include +#include "backend/train/ExtraTensor.h" namespace onert { @@ -26,23 +26,10 @@ namespace backend namespace train { -class ExtraTensor final : public basic::Tensor -{ -public: - ExtraTensor() = delete; - -public: - ExtraTensor(const ir::OperandInfo &info) : basic::Tensor(info, nullptr) - { - // DO NOTHING - } -}; - enum class ExtraTensorLifeTime { - FORWARD, // live during forward() - BACKWARD, // live during backward() - FORWARD_TO_BACKWARD, // live from forward to backward() + BACKWARD, // alive during backward() + FORWARD_TO_BACKWARD, // alive from forward to backward() }; class ExtraTensorRequest @@ -58,7 +45,6 @@ class ExtraTensorRequest static ExtraTensorRequest createRequestLike(const IPortableTensor *origin, backend::train::ExtraTensor **addr) { - assert(origin != nullptr); assert(addr != nullptr);