Skip to content

Commit

Permalink
Make deployment names configurable in HelmCluster (#275)
Browse files Browse the repository at this point in the history
Not sure if others would find this useful but in our case we deploy our own helm chart that's based on, but not identical to, the standard Dask. In particular the deployment names are slightly different, so the hard-coded versions in `HelmCluster` don't work for us.

For now I'm just monkey patching the functions with hard-coded deployment names but it doesn't seem like it'd hurt to make it configurable 🙂.
  • Loading branch information
bnaul authored Oct 20, 2020
1 parent 451f845 commit ccb9864
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions dask_kubernetes/helm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class HelmCluster(Cluster):
auth: List[ClusterAuth] (optional)
Configuration methods to attempt in order. Defaults to
``[InCluster(), KubeConfig()]``.
scheduler_name: str (optional)
Name of the Dask scheduler deployment in the current release.
Defaults to "scheduler".
worker_name: str (optional)
Name of the Dask worker deployment in the current release.
Defaults to "worker".
**kwargs: dict
Additional keyword arguments to pass to Cluster
Expand Down Expand Up @@ -69,6 +75,8 @@ def __init__(
port_forward_cluster_ip=False,
loop=None,
asynchronous=False,
scheduler_name="scheduler",
worker_name="worker",
):
self.release_name = release_name
self.namespace = namespace or _namespace_default()
Expand All @@ -88,6 +96,8 @@ def __init__(
self._supports_scaling = True
self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
self.loop = self._loop_runner.loop
self.scheduler_name = scheduler_name
self.worker_name = worker_name

super().__init__(asynchronous=asynchronous)
if not self.asynchronous:
Expand All @@ -111,7 +121,7 @@ async def _start(self):
await super()._start()

async def _get_scheduler_address(self):
service_name = f"{self.release_name}-scheduler"
service_name = f"{self.release_name}-{self.scheduler_name}"
service = await self.core_api.read_namespaced_service(
service_name, self.namespace
)
Expand Down Expand Up @@ -143,7 +153,7 @@ async def _wait_for_workers(self):
while True:
n_workers = len(self.scheduler_info["workers"])
deployment = await self.apps_api.read_namespaced_deployment(
name=f"{self.release_name}-worker", namespace=self.namespace
name=f"{self.release_name}-{self.worker_name}", namespace=self.namespace
)
deployment_replicas = deployment.spec.replicas
if n_workers == deployment_replicas:
Expand Down Expand Up @@ -224,7 +234,7 @@ def scale(self, n_workers):

async def _scale(self, n_workers):
await self.apps_api.patch_namespaced_deployment(
name=f"{self.release_name}-worker",
name=f"{self.release_name}-{self.worker_name}",
namespace=self.namespace,
body={
"spec": {
Expand Down

0 comments on commit ccb9864

Please sign in to comment.