diff --git a/dask_cuda/cli.py b/dask_cuda/cli.py index a8c6d972..8101f020 100644 --- a/dask_cuda/cli.py +++ b/dask_cuda/cli.py @@ -13,7 +13,7 @@ from distributed.utils import import_term from .cuda_worker import CUDAWorker -from .utils import print_cluster_config +from .utils import CommaSeparatedChoice, print_cluster_config logger = logging.getLogger(__name__) @@ -164,6 +164,16 @@ def cuda(): incompatible with RMM pools and managed memory, trying to enable both will result in failure.""", ) +@click.option( + "--set-rmm-allocator-for-libs", + "rmm_allocator_external_lib_list", + type=CommaSeparatedChoice(["cupy", "torch"]), + default=None, + show_default=True, + help=""" + Set RMM as the allocator for external libraries. Provide a comma-separated + list of libraries to set, e.g., "torch,cupy".""", +) @click.option( "--rmm-release-threshold", default=None, @@ -351,6 +361,7 @@ def worker( rmm_maximum_pool_size, rmm_managed_memory, rmm_async, + rmm_allocator_external_lib_list, rmm_release_threshold, rmm_log_directory, rmm_track_allocations, @@ -425,6 +436,7 @@ def worker( rmm_maximum_pool_size, rmm_managed_memory, rmm_async, + rmm_allocator_external_lib_list, rmm_release_threshold, rmm_log_directory, rmm_track_allocations, diff --git a/dask_cuda/cuda_worker.py b/dask_cuda/cuda_worker.py index 3e03ed29..30c14450 100644 --- a/dask_cuda/cuda_worker.py +++ b/dask_cuda/cuda_worker.py @@ -47,6 +47,7 @@ def __init__( rmm_maximum_pool_size=None, rmm_managed_memory=False, rmm_async=False, + rmm_allocator_external_lib_list=None, rmm_release_threshold=None, rmm_log_directory=None, rmm_track_allocations=False, @@ -231,6 +232,7 @@ def del_pid_file(): release_threshold=rmm_release_threshold, log_directory=rmm_log_directory, track_allocations=rmm_track_allocations, + external_lib_list=rmm_allocator_external_lib_list, ), PreImport(pre_import), CUDFSetup(spill=enable_cudf_spill, spill_stats=cudf_spill_stats), diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index c037223b..7a24df43 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -143,6 +143,11 @@ class LocalCUDACluster(LocalCluster): The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also incompatible with RMM pools and managed memory. Trying to enable both will result in an exception. + rmm_allocator_external_lib_list: str, list or None, default None + List of external libraries for which to set RMM as the allocator. + Supported options are: ``["torch", "cupy"]``. Can be a comma-separated string + (like ``"torch,cupy"``) or a list of strings (like ``["torch", "cupy"]``). + If ``None``, no external libraries will use RMM as their allocator. rmm_release_threshold: int, str or None, default None When ``rmm.async is True`` and the pool size grows beyond this value, unused memory held by the pool will be released at the next synchronization point. @@ -231,6 +236,7 @@ def __init__( rmm_maximum_pool_size=None, rmm_managed_memory=False, rmm_async=False, + rmm_allocator_external_lib_list=None, rmm_release_threshold=None, rmm_log_directory=None, rmm_track_allocations=False, @@ -265,6 +271,19 @@ def __init__( n_workers = len(CUDA_VISIBLE_DEVICES) if n_workers < 1: raise ValueError("Number of workers cannot be less than 1.") + + if rmm_allocator_external_lib_list is not None: + if isinstance(rmm_allocator_external_lib_list, str): + rmm_allocator_external_lib_list = [ + v.strip() for v in rmm_allocator_external_lib_list.split(",") + ] + elif not isinstance(rmm_allocator_external_lib_list, list): + raise ValueError( + "rmm_allocator_external_lib_list must be either a comma-separated " + "string or a list of strings. Examples: 'torch,cupy' " + "or ['torch', 'cupy']" + ) + # Set nthreads=1 when parsing mem_limit since it only depends on n_workers logger = logging.getLogger(__name__) self.memory_limit = parse_memory_limit( @@ -284,6 +303,8 @@ def __init__( self.rmm_managed_memory = rmm_managed_memory self.rmm_async = rmm_async self.rmm_release_threshold = rmm_release_threshold + self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list + if rmm_pool_size is not None or rmm_managed_memory or rmm_async: try: import rmm # noqa F401 @@ -437,6 +458,7 @@ def new_worker_spec(self): release_threshold=self.rmm_release_threshold, log_directory=self.rmm_log_directory, track_allocations=self.rmm_track_allocations, + external_lib_list=self.rmm_allocator_external_lib_list, ), PreImport(self.pre_import), CUDFSetup(self.enable_cudf_spill, self.cudf_spill_stats), diff --git a/dask_cuda/plugins.py b/dask_cuda/plugins.py index 122f93ff..cd1928af 100644 --- a/dask_cuda/plugins.py +++ b/dask_cuda/plugins.py @@ -1,5 +1,6 @@ import importlib import os +from typing import Callable, Dict from distributed import WorkerPlugin @@ -39,6 +40,7 @@ def __init__( release_threshold, log_directory, track_allocations, + external_lib_list, ): if initial_pool_size is None and maximum_pool_size is not None: raise ValueError( @@ -61,6 +63,7 @@ def __init__( self.logging = log_directory is not None self.log_directory = log_directory self.rmm_track_allocations = track_allocations + self.external_lib_list = external_lib_list def setup(self, worker=None): if self.initial_pool_size is not None: @@ -123,6 +126,70 @@ def setup(self, worker=None): mr = rmm.mr.get_current_device_resource() rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr)) + if self.external_lib_list is not None: + for lib in self.external_lib_list: + enable_rmm_memory_for_library(lib) + + +def enable_rmm_memory_for_library(lib_name: str) -> None: + """Enable RMM memory pool support for a specified third-party library. + + This function allows the given library to utilize RMM's memory pool if it supports + integration with RMM. The library name is passed as a string argument, and if the + library is compatible, its memory allocator will be configured to use RMM. + + Parameters + ---------- + lib_name : str + The name of the third-party library to enable RMM memory pool support for. + Supported libraries are "cupy" and "torch". + + Raises + ------ + ValueError + If the library name is not supported or does not have RMM integration. + ImportError + If the required library is not installed. + """ + + # Mapping of supported libraries to their respective setup functions + setup_functions: Dict[str, Callable[[], None]] = { + "torch": _setup_rmm_for_torch, + "cupy": _setup_rmm_for_cupy, + } + + if lib_name not in setup_functions: + supported_libs = ", ".join(setup_functions.keys()) + raise ValueError( + f"The library '{lib_name}' is not supported for RMM integration. " + f"Supported libraries are: {supported_libs}." + ) + + # Call the setup function for the specified library + setup_functions[lib_name]() + + +def _setup_rmm_for_torch() -> None: + try: + import torch + except ImportError as e: + raise ImportError("PyTorch is not installed.") from e + + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + +def _setup_rmm_for_cupy() -> None: + try: + import cupy + except ImportError as e: + raise ImportError("CuPy is not installed.") from e + + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + class PreImport(WorkerPlugin): def __init__(self, libraries): diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index ff4dbbae..74596fe2 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -9,6 +9,7 @@ from multiprocessing import cpu_count from typing import Optional +import click import numpy as np import pynvml import toolz @@ -764,3 +765,13 @@ def get_rmm_memory_resource_stack(mr) -> list: if isinstance(mr, rmm.mr.StatisticsResourceAdaptor): return mr.allocation_counts["current_bytes"] return None + + +class CommaSeparatedChoice(click.Choice): + def convert(self, value, param, ctx): + values = [v.strip() for v in value.split(",")] + for v in values: + if v not in self.choices: + choices_str = ", ".join(f"'{c}'" for c in self.choices) + self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})") + return values