diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 82f59346..d0ea9274 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -332,16 +332,17 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - if worker_class is not None and log_spilling is True: - raise ValueError( - "Cannot enable `log_spilling` when `worker_class` is specified. If " - "logging is needed, ensure `worker_class` is a subclass of " - "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of " - "`distributed.local_cuda_cluster.LoggedWorker`, and specify " - "`log_spilling=False`." - ) - if not isinstance(worker_class, Nanny): - worker_class = partial(Nanny, worker_class=worker_class) + if worker_class is not None: + if log_spilling is True: + raise ValueError( + "Cannot enable `log_spilling` when `worker_class` is specified. If " + "logging is needed, ensure `worker_class` is a subclass of " + "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of " + "`distributed.local_cuda_cluster.LoggedWorker`, and specify " + "`log_spilling=False`." + ) + if not issubclass(worker_class, Nanny): + worker_class = partial(Nanny, worker_class=worker_class) self.pre_import = pre_import