diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 6979cdd9342..8d30287cbf5 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -35,6 +35,7 @@ ucp = None host_array = None device_array = None +as_numba_device_array = None def synchronize_stream(stream=0): @@ -47,7 +48,7 @@ def synchronize_stream(stream=0): def init_once(): - global ucp, host_array, device_array + global ucp, host_array, device_array, as_numba_device_array if ucp is not None: return @@ -100,6 +101,16 @@ def device_array(n): "In order to send/recv CUDA arrays, Numba or RMM is required" ) + # Find the function, `as_numba_device_array()` + try: + import numba.cuda + + as_numba_device_array = lambda a: numba.cuda.as_cuda_array(a) + except ImportError: + + def as_numba_device_array(n): + raise RuntimeError("In order to send/recv CUDA arrays, Numba is required") + pool_size_str = dask.config.get("rmm.pool-size") if pool_size_str is not None: pool_size = parse_bytes(pool_size_str)