From eb3843dece80a1e8fb340b268e4ba915a7af7cf7 Mon Sep 17 00:00:00 2001 From: sseung Date: Fri, 13 Sep 2024 17:34:18 +0900 Subject: [PATCH] [onert/backend] Add LayerScopeTensors into TensorRegistry This PR adds LayerScopeTensors into TensorRegistry. ONE-DCO-1.0-Signed-off-by: seunghui youn draft : https://github.com/Samsung/ONE/pull/13486 for : https://github.com/Samsung/ONE/issues/13282 --- runtime/onert/backend/train/TensorRegistry.h | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/runtime/onert/backend/train/TensorRegistry.h b/runtime/onert/backend/train/TensorRegistry.h index 13932199a9d..e1b1c1ce8f7 100644 --- a/runtime/onert/backend/train/TensorRegistry.h +++ b/runtime/onert/backend/train/TensorRegistry.h @@ -18,8 +18,10 @@ #define __ONERT_BACKEND_TRAIN_TENSOR_REGISTRY__ #include +#include #include "DisposableTensorIndex.h" +#include "LayerScopeTensorIndex.h" #include "Tensor.h" namespace onert @@ -60,9 +62,36 @@ class TensorRegistry return _disposable_back_prop; } + std::shared_ptr getExtraTensor(const LayerScopeTensorIndex &index) + { + auto itr = _layerscope.find(index); + if (itr != _layerscope.end()) + return itr->second; + + return nullptr; + } + + void setExtraTensor(const LayerScopeTensorIndex &index, std::shared_ptr &tensor) + { + assert(tensor != nullptr); + auto itr = _layerscope.find(index); + if (itr != _layerscope.end()) + throw std::runtime_error{ + "Tried to set a layer scope tensor but another layer scope tensor already exists."}; + + _layerscope[index] = tensor; + } + + const std::unordered_map> & + layerscope_tensors() + { + return _layerscope; + } + private: // Disposable Tensors to be accumulated to BackPropTensor std::unordered_map> _disposable_back_prop; + std::unordered_map> _layerscope; }; } // namespace train