From 723bc1afc138fa55d88fa175cd6b1c44265bda77 Mon Sep 17 00:00:00 2001 From: sseung Date: Fri, 13 Sep 2024 17:34:18 +0900 Subject: [PATCH] [onert/backend] Add LayerScopeTensors into TensorRegistry This PR adds LayerScopeTensors into TensorRegistry. ONE-DCO-1.0-Signed-off-by: seunghui youn draft : https://github.com/Samsung/ONE/pull/13486 for : https://github.com/Samsung/ONE/issues/13282 --- runtime/onert/backend/train/TensorRegistry.h | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/runtime/onert/backend/train/TensorRegistry.h b/runtime/onert/backend/train/TensorRegistry.h index 13932199a9d..2d1c034dbc4 100644 --- a/runtime/onert/backend/train/TensorRegistry.h +++ b/runtime/onert/backend/train/TensorRegistry.h @@ -18,8 +18,10 @@ #define __ONERT_BACKEND_TRAIN_TENSOR_REGISTRY__ #include +#include #include "DisposableTensorIndex.h" +#include "LayerScopeTensorIndex.h" #include "Tensor.h" namespace onert @@ -60,9 +62,37 @@ class TensorRegistry return _disposable_back_prop; } + std::shared_ptr getLayerScopeTensor(const LayerScopeTensorIndex &index) + { + auto itr = _layer_scope.find(index); + if (itr != _layer_scope.end()) + return itr->second; + + return nullptr; + } + + void setLayerScopeTensor(const LayerScopeTensorIndex &index, + std::shared_ptr &tensor) + { + assert(tensor != nullptr); + auto itr = _layer_scope.find(index); + if (itr != _layer_scope.end()) + throw std::runtime_error{ + "Tried to set a layer scope tensor but another layer scope tensor already exists."}; + + _layer_scope[index] = tensor; + } + + const std::unordered_map> & + layerscope_tensors() + { + return _layer_scope; + } + private: // Disposable Tensors to be accumulated to BackPropTensor std::unordered_map> _disposable_back_prop; + std::unordered_map> _layer_scope; }; } // namespace train