diff --git a/include/rmm/mr/device/aligned_resource_adaptor.hpp b/include/rmm/mr/device/aligned_resource_adaptor.hpp index c3213afdf..a91056dfa 100644 --- a/include/rmm/mr/device/aligned_resource_adaptor.hpp +++ b/include/rmm/mr/device/aligned_resource_adaptor.hpp @@ -168,7 +168,8 @@ class aligned_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto cast = dynamic_cast const*>(&other); - return cast != nullptr && upstream_->is_equal(*cast->get_upstream()) && + if (cast == nullptr) { return false; } + return get_upstream_resource() == cast->get_upstream_resource() && alignment_ == cast->alignment_ && alignment_threshold_ == cast->alignment_threshold_; } diff --git a/include/rmm/mr/device/binning_memory_resource.hpp b/include/rmm/mr/device/binning_memory_resource.hpp index 9b0c04f3d..2649a17b8 100644 --- a/include/rmm/mr/device/binning_memory_resource.hpp +++ b/include/rmm/mr/device/binning_memory_resource.hpp @@ -149,13 +149,13 @@ class binning_memory_resource final : public device_memory_resource { * Chooses a memory_resource that allocates the smallest blocks at least as large as `bytes`. * * @param bytes Requested allocation size in bytes - * @return rmm::mr::device_memory_resource& memory_resource that can allocate the requested size. + * @return Get the resource reference for the requested size. */ - device_memory_resource* get_resource(std::size_t bytes) + rmm::device_async_resource_ref get_resource_ref(std::size_t bytes) { auto iter = resource_bins_.lower_bound(bytes); - return (iter != resource_bins_.cend()) ? iter->second - : static_cast(get_upstream()); + return (iter != resource_bins_.cend()) ? rmm::device_async_resource_ref{iter->second} + : get_upstream_resource(); } /** @@ -170,7 +170,7 @@ class binning_memory_resource final : public device_memory_resource { void* do_allocate(std::size_t bytes, cuda_stream_view stream) override { if (bytes <= 0) { return nullptr; } - return get_resource(bytes)->allocate(bytes, stream); + return get_resource_ref(bytes).allocate_async(bytes, stream); } /** @@ -183,8 +183,7 @@ class binning_memory_resource final : public device_memory_resource { */ void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override { - auto res = get_resource(bytes); - if (res != nullptr) { res->deallocate(ptr, bytes, stream); } + get_resource_ref(bytes).deallocate_async(ptr, bytes, stream); } Upstream* upstream_mr_; // The upstream memory_resource from which to allocate blocks. diff --git a/include/rmm/mr/device/failure_callback_resource_adaptor.hpp b/include/rmm/mr/device/failure_callback_resource_adaptor.hpp index 5a1a09b82..53bc572c2 100644 --- a/include/rmm/mr/device/failure_callback_resource_adaptor.hpp +++ b/include/rmm/mr/device/failure_callback_resource_adaptor.hpp @@ -183,8 +183,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto cast = dynamic_cast const*>(&other); - return cast != nullptr ? upstream_->is_equal(*cast->get_upstream()) - : upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } Upstream* upstream_; // the upstream resource used for satisfying allocation requests diff --git a/include/rmm/mr/device/fixed_size_memory_resource.hpp b/include/rmm/mr/device/fixed_size_memory_resource.hpp index 8318dbfa1..da9476b1e 100644 --- a/include/rmm/mr/device/fixed_size_memory_resource.hpp +++ b/include/rmm/mr/device/fixed_size_memory_resource.hpp @@ -156,7 +156,7 @@ class fixed_size_memory_resource */ free_list blocks_from_upstream(cuda_stream_view stream) { - void* ptr = get_upstream()->allocate(upstream_chunk_size_, stream); + void* ptr = get_upstream_resource().allocate_async(upstream_chunk_size_, stream); block_type block{ptr}; upstream_blocks_.push_back(block); @@ -211,7 +211,7 @@ class fixed_size_memory_resource lock_guard lock(this->get_mutex()); for (auto block : upstream_blocks_) { - get_upstream()->deallocate(block.pointer(), upstream_chunk_size_); + get_upstream_resource().deallocate(block.pointer(), upstream_chunk_size_); } upstream_blocks_.clear(); } diff --git a/include/rmm/mr/device/limiting_resource_adaptor.hpp b/include/rmm/mr/device/limiting_resource_adaptor.hpp index 174ada34b..aa2361d1f 100644 --- a/include/rmm/mr/device/limiting_resource_adaptor.hpp +++ b/include/rmm/mr/device/limiting_resource_adaptor.hpp @@ -162,8 +162,8 @@ class limiting_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto const* cast = dynamic_cast const*>(&other); - if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); } - return upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } // maximum bytes this allocator is allowed to allocate. diff --git a/include/rmm/mr/device/logging_resource_adaptor.hpp b/include/rmm/mr/device/logging_resource_adaptor.hpp index 885908937..61d00cafc 100644 --- a/include/rmm/mr/device/logging_resource_adaptor.hpp +++ b/include/rmm/mr/device/logging_resource_adaptor.hpp @@ -277,8 +277,8 @@ class logging_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto const* cast = dynamic_cast const*>(&other); - if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); } - return upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } // make_logging_adaptor needs access to private get_default_filename diff --git a/include/rmm/mr/device/pool_memory_resource.hpp b/include/rmm/mr/device/pool_memory_resource.hpp index 61ef75bd0..4cbdeef4a 100644 --- a/include/rmm/mr/device/pool_memory_resource.hpp +++ b/include/rmm/mr/device/pool_memory_resource.hpp @@ -503,7 +503,7 @@ class pool_memory_resource final if (size == 0) { return {}; } try { - void* ptr = get_upstream()->allocate_async(size, stream); + void* ptr = get_upstream_resource().allocate_async(size, stream); return std::optional{ *upstream_blocks_.emplace(static_cast(ptr), size, true).first}; } catch (std::exception const& e) { @@ -570,7 +570,7 @@ class pool_memory_resource final lock_guard lock(this->get_mutex()); for (auto block : upstream_blocks_) { - get_upstream()->deallocate(block.pointer(), block.size()); + get_upstream_resource().deallocate(block.pointer(), block.size()); } upstream_blocks_.clear(); #ifdef RMM_POOL_TRACK_ALLOCATIONS diff --git a/include/rmm/mr/device/statistics_resource_adaptor.hpp b/include/rmm/mr/device/statistics_resource_adaptor.hpp index 9febf9ef5..d072d5886 100644 --- a/include/rmm/mr/device/statistics_resource_adaptor.hpp +++ b/include/rmm/mr/device/statistics_resource_adaptor.hpp @@ -209,8 +209,8 @@ class statistics_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto cast = dynamic_cast const*>(&other); - return cast != nullptr ? upstream_->is_equal(*cast->get_upstream()) - : upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } counter bytes_; // peak, current and total allocated bytes diff --git a/include/rmm/mr/device/thread_safe_resource_adaptor.hpp b/include/rmm/mr/device/thread_safe_resource_adaptor.hpp index 5daf5f3f9..64a6f8ad5 100644 --- a/include/rmm/mr/device/thread_safe_resource_adaptor.hpp +++ b/include/rmm/mr/device/thread_safe_resource_adaptor.hpp @@ -119,11 +119,9 @@ class thread_safe_resource_adaptor final : public device_memory_resource { bool do_is_equal(device_memory_resource const& other) const noexcept override { if (this == &other) { return true; } - auto thread_safe_other = dynamic_cast const*>(&other); - if (thread_safe_other != nullptr) { - return upstream_->is_equal(*thread_safe_other->get_upstream()); - } - return upstream_->is_equal(other); + auto cast = dynamic_cast const*>(&other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } std::mutex mutable mtx; // mutex for thread safe access to upstream diff --git a/include/rmm/mr/device/tracking_resource_adaptor.hpp b/include/rmm/mr/device/tracking_resource_adaptor.hpp index d337da21c..c49674849 100644 --- a/include/rmm/mr/device/tracking_resource_adaptor.hpp +++ b/include/rmm/mr/device/tracking_resource_adaptor.hpp @@ -264,8 +264,8 @@ class tracking_resource_adaptor final : public device_memory_resource { { if (this == &other) { return true; } auto cast = dynamic_cast const*>(&other); - return cast != nullptr ? upstream_->is_equal(*cast->get_upstream()) - : upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } bool capture_stacks_; // whether or not to capture call stacks diff --git a/tests/device_check_resource_adaptor.hpp b/tests/device_check_resource_adaptor.hpp index 4f75b191c..fcb578fdf 100644 --- a/tests/device_check_resource_adaptor.hpp +++ b/tests/device_check_resource_adaptor.hpp @@ -64,8 +64,8 @@ class device_check_resource_adaptor final : public rmm::mr::device_memory_resour { if (this == &other) { return true; } auto const* cast = dynamic_cast(&other); - if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); } - return upstream_->is_equal(other); + if (cast == nullptr) { return upstream_->is_equal(other); } + return get_upstream_resource() == cast->get_upstream_resource(); } rmm::cuda_device_id device_id; diff --git a/tests/mr/device/adaptor_tests.cpp b/tests/mr/device/adaptor_tests.cpp index c05c9b641..a757a78b0 100644 --- a/tests/mr/device/adaptor_tests.cpp +++ b/tests/mr/device/adaptor_tests.cpp @@ -135,15 +135,6 @@ TYPED_TEST(AdaptorTest, Equality) } } -TYPED_TEST(AdaptorTest, GetUpstream) -{ - if constexpr (std::is_same_v) { - EXPECT_TRUE(this->mr->wrapped().get_upstream()->is_equal(this->cuda)); - } else { - EXPECT_TRUE(this->mr->get_upstream()->is_equal(this->cuda)); - } -} - TYPED_TEST(AdaptorTest, GetUpstreamResource) { rmm::device_async_resource_ref expected{this->cuda}; diff --git a/tests/mr/device/statistics_mr_tests.cpp b/tests/mr/device/statistics_mr_tests.cpp index 59c356b1e..8fd12f49b 100644 --- a/tests/mr/device/statistics_mr_tests.cpp +++ b/tests/mr/device/statistics_mr_tests.cpp @@ -127,7 +127,8 @@ TEST(StatisticsTest, PeakAllocations) TEST(StatisticsTest, MultiTracking) { - statistics_adaptor mr{rmm::mr::get_current_device_resource()}; + auto* orig_device_resource = rmm::mr::get_current_device_resource(); + statistics_adaptor mr{orig_device_resource}; rmm::mr::set_current_device_resource(&mr); std::vector> allocations; @@ -171,7 +172,7 @@ TEST(StatisticsTest, MultiTracking) EXPECT_EQ(inner_mr.get_allocations_counter().peak, 5); // Reset the current device resource - rmm::mr::set_current_device_resource(mr.get_upstream()); + rmm::mr::set_current_device_resource(orig_device_resource); } TEST(StatisticsTest, NegativeInnerTracking) diff --git a/tests/mr/device/tracking_mr_tests.cpp b/tests/mr/device/tracking_mr_tests.cpp index 65d2f955c..7c2532c60 100644 --- a/tests/mr/device/tracking_mr_tests.cpp +++ b/tests/mr/device/tracking_mr_tests.cpp @@ -101,7 +101,8 @@ TEST(TrackingTest, AllocationsLeftWithoutStacks) TEST(TrackingTest, MultiTracking) { - tracking_adaptor mr{rmm::mr::get_current_device_resource(), true}; + auto* orig_device_resource = rmm::mr::get_current_device_resource(); + tracking_adaptor mr{orig_device_resource, true}; rmm::mr::set_current_device_resource(&mr); std::vector> allocations; @@ -140,7 +141,7 @@ TEST(TrackingTest, MultiTracking) EXPECT_EQ(inner_mr.get_allocated_bytes(), 0); // Reset the current device resource - rmm::mr::set_current_device_resource(mr.get_upstream()); + rmm::mr::set_current_device_resource(orig_device_resource); } TEST(TrackingTest, NegativeInnerTracking)