diff --git a/runtime/onert/backend/train/ExtraTensorGenerator.cc b/runtime/onert/backend/train/ExtraTensorGenerator.cc index 0cbe8b8b1f9..6e2cf2aa978 100644 --- a/runtime/onert/backend/train/ExtraTensorGenerator.cc +++ b/runtime/onert/backend/train/ExtraTensorGenerator.cc @@ -51,7 +51,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, ExtraTens { // register tensor ExtraTensorIndex tensor_idx(op_idx, i); - _tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info); + _tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info()); std::stringstream op_info; op_info << op_idx << "_" << operations.at(op_idx).name(); @@ -60,7 +60,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, ExtraTens // return registered tensor auto generated_tensor = _tensor_reg->getExtraTensor(tensor_idx); - *reqs[i].address = generated_tensor; + reqs[i].update_address(generated_tensor); } return; } @@ -74,7 +74,7 @@ void ExtraTensorGenerator::plan() auto &reqs = _idx_to_requests[op_index]; for (auto i = 0u; i < reqs.size(); ++i) { - auto < = reqs[i].lifetime; + const auto < = reqs[i].lifetime(); if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD) _tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i)); } @@ -88,14 +88,14 @@ void ExtraTensorGenerator::plan() for (auto i = 0u; i < reqs.size(); ++i) { - auto < = reqs[i].lifetime; + const auto < = reqs[i].lifetime(); if (lt == ExtraTensorLifeTime::BACKWARD) _tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i)); } for (auto i = 0u; i < reqs.size(); ++i) { - auto < = reqs[i].lifetime; + const auto < = reqs[i].lifetime(); if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD || lt == ExtraTensorLifeTime::BACKWARD) _tensor_builder->notifyLastUse(ExtraTensorIndex(op_index, i)); } diff --git a/runtime/onert/core/include/backend/train/ExtraTensorRequest.h b/runtime/onert/core/include/backend/train/ExtraTensorRequest.h index 6a430802fb8..cf6be5980cb 100644 --- a/runtime/onert/core/include/backend/train/ExtraTensorRequest.h +++ b/runtime/onert/core/include/backend/train/ExtraTensorRequest.h @@ -37,13 +37,13 @@ class ExtraTensorRequest public: ExtraTensorRequest(ir::OperandInfo info, ExtraTensorLifeTime lt, - backend::train::ExtraTensor **addr) - : info(info), lifetime(lt), address(addr) + ExtraTensor **addr) + : _info(info), _lifetime(lt), _address(addr) { } static ExtraTensorRequest createLike(const IPortableTensor *origin, - backend::train::ExtraTensor** addr) + ExtraTensor** addr) { assert(origin != nullptr); assert(addr != nullptr); @@ -52,9 +52,25 @@ class ExtraTensorRequest } public: - const ir::OperandInfo info; - const ExtraTensorLifeTime lifetime; - backend::train::ExtraTensor ** const address; + const ir::OperandInfo& info() const + { + return _info; + } + + ExtraTensorLifeTime lifetime() const + { + return _lifetime; + } + + void update_address(ExtraTensor* tensor) + { + *_address = tensor; + } + +private: + ir::OperandInfo _info; + ExtraTensorLifeTime _lifetime; + backend::train::ExtraTensor ** const _address; }; using ExtraTensorRequests = std::vector;