From 22c40b9dce6b8e743f9a1a055aa13105e07067ad Mon Sep 17 00:00:00 2001 From: sseung Date: Wed, 11 Sep 2024 13:55:01 +0900 Subject: [PATCH] [onert/backend] Templatize MemoryPlannerFactory in train This PR templatize memory planner factory in train backend. MemoryPlannerFactory currently used for DisposableTensorIndex, but it will be also used for LayerScopeTensorIndex. ONE-DCO-1.0-Signed-off-by: seunghui youn draft : https://github.com/Samsung/ONE/pull/13486 for : https://github.com/Samsung/ONE/issues/13282 --- runtime/onert/backend/train/MemoryManager.cc | 4 ++-- runtime/onert/backend/train/MemoryManager.h | 2 ++ .../backend/train/MemoryPlannerFactory.cc | 22 ++++++++++++------- .../backend/train/MemoryPlannerFactory.h | 6 ++--- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/runtime/onert/backend/train/MemoryManager.cc b/runtime/onert/backend/train/MemoryManager.cc index 4902e2a7eaa..fd156fea231 100644 --- a/runtime/onert/backend/train/MemoryManager.cc +++ b/runtime/onert/backend/train/MemoryManager.cc @@ -61,13 +61,13 @@ DisposableMemoryManager::DisposableMemoryManager() : _mem_planner{createMemoryPl basic::IMemoryPlanner *DisposableMemoryManager::createMemoryPlanner() { auto planner_id = util::getConfigString(util::config::CPU_MEMORY_PLANNER); - return MemoryPlannerFactory::get().create(planner_id); + return MemoryPlannerFactory::get().create(planner_id); } basic::IMemoryPlanner * DisposableMemoryManager::createMemoryPlanner(const std::string planner_id) { - return MemoryPlannerFactory::get().create(planner_id); + return MemoryPlannerFactory::get().create(planner_id); } void DisposableMemoryManager::claimPlan(const DisposableTensorIndex &ind, uint32_t size) diff --git a/runtime/onert/backend/train/MemoryManager.h b/runtime/onert/backend/train/MemoryManager.h index 19a60e32deb..98e840bf7f7 100644 --- a/runtime/onert/backend/train/MemoryManager.h +++ b/runtime/onert/backend/train/MemoryManager.h @@ -67,6 +67,8 @@ class DisposableMemoryManager std::shared_ptr _mem_alloc; }; +// TODO: Add LayerScopeMemoryManager using MemoryPlannerFactory + } // namespace train } // namespace backend } // namespace onert diff --git a/runtime/onert/backend/train/MemoryPlannerFactory.cc b/runtime/onert/backend/train/MemoryPlannerFactory.cc index 157b2f6af91..258feafad35 100644 --- a/runtime/onert/backend/train/MemoryPlannerFactory.cc +++ b/runtime/onert/backend/train/MemoryPlannerFactory.cc @@ -16,6 +16,9 @@ #include "MemoryPlannerFactory.h" +#include "DisposableTensorIndex.h" +#include "LayerScopeTensorIndex.h" + namespace onert { namespace backend @@ -23,30 +26,33 @@ namespace backend namespace train { -MemoryPlannerFactory &MemoryPlannerFactory::get() +template MemoryPlannerFactory &MemoryPlannerFactory::get() { - static MemoryPlannerFactory instance; + static MemoryPlannerFactory instance; return instance; } -// TODO: Update to use template varialbe instead of DisposableTensorIndex -basic::IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key) +template +basic::IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key) { if (key == "FirstFit") { - return new FirstFitPlanner(); + return new FirstFitPlanner(); } else if (key == "Bump") { - return new BumpPlanner(); + return new BumpPlanner(); } else if (key == "WIC") { - return new WICPlanner(); + return new WICPlanner(); } - return new FirstFitPlanner(); // Default Planner + return new FirstFitPlanner(); // Default Planner } +template class MemoryPlannerFactory; +template class MemoryPlannerFactory; + } // namespace train } // namespace backend } // namespace onert diff --git a/runtime/onert/backend/train/MemoryPlannerFactory.h b/runtime/onert/backend/train/MemoryPlannerFactory.h index d1609e17559..7d005b01c2e 100644 --- a/runtime/onert/backend/train/MemoryPlannerFactory.h +++ b/runtime/onert/backend/train/MemoryPlannerFactory.h @@ -28,17 +28,17 @@ namespace backend namespace train { -class MemoryPlannerFactory +template class MemoryPlannerFactory { public: - static MemoryPlannerFactory &get(); + static MemoryPlannerFactory &get(); private: MemoryPlannerFactory() = default; public: // Currently, only the memory planner for DisposableTensor is supported - basic::IMemoryPlanner *create(const std::string &key); + basic::IMemoryPlanner *create(const std::string &key); }; } // namespace train