diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index 59fee712247..cba4eac8c31 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -179,9 +179,60 @@ FunctionMap BackendContext::gen() // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); // } + planLayerScopeTensors(fn_map); + _tensor_builder->allocateLayerScope(); + return fn_map; } +void BackendContext::planLayerScopeTensors(const FunctionMap &fn_map) +{ + + const auto &ops = trainable_graph()->operations(); + + auto register_tensors = [this, &ops](ir::OperationIndex op_idx, + std::optional &&tensors) { + if (not tensors.has_value()) + return; + + auto ls_tensors = tensors.value(); + for (auto i = 0u; i < ls_tensors.size(); ++i) + { + LayerScopeTensorIndex tensor_idx(op_idx, i); + _tensor_builder->registerLayerScopeTensor(tensor_idx, ls_tensors[i]); + + std::stringstream info; + info << op_idx << "_" << ops.at(op_idx).name(); + VERBOSE() << "register (idx:" << tensor_idx << ") requested from " << info.str() << std::endl; + } + return; + }; + + for (auto &pair : fn_map) + { + auto &op_idx = pair.first; + auto &fn_seq = pair.second; + + const ir::IOperation *op = &ops.at(op_idx); + const auto trainable_op = dynamic_cast(op); + assert(trainable_op != nullptr); + + if (not trainable_op->isRequiredForBackward()) + continue; + + VERBOSE(LayerScopeTensor) << "register tensor for " << trainable_op->name() << std::endl; + + fn_seq->iterate([&](exec::train::ITrainableFunction &fn) { + register_tensors(op_idx, (&fn)->registerLayerScopeTensors()); + }); + } + + const auto ctx_data = data(); + TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands}; + tensor_planner.planLayerScopeTensors(_tensor_builder.get()); + return; +} + void BackendContext::planForwardTensors() { const auto &tgraph = *trainable_graph(); diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h index 8e343aee403..7017d7e3175 100644 --- a/runtime/onert/backend/train/BackendContext.h +++ b/runtime/onert/backend/train/BackendContext.h @@ -73,6 +73,7 @@ class BackendContext : public onert::backend::train::TrainableBackendContext private: void planForwardTensors(); void planBackwardTensors(); + void planLayerScopeTensors(const FunctionMap &fn_map); public: std::shared_ptr external_context() { return _external_context; } diff --git a/runtime/onert/backend/train/TensorBuilder.cc b/runtime/onert/backend/train/TensorBuilder.cc index ee737222be2..5b58ed7b593 100644 --- a/runtime/onert/backend/train/TensorBuilder.cc +++ b/runtime/onert/backend/train/TensorBuilder.cc @@ -95,6 +95,27 @@ void TensorBuilder::registerDisposableBackwardTensorInfo(const DisposableTensorI _disposable_backprops.add(index); } +void TensorBuilder::registerLayerScopeTensor(const LayerScopeTensorIndex &index, + std::shared_ptr &tensor) +{ + const auto op_idx = index.op_index(); + + const auto pair = _operation_to_layerscope.find(op_idx); + if (pair == _operation_to_layerscope.end()) + { + util::Set tensor_indices; + tensor_indices.add(index); + _operation_to_layerscope[op_idx] = tensor_indices; + } + else + { + assert(!pair->second.contains(index)); + pair->second.add(index); + } + + _tensor_reg->setLayerScopeTensor(index, tensor); +} + void TensorBuilder::notifyFirstUse(const ir::OperandIndex &index) { // TODO Support momory plan @@ -155,6 +176,16 @@ void TensorBuilder::notifyDisposableBackPropLastUse(const DisposableTensorIndex _tensor_mgr->releaseDisposableBackPropPlan(index); } +void TensorBuilder::notifyLayerScopeFirstUse(const LayerScopeTensorIndex &index) +{ + _tensor_mgr->claimLayerScopePlan(index); +} + +void TensorBuilder::notifyLayerScopeLastUse(const LayerScopeTensorIndex &index) +{ + _tensor_mgr->releaseLayerScopePlan(index); +} + bool TensorBuilder::isRegistered(const ir::OperandIndex &index) const { return _tensor_info_map.find(index) != _tensor_info_map.end(); @@ -170,6 +201,29 @@ bool TensorBuilder::isRegisteredDisposableBackwardTensor(const DisposableTensorI return _disposable_backprops.contains(index); } +bool TensorBuilder::isRegisteredLayerScopeTensor(const ir::OperationIndex &index) const +{ + const auto pair = _operation_to_layerscope.find(index); + return (pair != _operation_to_layerscope.end()); +} + +const util::Set & +TensorBuilder::getRegisteredLayerScopeTensorIndex(const ir::OperationIndex &index) const +{ + const auto pair = _operation_to_layerscope.find(index); + assert(pair != _operation_to_layerscope.end()); + + return pair->second; +} + +LayerScopeTensorLifeTime +TensorBuilder::getLayerScopeTensorLifeTime(const LayerScopeTensorIndex &index) const +{ + const auto &ls_tensors = _tensor_reg->layerscope_tensors(); + const auto &tensor = ls_tensors.at(index); + return tensor->lifetime(); +} + void TensorBuilder::allocate(void) { _tensor_mgr->allocateNonConstTensors(); @@ -183,6 +237,8 @@ void TensorBuilder::allocateBackward(void) _tensor_mgr->allocateDisposableBackPropTensors(); } +void TensorBuilder::allocateLayerScope(void) { _tensor_mgr->allocateLayerScopeTensors(); } + } // namespace train } // namespace backend } // namespace onert diff --git a/runtime/onert/backend/train/TensorBuilder.h b/runtime/onert/backend/train/TensorBuilder.h index 1fa46855142..c53861ec519 100644 --- a/runtime/onert/backend/train/TensorBuilder.h +++ b/runtime/onert/backend/train/TensorBuilder.h @@ -18,10 +18,12 @@ #define __ONERT_BACKEND_TRAIN_TENSOR_BUILDER_H__ #include "DisposableTensorIndex.h" +#include "LayerScopeTensorIndex.h" #include "TensorManager.h" #include "TensorRegistry.h" #include "util/Set.h" +#include #include namespace onert @@ -55,6 +57,9 @@ class TensorBuilder void registerDisposableBackwardTensorInfo(const DisposableTensorIndex &index, const ir::OperandInfo &info); + void registerLayerScopeTensor(const LayerScopeTensorIndex &index, + std::shared_ptr &info); + // TODO Support memory plan of all tensors void notifyFirstUse(const ir::OperandIndex &); void notifyLastUse(const ir::OperandIndex &); @@ -62,13 +67,21 @@ class TensorBuilder void notifyBackwardLastUse(const ir::OperandIndex &); void notifyDisposableBackPropFirstUse(const DisposableTensorIndex &); void notifyDisposableBackPropLastUse(const DisposableTensorIndex &); + void notifyLayerScopeFirstUse(const LayerScopeTensorIndex &); + void notifyLayerScopeLastUse(const LayerScopeTensorIndex &); bool isRegistered(const ir::OperandIndex &) const; bool isRegisteredBackward(const ir::OperandIndex &) const; bool isRegisteredDisposableBackwardTensor(const DisposableTensorIndex &index) const; + bool isRegisteredLayerScopeTensor(const ir::OperationIndex &) const; + + const util::Set & + getRegisteredLayerScopeTensorIndex(const ir::OperationIndex &) const; + LayerScopeTensorLifeTime getLayerScopeTensorLifeTime(const LayerScopeTensorIndex &) const; void allocate(void); void allocateBackward(void); + void allocateLayerScope(void); // <- this have to called after private: const std::shared_ptr _tensor_reg; @@ -77,6 +90,7 @@ class TensorBuilder ir::OperandIndexMap _backward_tensor_info_map; ir::OperandIndexMap _as_constants; util::Set _disposable_backprops; + ir::OperationIndexMap> _operation_to_layerscope; const exec::train::optimizer::Optimizer *_optimizer; }; diff --git a/runtime/onert/backend/train/TensorPlanner.cc b/runtime/onert/backend/train/TensorPlanner.cc index 724eab7d171..33d4048dce4 100644 --- a/runtime/onert/backend/train/TensorPlanner.cc +++ b/runtime/onert/backend/train/TensorPlanner.cc @@ -519,6 +519,48 @@ ir::OperandIndexSequence TensorPlanner::getOutgoingBackPropSeq(const ir::Operati return ret; } +void TensorPlanner::planLayerScopeTensors(TensorBuilder *tensor_builder) +{ + // forwading order + const auto f_order = _tgraph.topolSortOperations(); + for (const auto &op_index : f_order) + { + if (not tensor_builder->isRegisteredLayerScopeTensor(op_index)) + continue; + + auto indices = tensor_builder->getRegisteredLayerScopeTensorIndex(op_index); + for (const auto &idx : indices) + { + const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx); + if (lt == LayerScopeTensorLifeTime::FORWARD_TO_BACKWARD) + tensor_builder->notifyLayerScopeFirstUse(idx); + } + } + + // backwarding order + const auto b_order = _tgraph.essentialBackwardOrder(); + for (const auto &op_index : b_order) + { + if (not tensor_builder->isRegisteredLayerScopeTensor(op_index)) + continue; + + auto indices = tensor_builder->getRegisteredLayerScopeTensorIndex(op_index); + for (const auto &idx : indices) + { + const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx); + if (lt == LayerScopeTensorLifeTime::BACKWARD) + tensor_builder->notifyLayerScopeFirstUse(idx); + } + for (const auto &idx : indices) + { + const auto lt = tensor_builder->getLayerScopeTensorLifeTime(idx); + if (lt == LayerScopeTensorLifeTime::FORWARD_TO_BACKWARD || + lt == LayerScopeTensorLifeTime::BACKWARD) + tensor_builder->notifyLayerScopeLastUse(idx); + } + } +} + } // namespace train } // namespace backend } // namespace onert diff --git a/runtime/onert/backend/train/TensorPlanner.h b/runtime/onert/backend/train/TensorPlanner.h index 61af802fda9..5bdb0d70803 100644 --- a/runtime/onert/backend/train/TensorPlanner.h +++ b/runtime/onert/backend/train/TensorPlanner.h @@ -45,6 +45,7 @@ class TensorPlanner void planBackPropTensors(TensorBuilder *tensor_builder); void planGradientTensors(TensorBuilder *tensor_builder); void planDisposableBackPropTensors(TensorBuilder *tensor_builder); + void planLayerScopeTensors(TensorBuilder *tensor_builder); private: ir::OperandIndexSequence getOutgoingBackPropSeq(const ir::OperationIndex &op_index, diff --git a/runtime/onert/backend/train/ops/BinaryArithmeticLayer.cc b/runtime/onert/backend/train/ops/BinaryArithmeticLayer.cc index 3c4ce2f7ce1..cd3dc2e7f52 100644 --- a/runtime/onert/backend/train/ops/BinaryArithmeticLayer.cc +++ b/runtime/onert/backend/train/ops/BinaryArithmeticLayer.cc @@ -55,11 +55,22 @@ void BinaryArithmeticLayer::configureBackward(IPortableTensor *back_prop_lhs, if (activation != ir::Activation::NONE) { - _act_back_prop_output = std::make_unique(_output->get_info()); - _act_back_prop_output->setBuffer(std::make_shared(_output->total_size())); + _act_back_prop_output = std::make_shared(_back_prop_output->get_info()); } } +std::optional BinaryArithmeticLayer::registerLayerScopeTensors() +{ + LayerScopeTensors tensors; + + if (_act_back_prop_output != nullptr) + { + tensors.push_back(_act_back_prop_output); + } + + return std::optional(tensors); +} + void BinaryArithmeticLayer::forward(bool) { cpu::ops::BinaryArithmeticLayer::run(); } void BinaryArithmeticLayer::backward() diff --git a/runtime/onert/backend/train/ops/BinaryArithmeticLayer.h b/runtime/onert/backend/train/ops/BinaryArithmeticLayer.h index 60d6e8be1cc..0a3e8ae58d6 100644 --- a/runtime/onert/backend/train/ops/BinaryArithmeticLayer.h +++ b/runtime/onert/backend/train/ops/BinaryArithmeticLayer.h @@ -50,6 +50,7 @@ class BinaryArithmeticLayer : public ::onert::exec::train::ITrainableFunction, void configureBackward(IPortableTensor *back_prop_lhs, IPortableTensor *back_prop_rhs, const IPortableTensor *back_prop_output, const ir::Activation activation, const ArithmeticType arithmetic_type); + std::optional registerLayerScopeTensors() override; void forward(bool training) override; void backward() override; @@ -60,7 +61,7 @@ class BinaryArithmeticLayer : public ::onert::exec::train::ITrainableFunction, ArithmeticType _arithmetic_type; ir::Activation _activation; - std::unique_ptr _act_back_prop_output; + std::shared_ptr _act_back_prop_output; }; } // namespace ops diff --git a/runtime/onert/backend/train/ops/ConvolutionLayer.cc b/runtime/onert/backend/train/ops/ConvolutionLayer.cc index 41ff7fd1c43..f53a9932a00 100644 --- a/runtime/onert/backend/train/ops/ConvolutionLayer.cc +++ b/runtime/onert/backend/train/ops/ConvolutionLayer.cc @@ -31,7 +31,7 @@ namespace using namespace onert; template -std::unique_ptr createTransposedWeights(const backend::IPortableTensor *origin_weights) +std::shared_ptr createTransposedWeights(const backend::IPortableTensor *origin_weights) { const auto &origin_shape = origin_weights->getShape(); assert(origin_shape.rank() == 4); @@ -42,7 +42,7 @@ std::unique_ptr createTransposedWeights(const backend::IPortableTensor * ir::Shape{origin_shape.dim(1), origin_shape.dim(2), origin_shape.dim(3), origin_shape.dim(0)}; transposed_info.shape(transposed_shape); - return std::make_unique(transposed_info); + return std::make_shared(transposed_info); } } // namespace @@ -79,27 +79,31 @@ void ConvolutionLayer::configureBackward(const IPortableTensor *weights, if (_dilationHeightFactor != 1 || _dilationWidthFactor != 1) throw std::runtime_error("train ConvolutionLayer: Unsupported dilation yet"); - // TODO Optimize transposed tensors - _transposed_weights = createTransposedWeights(weights); - _transposed_weights->setBuffer( - std::make_shared(_transposed_weights->total_size())); + _transposed_weights = createTransposedWeights(weights); - _conv_back_prop_output = std::make_unique(back_prop_output->get_info()); - _conv_back_prop_output->setBuffer( - std::make_shared(_conv_back_prop_output->total_size())); + _conv_back_prop_output = std::make_shared(back_prop_output->get_info()); - _transposed_grad_weights = createTransposedWeights(weights); - _transposed_grad_weights->setBuffer( - std::make_shared(_transposed_grad_weights->total_size())); + _transposed_grad_weights = createTransposedWeights(weights); if (activation != ir::Activation::NONE) { - _act_back_prop_output = std::make_unique(_back_prop_output->get_info()); - _act_back_prop_output->setBuffer( - std::make_shared(_act_back_prop_output->total_size())); + _act_back_prop_output = std::make_unique(_back_prop_output->get_info()); } } +std::optional ConvolutionLayer::registerLayerScopeTensors() +{ + LayerScopeTensors tensors = {_transposed_weights, _conv_back_prop_output, + _transposed_grad_weights}; + + if (_act_back_prop_output != nullptr) + { + tensors.push_back(_act_back_prop_output); + } + + return std::optional(tensors); +} + void ConvolutionLayer::forward(bool) { cpu::ops::ConvolutionLayer::run(); } void ConvolutionLayer::backward() { diff --git a/runtime/onert/backend/train/ops/ConvolutionLayer.h b/runtime/onert/backend/train/ops/ConvolutionLayer.h index ef11f68bf57..1177fb26f1f 100644 --- a/runtime/onert/backend/train/ops/ConvolutionLayer.h +++ b/runtime/onert/backend/train/ops/ConvolutionLayer.h @@ -41,6 +41,7 @@ class ConvolutionLayer : public ::onert::exec::train::ITrainableFunction, void configureBackward(const IPortableTensor *weights, IPortableTensor *back_prop_input, IPortableTensor *grad_weights, IPortableTensor *grad_bias, const IPortableTensor *back_prop_output, const ir::Activation activation); + std::optional registerLayerScopeTensors() override; void forward(bool training) override; void backward() override; @@ -54,10 +55,10 @@ class ConvolutionLayer : public ::onert::exec::train::ITrainableFunction, const IPortableTensor *_back_prop_output; // TODO Consider if these tensors should be built in TensorBuilder - std::unique_ptr _transposed_weights; - std::unique_ptr _conv_back_prop_output; - std::unique_ptr _act_back_prop_output; - std::unique_ptr _transposed_grad_weights; + std::shared_ptr _transposed_weights; + std::shared_ptr _conv_back_prop_output; + std::shared_ptr _transposed_grad_weights; + std::shared_ptr _act_back_prop_output; }; } // namespace ops diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc index 9d35655b26f..cf1407923de 100644 --- a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc @@ -28,7 +28,7 @@ namespace using namespace onert; -std::unique_ptr +std::shared_ptr createTransposedTensor(const backend::IPortableTensor *origin_tensor) { const auto &origin_shape = origin_tensor->getShape(); @@ -38,7 +38,7 @@ createTransposedTensor(const backend::IPortableTensor *origin_tensor) auto transposed_shape = ir::Shape{origin_shape.dim(1), origin_shape.dim(0)}; transposed_info.shape(transposed_shape); - return std::make_unique(transposed_info); + return std::make_shared(transposed_info); } } // namespace @@ -86,23 +86,29 @@ void FullyConnectedLayer::configureBackward( "train FullyConnectedLayer: Input other ranks than 2 are not supported."}; _transposed_weights = createTransposedTensor(weights); - _transposed_weights->setBuffer(std::make_shared(weights->total_size())); _transposed_input = createTransposedTensor(input); - _transposed_input->setBuffer(std::make_shared(input->total_size())); _transposed_back_prop_output = createTransposedTensor(back_prop_output); - _transposed_back_prop_output->setBuffer( - std::make_shared(back_prop_output->total_size())); if (activation != ir::Activation::NONE) { - _act_back_prop_output = std::make_unique(_back_prop_output->get_info()); - _act_back_prop_output->setBuffer( - std::make_shared(_back_prop_output->total_size())); + _act_back_prop_output = std::make_shared(_back_prop_output->get_info()); } } +std::optional FullyConnectedLayer::registerLayerScopeTensors() +{ + LayerScopeTensors tensors = {_transposed_weights, _transposed_input, + _transposed_back_prop_output}; + if (_act_back_prop_output != nullptr) + { + tensors.push_back(_act_back_prop_output); + } + + return tensors; +} + void FullyConnectedLayer::forward(bool) { cpu::ops::FullyConnectedLayer::run(); } void FullyConnectedLayer::backward() diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.h b/runtime/onert/backend/train/ops/FullyConnectedLayer.h index 190bfbffe42..44fe5ab7c88 100644 --- a/runtime/onert/backend/train/ops/FullyConnectedLayer.h +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.h @@ -46,6 +46,7 @@ class FullyConnectedLayer : public exec::train::ITrainableFunction, const IPortableTensor *back_prop_output, ir::Activation activation, ir::FullyConnectedWeightsFormat weights_format); + std::optional registerLayerScopeTensors() override; void forward(bool training) override; void backward() override; @@ -58,11 +59,10 @@ class FullyConnectedLayer : public exec::train::ITrainableFunction, IPortableTensor *_back_prop_input; const IPortableTensor *_back_prop_output; - // TODO Optimize memory - std::unique_ptr _transposed_weights; - std::unique_ptr _transposed_input; - std::unique_ptr _transposed_back_prop_output; - std::unique_ptr _act_back_prop_output; + std::shared_ptr _transposed_weights; + std::shared_ptr _transposed_input; + std::shared_ptr _transposed_back_prop_output; + std::shared_ptr _act_back_prop_output; }; } // namespace ops diff --git a/runtime/onert/backend/train/ops/PoolLayer.cc b/runtime/onert/backend/train/ops/PoolLayer.cc index 098389d8f10..8a8c05adeed 100644 --- a/runtime/onert/backend/train/ops/PoolLayer.cc +++ b/runtime/onert/backend/train/ops/PoolLayer.cc @@ -24,6 +24,8 @@ #include #include +#include + namespace onert { namespace backend @@ -43,8 +45,8 @@ class MaxPool2D final : public TrainingKernelRegistry const IPortableTensor *_output; nnfw::cker::PoolParams _op_params; - std::unique_ptr _act_back_prop_output; - std::unique_ptr _arg_max_index; + std::shared_ptr _act_back_prop_output; + std::shared_ptr _arg_max_index; public: MaxPool2D(const uint32_t paddingLeft, const uint32_t, const uint32_t paddingTop, const uint32_t, @@ -66,20 +68,31 @@ class MaxPool2D final : public TrainingKernelRegistry &_op_params.float_activation_max); } - _arg_max_index = std::make_unique(_output->get_info()); - _arg_max_index->setBuffer(std::make_shared(_output->total_size())); + _arg_max_index = std::make_shared( + _output->get_info(), LayerScopeTensorLifeTime::FORWARD_TO_BACKWARD); if (activation != ir::Activation::NONE) { - _act_back_prop_output = std::make_unique(_output->get_info()); - _act_back_prop_output->setBuffer(std::make_shared(_output->total_size())); + _act_back_prop_output = std::make_shared(_output->get_info()); } }; ~MaxPool2D() {} public: - void forward(const IPortableTensor *in, IPortableTensor *out) + std::optional registerLayerScopeTensors() override + { + LayerScopeTensors tensors = {_arg_max_index}; + if (_act_back_prop_output != nullptr) + { + tensors.push_back(_act_back_prop_output); + } + + return std::optional(tensors); + } + +public: + void forward(const IPortableTensor *in, IPortableTensor *out) override { auto out_shape = getShape(out); auto out_data = getBuffer(out); @@ -90,7 +103,7 @@ class MaxPool2D final : public TrainingKernelRegistry out_data, getBuffer(arg_max_index)); } - void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in) + void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in) override { // activation backward try @@ -110,7 +123,7 @@ class MaxPool2D final : public TrainingKernelRegistry getBuffer(arg_max_index), getShape(back_prop_in), getBuffer(back_prop_in)); } -}; +}; // namespace ops class AveragePool2D final : public TrainingKernelRegistry { @@ -152,7 +165,7 @@ class AveragePool2D final : public TrainingKernelRegistry ~AveragePool2D() {} public: - void forward(const IPortableTensor *in, IPortableTensor *out) + void forward(const IPortableTensor *in, IPortableTensor *out) override { auto out_shape = getShape(out); auto out_data = getBuffer(out); @@ -162,7 +175,7 @@ class AveragePool2D final : public TrainingKernelRegistry out_data); } - void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in) + void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in) override { // activation backward try @@ -181,6 +194,9 @@ class AveragePool2D final : public TrainingKernelRegistry getBuffer(back_prop_out), getShape(back_prop_in), getBuffer(back_prop_in)); } + +public: + std::optional registerLayerScopeTensors() override { return std::nullopt; } }; } // namespace @@ -225,6 +241,11 @@ void PoolLayer::configureBackward(const uint32_t paddingLeft, const uint32_t pad } } +std::optional PoolLayer::registerLayerScopeTensors() +{ + return _kernel->registerLayerScopeTensors(); +} + void PoolLayer::forward(bool training) { if (training) diff --git a/runtime/onert/backend/train/ops/PoolLayer.h b/runtime/onert/backend/train/ops/PoolLayer.h index 2b0c9e2a00b..b1ed9006a42 100644 --- a/runtime/onert/backend/train/ops/PoolLayer.h +++ b/runtime/onert/backend/train/ops/PoolLayer.h @@ -38,6 +38,8 @@ class TrainingKernelRegistry public: virtual void forward(const IPortableTensor *in, IPortableTensor *out) = 0; virtual void backward(const IPortableTensor *back_prop_out, IPortableTensor *back_prop_in) = 0; + virtual std::optional registerLayerScopeTensors() = 0; + TrainingKernelRegistry() = default; virtual ~TrainingKernelRegistry() = default; }; @@ -62,6 +64,7 @@ class PoolLayer : public ::onert::exec::train::ITrainableFunction, public cpu::o IPortableTensor *output, IPortableTensor *back_prop_input, const IPortableTensor *back_prop_output); + std::optional registerLayerScopeTensors() override; void forward(bool training) override; void backward() override;