Skip to content

Commit

Permalink
Remove code paths that depend on RMM_STATIC_CUDART
Browse files Browse the repository at this point in the history
  • Loading branch information
robertmaynard committed Sep 4, 2024
1 parent a42d36d commit 8a3db58
Showing 1 changed file with 6 additions and 22 deletions.
28 changes: 6 additions & 22 deletions include/rmm/detail/dynamic_load_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,18 @@ struct dynamic_load_runtime {
template <typename signature>
static std::optional<signature> function(const char* func_name)
{
auto* runtime = get_cuda_runtime_handle();
auto* handle = ::dlsym(runtime, func_name);
// query if the function has already been loaded
auto* handle = ::dlsym(RTLD_DEFAULT, func_name);
if(!handle) {
auto* runtime = get_cuda_runtime_handle();
handle = ::dlsym(runtime, func_name);
}
if (!handle) { return std::nullopt; }
auto* function_ptr = reinterpret_cast<signature>(handle);
return std::optional<signature>(function_ptr);
}
};

#if defined(RMM_STATIC_CUDART)
// clang-format off
#define RMM_CUDART_API_WRAPPER(name, signature) \
template <typename... Args> \
static cudaError_t name(Args... args) \
{ \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Waddress\"") \
static_assert(static_cast<signature>(::name), \
"Failed to find #name function with arguments #signature"); \
_Pragma("GCC diagnostic pop") \
return ::name(args...); \
}
// clang-format on
#else
#define RMM_CUDART_API_WRAPPER(name, signature) \
template <typename... Args> \
static cudaError_t name(Args... args) \
Expand All @@ -100,7 +89,6 @@ struct dynamic_load_runtime {
if (func) { return (*func)(args...); } \
RMM_FAIL("Failed to find #name function in libcudart.so"); \
}
#endif

#if CUDART_VERSION >= 11020 // 11.2 introduced cudaMallocAsync
/**
Expand All @@ -113,14 +101,10 @@ struct dynamic_load_runtime {
struct async_alloc {
static bool is_supported()
{
#if defined(RMM_STATIC_CUDART)
static bool runtime_supports_pool = (CUDART_VERSION >= 11020);
#else
static bool runtime_supports_pool =
dynamic_load_runtime::function<dynamic_load_runtime::function_sig<void*, cudaStream_t>>(
"cudaFreeAsync")
.has_value();
#endif

static auto driver_supports_pool{[] {
int cuda_pool_supported{};
Expand Down

0 comments on commit 8a3db58

Please sign in to comment.