diff --git a/Src/Base/AMReX_CArena.H b/Src/Base/AMReX_CArena.H index ee081d88f82..92147ecfcf5 100644 --- a/Src/Base/AMReX_CArena.H +++ b/Src/Base/AMReX_CArena.H @@ -47,6 +47,18 @@ public: //! Allocate some memory. [[nodiscard]] void* alloc (std::size_t nbytes) final; + /** + * Try to allocate in-place by extending the capacity of given pointer. + */ + [[nodiscard]] std::pair + alloc_in_place (void* pt, std::size_t szmin, std::size_t szmax) final; + + /** + * Try to shrink in-place + */ + [[nodiscard]] void* + shrink_in_place (void* pt, std::size_t sz) final; + /** * \brief Free up allocated memory. Merge neighboring free memory chunks * into largest possible chunk. @@ -87,6 +99,8 @@ public: protected: + void* alloc_protected (std::size_t nbytes); + std::size_t freeUnused_protected () final; //! The nodes in our free list and block list. diff --git a/Src/Base/AMReX_CArena.cpp b/Src/Base/AMReX_CArena.cpp index 8c5c0dc7e18..fe0d9d1d19c 100644 --- a/Src/Base/AMReX_CArena.cpp +++ b/Src/Base/AMReX_CArena.cpp @@ -42,9 +42,13 @@ void* CArena::alloc (std::size_t nbytes) { std::lock_guard lock(carena_mutex); - nbytes = Arena::align(nbytes == 0 ? 1 : nbytes); + return alloc_protected(nbytes); +} +void* +CArena::alloc_protected (std::size_t nbytes) +{ MemStat* stat = nullptr; #ifdef AMREX_TINY_PROFILING if (m_do_profiling) { @@ -127,6 +131,78 @@ CArena::alloc (std::size_t nbytes) return vp; } +std::pair +CArena::alloc_in_place (void* pt, std::size_t szmin, std::size_t szmax) +{ + std::lock_guard lock(carena_mutex); + + std::size_t nbytes_max = Arena::align(szmax == 0 ? 1 : szmax); + + if (pt != nullptr) { // Try to allocate in-place first + auto busy_it = m_busylist.find(Node(pt,nullptr,0)); + AMREX_ALWAYS_ASSERT(busy_it != m_busylist.end()); + AMREX_ASSERT(m_freelist.find(*busy_it) == m_freelist.end()); + + if (busy_it->size() >= szmax) { + return std::make_pair(pt, busy_it->size()); + } + + void* next_block = (char*)pt + busy_it->size(); + auto next_it = m_freelist.find(Node(next_block,nullptr,0)); + if (next_it != m_freelist.end() && busy_it->coalescable(*next_it)) { + std::size_t total_size = busy_it->size() + next_it->size(); + if (total_size >= szmax) { + // Must use nbytes_max instead of szmax for alignment. + std::size_t new_size = std::min(total_size, nbytes_max); + std::size_t left_size = total_size - new_size; + if (left_size <= 64) { + m_freelist.erase(next_it); + new_size = total_size; + } else { + auto& free_node = const_cast(*next_it); + free_node.block((char*)pt + new_size); + free_node.size(left_size); + } + std::size_t extra_size = new_size - busy_it->size(); +#ifdef AMREX_TINY_PROFILING + if (m_do_profiling) { + // xxxxx TODO: need to store the return value in *busy_it + TinyProfiler::memory_alloc(extra_size, m_profiling_stats); + } +#endif + m_actually_used += extra_size; + const_cast(*busy_it).size(new_size); + return std::make_pair(pt, new_size); + } else if (total_size >= szmin) { + m_freelist.erase(next_it); + std::size_t extra_size = total_size - busy_it->size(); +#ifdef AMREX_TINY_PROFILING + if (m_do_profiling) { + // xxxxx TODO: need to store the return value in *busy_it + TinyProfiler::memory_alloc(extra_size, m_profiling_stats); + } +#endif + m_actually_used += extra_size; + const_cast(*busy_it).size(total_size); + return std::make_pair(pt, total_size); + } + } + + if (busy_it->size() >= szmin) { + return std::make_pair(pt, busy_it->size()); + } + } + + void* newp = alloc_protected(nbytes_max); + return std::make_pair(newp, nbytes_max); +} + +void* +CArena::shrink_in_place (void* /*pt*/, std::size_t sz) +{ + return alloc(sz); // xxxxx TODO +} + void CArena::free (void* vp) {