Skip to content

Commit

Permalink
[onert/backend] Templatize MemoryPlannerFactory in train
Browse files Browse the repository at this point in the history
This PR templatize memory planner factory in train backend.
MemoryPlannerFactory currently used for DisposableTensorIndex, but it will be also used for LayerScopeTensorIndex.

ONE-DCO-1.0-Signed-off-by: seunghui youn <[email protected]>
draft : Samsung#13486
for : Samsung#13282
  • Loading branch information
zetwhite committed Sep 11, 2024
1 parent 0f8808e commit 22c40b9
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
4 changes: 2 additions & 2 deletions runtime/onert/backend/train/MemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ DisposableMemoryManager::DisposableMemoryManager() : _mem_planner{createMemoryPl
basic::IMemoryPlanner<DisposableTensorIndex> *DisposableMemoryManager::createMemoryPlanner()
{
auto planner_id = util::getConfigString(util::config::CPU_MEMORY_PLANNER);
return MemoryPlannerFactory::get().create(planner_id);
return MemoryPlannerFactory<DisposableTensorIndex>::get().create(planner_id);
}

basic::IMemoryPlanner<DisposableTensorIndex> *
DisposableMemoryManager::createMemoryPlanner(const std::string planner_id)
{
return MemoryPlannerFactory::get().create(planner_id);
return MemoryPlannerFactory<DisposableTensorIndex>::get().create(planner_id);
}

void DisposableMemoryManager::claimPlan(const DisposableTensorIndex &ind, uint32_t size)
Expand Down
2 changes: 2 additions & 0 deletions runtime/onert/backend/train/MemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DisposableMemoryManager
std::shared_ptr<basic::Allocator> _mem_alloc;
};

// TODO: Add LayerScopeMemoryManager using MemoryPlannerFactory<LayerScopeTensorIndex>

} // namespace train
} // namespace backend
} // namespace onert
Expand Down
22 changes: 14 additions & 8 deletions runtime/onert/backend/train/MemoryPlannerFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,43 @@

#include "MemoryPlannerFactory.h"

#include "DisposableTensorIndex.h"
#include "LayerScopeTensorIndex.h"

namespace onert
{
namespace backend
{
namespace train
{

MemoryPlannerFactory &MemoryPlannerFactory::get()
template <typename Index> MemoryPlannerFactory<Index> &MemoryPlannerFactory<Index>::get()
{
static MemoryPlannerFactory instance;
static MemoryPlannerFactory<Index> instance;
return instance;
}

// TODO: Update to use template varialbe instead of DisposableTensorIndex
basic::IMemoryPlanner<DisposableTensorIndex> *MemoryPlannerFactory::create(const std::string &key)
template <typename Index>
basic::IMemoryPlanner<Index> *MemoryPlannerFactory<Index>::create(const std::string &key)
{
if (key == "FirstFit")
{
return new FirstFitPlanner<DisposableTensorIndex>();
return new FirstFitPlanner<Index>();
}
else if (key == "Bump")
{
return new BumpPlanner<DisposableTensorIndex>();
return new BumpPlanner<Index>();
}
else if (key == "WIC")
{
return new WICPlanner<DisposableTensorIndex>();
return new WICPlanner<Index>();
}
return new FirstFitPlanner<DisposableTensorIndex>(); // Default Planner
return new FirstFitPlanner<Index>(); // Default Planner
}

template class MemoryPlannerFactory<DisposableTensorIndex>;
template class MemoryPlannerFactory<LayerScopeTensorIndex>;

} // namespace train
} // namespace backend
} // namespace onert
6 changes: 3 additions & 3 deletions runtime/onert/backend/train/MemoryPlannerFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ namespace backend
namespace train
{

class MemoryPlannerFactory
template <typename Index> class MemoryPlannerFactory
{
public:
static MemoryPlannerFactory &get();
static MemoryPlannerFactory<Index> &get();

private:
MemoryPlannerFactory() = default;

public:
// Currently, only the memory planner for DisposableTensor is supported
basic::IMemoryPlanner<DisposableTensorIndex> *create(const std::string &key);
basic::IMemoryPlanner<Index> *create(const std::string &key);
};

} // namespace train
Expand Down

0 comments on commit 22c40b9

Please sign in to comment.