diff --git a/dask_kubernetes/helm.py b/dask_kubernetes/helm.py index 241939cf8..21f4228b0 100644 --- a/dask_kubernetes/helm.py +++ b/dask_kubernetes/helm.py @@ -34,6 +34,12 @@ class HelmCluster(Cluster): auth: List[ClusterAuth] (optional) Configuration methods to attempt in order. Defaults to ``[InCluster(), KubeConfig()]``. + scheduler_name: str (optional) + Name of the Dask scheduler deployment in the current release. + Defaults to "scheduler". + worker_name: str (optional) + Name of the Dask worker deployment in the current release. + Defaults to "worker". **kwargs: dict Additional keyword arguments to pass to Cluster @@ -69,6 +75,8 @@ def __init__( port_forward_cluster_ip=False, loop=None, asynchronous=False, + scheduler_name="scheduler", + worker_name="worker", ): self.release_name = release_name self.namespace = namespace or _namespace_default() @@ -88,6 +96,8 @@ def __init__( self._supports_scaling = True self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop + self.scheduler_name = scheduler_name + self.worker_name = worker_name super().__init__(asynchronous=asynchronous) if not self.asynchronous: @@ -111,7 +121,7 @@ async def _start(self): await super()._start() async def _get_scheduler_address(self): - service_name = f"{self.release_name}-scheduler" + service_name = f"{self.release_name}-{self.scheduler_name}" service = await self.core_api.read_namespaced_service( service_name, self.namespace ) @@ -143,7 +153,7 @@ async def _wait_for_workers(self): while True: n_workers = len(self.scheduler_info["workers"]) deployment = await self.apps_api.read_namespaced_deployment( - name=f"{self.release_name}-worker", namespace=self.namespace + name=f"{self.release_name}-{self.worker_name}", namespace=self.namespace ) deployment_replicas = deployment.spec.replicas if n_workers == deployment_replicas: @@ -224,7 +234,7 @@ def scale(self, n_workers): async def _scale(self, n_workers): await self.apps_api.patch_namespaced_deployment( - name=f"{self.release_name}-worker", + name=f"{self.release_name}-{self.worker_name}", namespace=self.namespace, body={ "spec": {