diff --git a/dask_kubernetes/conftest.py b/dask_kubernetes/conftest.py index 9d279b39d..b84e375a7 100644 --- a/dask_kubernetes/conftest.py +++ b/dask_kubernetes/conftest.py @@ -1,5 +1,6 @@ import pytest +import contextlib import pathlib import os import subprocess @@ -17,6 +18,30 @@ check_dependency("docker") +@contextlib.contextmanager +def set_env(**environ): + """ + Temporarily set the process environment variables. + + >>> with set_env(PLUGINS_DIR=u'test/plugins'): + ... "PLUGINS_DIR" in os.environ + True + + >>> "PLUGINS_DIR" in os.environ + False + + :type environ: dict[str, unicode] + :param environ: Environment variables to set + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) + + @pytest.fixture() def kopf_runner(k8s_cluster): yield KopfRunner(["run", "-m", "dask_kubernetes.operator", "--verbose"]) @@ -40,10 +65,9 @@ def k8s_cluster(request, docker_image): image=image, ) kind_cluster.create() - os.environ["KUBECONFIG"] = str(kind_cluster.kubeconfig_path) kind_cluster.load_docker_image(docker_image) - yield kind_cluster - del os.environ["KUBECONFIG"] + with set_env(KUBECONFIG=str(kind_cluster.kubeconfig_path)): + yield kind_cluster if not request.config.getoption("keep_cluster"): kind_cluster.delete() diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index db47a75a7..6a6cb27fd 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -8,14 +8,11 @@ import aiohttp import kopf -import kubernetes_asyncio as kubernetes from importlib_metadata import entry_points -from kubernetes_asyncio.client import ApiException +import kr8s +from kr8s.asyncio.objects import APIObject, Pod, Service -from dask_kubernetes.common.auth import ClusterAuth from dask_kubernetes.common.networking import get_scheduler_address -from dask_kubernetes.aiopykube import HTTPClient, KubeConfig -from dask_kubernetes.aiopykube.dask import DaskCluster from distributed.core import rpc _ANNOTATION_NAMESPACES_TO_IGNORE = ( @@ -35,6 +32,48 @@ PLUGINS.append(ep.load()) +class DaskCluster(APIObject): + version = "kubernetes.dask.org/v1" + endpoint = "daskclusters" + kind = "DaskCluster" + plural = "daskclusters" + singular = "daskcluster" + namespaced = True + + # TODO make scalable + # scalable = True + # # Dot notation not yet supported in kr8s, patching cluster replicas not yet supported in controller + # scalable_spec = "worker.replicas" + + +class DaskWorkerGroup(APIObject): + version = "kubernetes.dask.org/v1" + endpoint = "daskworkergroups" + kind = "DaskWorkerGroup" + plural = "daskworkergroups" + singular = "daskworkergroup" + namespaced = True + scalable = True + + +class DaskAutoscaler(APIObject): + version = "kubernetes.dask.org/v1" + endpoint = "daskautoscalers" + kind = "DaskAutoscaler" + plural = "daskautoscalers" + singular = "daskautoscaler" + namespaced = True + + +class DaskJob(APIObject): + version = "kubernetes.dask.org/v1" + endpoint = "daskjobs" + kind = "DaskJob" + plural = "daskjobs" + singular = "daskjob" + namespaced = True + + class SchedulerCommError(Exception): """Raised when unable to communicate with a scheduler.""" @@ -218,9 +257,6 @@ def build_cluster_spec(name, worker_spec, scheduler_spec, annotations, labels): @kopf.on.startup() async def startup(settings: kopf.OperatorSettings, **kwargs): - # Authenticate with k8s - await ClusterAuth.load_first() - # Set server and client timeouts to reconnect from time to time. # In rare occasions the connection might go idle we will no longer receive any events. # These timeouts should help in those cases. @@ -258,79 +294,64 @@ async def daskcluster_create_components( ): """When the DaskCluster status.phase goes into Created create the cluster components.""" logger.info("Creating Dask cluster components.") - async with kubernetes.client.api_client.ApiClient() as api_client: - api = kubernetes.client.CoreV1Api(api_client) - custom_api = kubernetes.client.CustomObjectsApi(api_client) - - annotations = _get_annotations(meta) - labels = _get_labels(meta) - scheduler_spec = spec.get("scheduler", {}) - if "metadata" in scheduler_spec: - if "annotations" in scheduler_spec["metadata"]: - annotations.update(**scheduler_spec["metadata"]["annotations"]) - if "labels" in scheduler_spec["metadata"]: - labels.update(**scheduler_spec["metadata"]["labels"]) - data = build_scheduler_pod_spec( - name, scheduler_spec.get("spec"), annotations, labels - ) - kopf.adopt(data) - pod = await api.list_namespaced_pod( - namespace=namespace, - label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={name}", - ) - if not pod.items: - await api.create_namespaced_pod( - namespace=namespace, - body=data, - ) - logger.info( - f"Scheduler pod {data['metadata']['name']} created in {namespace}." - ) + api = await kr8s.asyncio.api() + annotations = _get_annotations(meta) + labels = _get_labels(meta) + scheduler_spec = spec.get("scheduler", {}) + if "metadata" in scheduler_spec: + if "annotations" in scheduler_spec["metadata"]: + annotations.update(**scheduler_spec["metadata"]["annotations"]) + if "labels" in scheduler_spec["metadata"]: + labels.update(**scheduler_spec["metadata"]["labels"]) + data = build_scheduler_pod_spec( + name, scheduler_spec.get("spec"), annotations, labels + ) + kopf.adopt(data) + pods = await api.get( + "pods", + namespace=namespace, + label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={name}", + ) + if len(pods) == 0: + scheduler_pod = await Pod(data) + await scheduler_pod.create() + logger.info(f"Scheduler pod {data['metadata']['name']} created in {namespace}.") - data = build_scheduler_service_spec( - name, scheduler_spec.get("service"), annotations, labels - ) - kopf.adopt(data) - service = await api.list_namespaced_service( - namespace=namespace, - label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={name}", - ) - if not pod.items: - await api.create_namespaced_service( - namespace=namespace, - body=data, - ) + data = build_scheduler_service_spec( + name, scheduler_spec.get("service"), annotations, labels + ) + kopf.adopt(data) + services = await api.get( + "services", + namespace=namespace, + label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={name}", + ) + if len(services) == 0: + scheduler_service = await Service(data) + await scheduler_service.create() logger.info( f"Scheduler service {data['metadata']['name']} created in {namespace}." ) - worker_spec = spec.get("worker", {}) - annotations = _get_annotations(meta) - labels = _get_labels(meta) - if "metadata" in worker_spec: - if "annotations" in worker_spec["metadata"]: - annotations.update(**worker_spec["metadata"]["annotations"]) - if "labels" in worker_spec["metadata"]: - labels.update(**worker_spec["metadata"]["labels"]) - data = build_default_worker_group_spec(name, worker_spec, annotations, labels) - worker_group = await custom_api.list_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", - namespace=namespace, - label_selector=f"dask.org/component=workergroup,dask.org/cluster-name={name}", - ) - if not worker_group["items"]: - await custom_api.create_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", - namespace=namespace, - body=data, - ) - logger.info( - f"Worker group {data['metadata']['name']} created in {namespace}." - ) + worker_spec = spec.get("worker", {}) + annotations = _get_annotations(meta) + labels = _get_labels(meta) + if "metadata" in worker_spec: + if "annotations" in worker_spec["metadata"]: + annotations.update(**worker_spec["metadata"]["annotations"]) + if "labels" in worker_spec["metadata"]: + labels.update(**worker_spec["metadata"]["labels"]) + data = build_default_worker_group_spec(name, worker_spec, annotations, labels) + + worker_groups = await api.get( + "daskworkergroups", + namespace=namespace, + label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={name}", + ) + if len(worker_groups) == 0: + dask_worker_group = await DaskWorkerGroup(data) + await dask_worker_group.create() + logger.info(f"Worker group {data['metadata']['name']} created in {namespace}.") patch.status["phase"] = "Pending" @@ -347,41 +368,25 @@ async def handle_scheduler_service_status( else: phase = "Running" - api = HTTPClient(KubeConfig.from_env()) - cluster = await DaskCluster.objects(api, namespace=namespace).get_by_name( - labels["dask.org/cluster-name"] + cluster = await DaskCluster.get( + labels["dask.org/cluster-name"], namespace=namespace ) await cluster.patch({"status": {"phase": phase}}) @kopf.on.create("daskworkergroup.kubernetes.dask.org") -async def daskworkergroup_create(spec, name, namespace, logger, **kwargs): - async with kubernetes.client.api_client.ApiClient() as api_client: - api = kubernetes.client.CustomObjectsApi(api_client) - cluster = await api.get_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", - namespace=namespace, - name=spec["cluster"], - ) - new_spec = dict(spec) - kopf.adopt(new_spec, owner=cluster) - api.api_client.set_default_header( - "content-type", "application/merge-patch+json" - ) - await api.patch_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", - namespace=namespace, - name=name, - body=new_spec, - ) - logger.info(f"Successfully adopted by {spec['cluster']}") +async def daskworkergroup_create(body, spec, name, namespace, logger, **kwargs): + cluster = await DaskCluster.get(spec["cluster"], namespace=namespace) + new_spec = dict(spec) + kopf.adopt(new_spec, owner=cluster.raw) + + worker_group = await DaskWorkerGroup(body) + await worker_group.patch(new_spec) + logger.info(f"Successfully adopted by {cluster.name}") del kwargs["new"] await daskworkergroup_replica_update( + body=body, spec=spec, name=name, namespace=namespace, @@ -394,6 +399,7 @@ async def daskworkergroup_create(spec, name, namespace, logger, **kwargs): async def retire_workers( n_workers, scheduler_service_name, worker_group_name, namespace, logger ): + api = await kr8s.asyncio.api() # Try gracefully retiring via the HTTP API dashboard_address = await get_scheduler_address( scheduler_service_name, @@ -438,13 +444,12 @@ async def retire_workers( logger.info( f"Scaling {worker_group_name} failed via the Dask RPC, falling back to LIFO scaling" ) - async with kubernetes.client.api_client.ApiClient() as api_client: - api = kubernetes.client.CoreV1Api(api_client) - workers = await api.list_namespaced_pod( - namespace=namespace, - label_selector=f"dask.org/workergroup-name={worker_group_name}", - ) - return [w["metadata"]["name"] for w in workers.items[:-n_workers]] + workers = await api.get( + "pods", + label_selector=f"dask.org/workergroup-name={worker_group_name}", + namespace=namespace, + ) + return [w.name for w in workers[:-n_workers]] async def get_desired_workers(scheduler_service_name, namespace, logger): @@ -485,86 +490,63 @@ async def get_desired_workers(scheduler_service_name, namespace, logger): async def daskworkergroup_replica_update( name, namespace, meta, spec, new, body, logger, **kwargs ): + api = await kr8s.asyncio.api() cluster_name = spec["cluster"] # Replica updates can come in quick succession and the changes must be applied atomically to ensure # the number of workers ends in the correct state async with worker_group_scale_locks[f"{namespace}/{name}"]: - async with kubernetes.client.api_client.ApiClient() as api_client: - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - corev1api = kubernetes.client.CoreV1Api(api_client) - - try: - cluster = await customobjectsapi.get_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", + cluster = await DaskCluster.get(cluster_name, namespace=namespace) + + workers = await api.get( + "pods", + namespace=namespace, + label_selector=f"dask.org/workergroup-name={name}", + ) + current_workers = len( + [w for w in workers if w.status["phase"] != "Terminating"] + ) + desired_workers = new + workers_needed = desired_workers - current_workers + labels = _get_labels(meta) + annotations = _get_annotations(meta) + worker_spec = spec["worker"] + if "metadata" in worker_spec: + if "annotations" in worker_spec["metadata"]: + annotations.update(**worker_spec["metadata"]["annotations"]) + if "labels" in worker_spec["metadata"]: + labels.update(**worker_spec["metadata"]["labels"]) + if workers_needed > 0: + for _ in range(workers_needed): + data = build_worker_pod_spec( + worker_group_name=name, namespace=namespace, - name=cluster_name, + cluster_name=cluster_name, + uuid=uuid4().hex[:10], + spec=worker_spec["spec"], + annotations=annotations, + labels=labels, ) - except ApiException as e: - if e.status == 404: - # No need to scale if worker group is deleted, pods will be cleaned up - return - else: - raise e - - cluster_labels = cluster.get("metadata", {}).get("labels", {}) - - workers = await corev1api.list_namespaced_pod( + kopf.adopt(data, owner=body) + kopf.label(data, labels=cluster.labels) + worker_pod = await Pod(data) + await worker_pod.create() + logger.info(f"Scaled worker group {name} up to {desired_workers} workers.") + if workers_needed < 0: + worker_ids = await retire_workers( + n_workers=-workers_needed, + scheduler_service_name=f"{cluster_name}-scheduler", + worker_group_name=name, namespace=namespace, - label_selector=f"dask.org/workergroup-name={name}", + logger=logger, ) - current_workers = len( - [w for w in workers.items if w.status.phase != "Terminating"] + logger.info(f"Workers to close: {worker_ids}") + for wid in worker_ids: + worker = await Pod.get(wid, namespace=namespace) + await worker.delete() + logger.info( + f"Scaled worker group {name} down to {desired_workers} workers." ) - desired_workers = new - workers_needed = desired_workers - current_workers - labels = _get_labels(meta) - annotations = _get_annotations(meta) - worker_spec = spec["worker"] - if "metadata" in worker_spec: - if "annotations" in worker_spec["metadata"]: - annotations.update(**worker_spec["metadata"]["annotations"]) - if "labels" in worker_spec["metadata"]: - labels.update(**worker_spec["metadata"]["labels"]) - if workers_needed > 0: - for _ in range(workers_needed): - data = build_worker_pod_spec( - worker_group_name=name, - namespace=namespace, - cluster_name=cluster_name, - uuid=uuid4().hex[:10], - spec=worker_spec["spec"], - annotations=annotations, - labels=labels, - ) - kopf.adopt(data, owner=body) - kopf.label(data, labels=cluster_labels) - await corev1api.create_namespaced_pod( - namespace=namespace, - body=data, - ) - logger.info( - f"Scaled worker group {name} up to {desired_workers} workers." - ) - if workers_needed < 0: - worker_ids = await retire_workers( - n_workers=-workers_needed, - scheduler_service_name=f"{cluster_name}-scheduler", - worker_group_name=name, - namespace=namespace, - logger=logger, - ) - logger.info(f"Workers to close: {worker_ids}") - for wid in worker_ids: - await corev1api.delete_namespaced_pod( - name=wid, - namespace=namespace, - ) - logger.info( - f"Scaled worker group {name} down to {desired_workers} workers." - ) @kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True) @@ -587,62 +569,49 @@ async def daskjob_create_components( spec, name, namespace, logger, patch, meta, **kwargs ): logger.info("Creating Dask job components.") - async with kubernetes.client.api_client.ApiClient() as api_client: - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - corev1api = kubernetes.client.CoreV1Api(api_client) - - cluster_name = f"{name}" - labels = _get_labels(meta) - annotations = _get_annotations(meta) - cluster_spec = spec["cluster"] - if "metadata" in cluster_spec: - if "annotations" in cluster_spec["metadata"]: - annotations.update(**cluster_spec["metadata"]["annotations"]) - if "labels" in cluster_spec["metadata"]: - labels.update(**cluster_spec["metadata"]["labels"]) - cluster_spec = build_cluster_spec( - cluster_name, - cluster_spec["spec"]["worker"], - cluster_spec["spec"]["scheduler"], - annotations, - labels, - ) - kopf.adopt(cluster_spec) - await customobjectsapi.create_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", - namespace=namespace, - body=cluster_spec, - ) - logger.info( - f"Cluster {cluster_spec['metadata']['name']} for job {name} created in {namespace}." - ) - - labels = _get_labels(meta) - annotations = _get_annotations(meta) - job_spec = spec["job"] - if "metadata" in job_spec: - if "annotations" in job_spec["metadata"]: - annotations.update(**job_spec["metadata"]["annotations"]) - if "labels" in job_spec["metadata"]: - labels.update(**job_spec["metadata"]["labels"]) - job_pod_spec = build_job_pod_spec( - job_name=name, - cluster_name=cluster_name, - namespace=namespace, - spec=job_spec["spec"], - annotations=annotations, - labels=labels, - ) - kopf.adopt(job_pod_spec) - await corev1api.create_namespaced_pod( - namespace=namespace, - body=job_pod_spec, - ) - patch.status["clusterName"] = cluster_name - patch.status["jobStatus"] = "ClusterCreated" - patch.status["jobRunnerPodName"] = get_job_runner_pod_name(name) + cluster_name = f"{name}" + labels = _get_labels(meta) + annotations = _get_annotations(meta) + cluster_spec = spec["cluster"] + if "metadata" in cluster_spec: + if "annotations" in cluster_spec["metadata"]: + annotations.update(**cluster_spec["metadata"]["annotations"]) + if "labels" in cluster_spec["metadata"]: + labels.update(**cluster_spec["metadata"]["labels"]) + cluster_spec = build_cluster_spec( + cluster_name, + cluster_spec["spec"]["worker"], + cluster_spec["spec"]["scheduler"], + annotations, + labels, + ) + kopf.adopt(cluster_spec) + cluster = await DaskCluster(cluster_spec) + await cluster.create() + logger.info(f"Cluster {cluster.name} for job {name} created in {namespace}.") + + labels = _get_labels(meta) + annotations = _get_annotations(meta) + job_spec = spec["job"] + if "metadata" in job_spec: + if "annotations" in job_spec["metadata"]: + annotations.update(**job_spec["metadata"]["annotations"]) + if "labels" in job_spec["metadata"]: + labels.update(**job_spec["metadata"]["labels"]) + job_pod_spec = build_job_pod_spec( + job_name=name, + cluster_name=cluster_name, + namespace=namespace, + spec=job_spec["spec"], + annotations=annotations, + labels=labels, + ) + kopf.adopt(job_pod_spec) + job_pod = await Pod(job_pod_spec) + await job_pod.create() + patch.status["clusterName"] = cluster_name + patch.status["jobStatus"] = "ClusterCreated" + patch.status["jobRunnerPodName"] = get_job_runner_pod_name(name) @kopf.on.field( @@ -651,24 +620,17 @@ async def daskjob_create_components( labels={"dask.org/component": "job-runner"}, new="Running", ) -async def handle_runner_status_change_running(meta, namespace, logger, **kwargs): +async def handle_runner_status_change_running(labels, namespace, logger, **kwargs): logger.info("Job now in running") - async with kubernetes.client.api_client.ApiClient() as api_client: - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - api_client.set_default_header("content-type", "application/merge-patch+json") - await customobjectsapi.patch_namespaced_custom_object_status( - group="kubernetes.dask.org", - version="v1", - plural="daskjobs", - namespace=namespace, - name=meta["labels"]["dask.org/cluster-name"], - body={ - "status": { - "jobStatus": "Running", - "startTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), - } - }, - ) + job = await DaskJob.get(labels.get("dask.org/cluster-name"), namespace=namespace) + await job.patch( + { + "jobStatus": "Running", + "startTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), + }, + subresource="status", + ) + await job.refresh() @kopf.on.field( @@ -679,29 +641,17 @@ async def handle_runner_status_change_running(meta, namespace, logger, **kwargs) ) async def handle_runner_status_change_succeeded(meta, namespace, logger, **kwargs): logger.info("Job succeeded, deleting Dask cluster.") - async with kubernetes.client.api_client.ApiClient() as api_client: - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - await customobjectsapi.delete_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", - namespace=namespace, - name=meta["labels"]["dask.org/cluster-name"], - ) - api_client.set_default_header("content-type", "application/merge-patch+json") - await customobjectsapi.patch_namespaced_custom_object_status( - group="kubernetes.dask.org", - version="v1", - plural="daskjobs", - namespace=namespace, - name=meta["labels"]["dask.org/cluster-name"], - body={ - "status": { - "jobStatus": "Successful", - "endTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), - } - }, - ) + cluster_name = meta["labels"]["dask.org/cluster-name"] + cluster = await DaskCluster.get(cluster_name, namespace=namespace) + job = await DaskJob.get(cluster_name, namespace=namespace) + await cluster.delete() + await job.patch( + { + "jobStatus": "Successful", + "endTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), + }, + subresource="status", + ) @kopf.on.field( @@ -712,169 +662,91 @@ async def handle_runner_status_change_succeeded(meta, namespace, logger, **kwarg ) async def handle_runner_status_change_succeeded(meta, namespace, logger, **kwargs): logger.info("Job failed, deleting Dask cluster.") - async with kubernetes.client.api_client.ApiClient() as api_client: - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - await customobjectsapi.delete_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", - namespace=namespace, - name=meta["labels"]["dask.org/cluster-name"], - ) - api_client.set_default_header("content-type", "application/merge-patch+json") - await customobjectsapi.patch_namespaced_custom_object_status( - group="kubernetes.dask.org", - version="v1", - plural="daskjobs", - namespace=namespace, - name=meta["labels"]["dask.org/cluster-name"], - body={ - "status": { - "jobStatus": "Failed", - "endTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), - } - }, - ) + cluster_name = meta["labels"]["dask.org/cluster-name"] + cluster = await DaskCluster.get(cluster_name, namespace=namespace) + job = await DaskJob.get(cluster_name, namespace=namespace) + await cluster.delete() + await job.patch( + { + "jobStatus": "Failed", + "endTime": datetime.utcnow().strftime(KUBERNETES_DATETIME_FORMAT), + }, + subresource="status", + ) @kopf.on.create("daskautoscaler.kubernetes.dask.org") -async def daskautoscaler_create(spec, name, namespace, logger, **kwargs): +async def daskautoscaler_create(body, spec, name, namespace, logger, **kwargs): """When an autoscaler is created make it a child of the associated cluster for cascade deletion.""" - async with kubernetes.client.api_client.ApiClient() as api_client: - api = kubernetes.client.CustomObjectsApi(api_client) - cluster = await api.get_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskclusters", - namespace=namespace, - name=spec["cluster"], - ) - new_spec = dict(spec) - kopf.adopt(new_spec, owner=cluster) - api.api_client.set_default_header( - "content-type", "application/merge-patch+json" - ) - await api.patch_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=namespace, - name=name, - body=new_spec, - ) - logger.info(f"Successfully adopted by {spec['cluster']}") + autoscaler = await DaskAutoscaler(body) + cluster = await DaskCluster.get(spec["cluster"], namespace=namespace) + new_spec = dict(spec) + kopf.adopt(new_spec, owner=cluster.raw) + await autoscaler.patch(new_spec) + logger.info(f"Successfully adopted by {spec['cluster']}") @kopf.timer("daskautoscaler.kubernetes.dask.org", interval=5.0) -async def daskautoscaler_adapt(spec, name, namespace, logger, **kwargs): - async with kubernetes.client.api_client.ApiClient() as api_client: - coreapi = kubernetes.client.CoreV1Api(api_client) - - pod_ready = False - try: - scheduler_pod = await coreapi.read_namespaced_pod( - f"{spec['cluster']}-scheduler", namespace - ) - if scheduler_pod.status.phase == "Running": - pod_ready = True - except ApiException as e: - if e.status != 404: - raise e - - if not pod_ready: - logger.info("Scheduler not ready, skipping autoscaling") - return - - customobjectsapi = kubernetes.client.CustomObjectsApi(api_client) - customobjectsapi.api_client.set_default_header( - "content-type", "application/merge-patch+json" +async def daskautoscaler_adapt(body, spec, name, namespace, logger, **kwargs): + scheduler_pod = await Pod.get(f"{spec['cluster']}-scheduler", namespace=namespace) + if not scheduler_pod.ready(): + logger.info("Scheduler not ready, skipping autoscaling") + return + + autoscaler = await DaskAutoscaler(body) + worker_group = await DaskWorkerGroup.get( + f"{spec['cluster']}-default", namespace=namespace + ) + current_replicas = int(worker_group.spec["worker"]["replicas"]) + cooldown_until = float( + autoscaler.annotations.get( + DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time() ) + ) - autoscaler_resource = await customobjectsapi.get_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=namespace, - name=name, - ) + # Cooldown autoscaling to prevent thrashing + if time.time() < cooldown_until: + logger.debug("Autoscaler for %s is in cooldown", spec["cluster"]) + return - worker_group_resource = await customobjectsapi.get_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", + # Ask the scheduler for the desired number of worker + try: + desired_workers = await get_desired_workers( + scheduler_service_name=f"{spec['cluster']}-scheduler", namespace=namespace, - name=f"{spec['cluster']}-default", + logger=logger, ) - - current_replicas = int(worker_group_resource["spec"]["worker"]["replicas"]) - cooldown_until = float( - autoscaler_resource.get("metadata", {}) - .get("annotations", {}) - .get(DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION, time.time()) - ) - - # Cooldown autoscaling to prevent thrashing - if time.time() < cooldown_until: - logger.debug("Autoscaler for %s is in cooldown", spec["cluster"]) - return - - # Ask the scheduler for the desired number of worker - try: - desired_workers = await get_desired_workers( - scheduler_service_name=f"{spec['cluster']}-scheduler", - namespace=namespace, - logger=logger, - ) - except SchedulerCommError: - logger.error("Unable to get desired number of workers from scheduler.") - return - - # Ensure the desired number is within the min and max - desired_workers = max(spec["minimum"], desired_workers) - desired_workers = min(spec["maximum"], desired_workers) - - if current_replicas > 0: - max_scale_down = int(current_replicas * 0.25) - max_scale_down = 1 if max_scale_down == 0 else max_scale_down - desired_workers = max(current_replicas - max_scale_down, desired_workers) - - # Update the default DaskWorkerGroup - if desired_workers != current_replicas: - await customobjectsapi.patch_namespaced_custom_object_scale( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", - namespace=namespace, - name=f"{spec['cluster']}-default", - body={"spec": {"replicas": desired_workers}}, - ) - - cooldown_until = time.time() + 15 - - await customobjectsapi.patch_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=namespace, - name=name, - body={ - "metadata": { - "annotations": { - DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str( - cooldown_until - ) - } + except SchedulerCommError: + logger.error("Unable to get desired number of workers from scheduler.") + return + + # Ensure the desired number is within the min and max + desired_workers = max(spec["minimum"], desired_workers) + desired_workers = min(spec["maximum"], desired_workers) + + if current_replicas > 0: + max_scale_down = int(current_replicas * 0.25) + max_scale_down = 1 if max_scale_down == 0 else max_scale_down + desired_workers = max(current_replicas - max_scale_down, desired_workers) + + # Update the default DaskWorkerGroup + if desired_workers != current_replicas: + await worker_group.patch({"spec": {"replicas": desired_workers}}) + cooldown_until = time.time() + 15 + await autoscaler.patch( + { + "metadata": { + "annotations": { + DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION: str(cooldown_until) } - }, - ) + } + } + ) - logger.info( - "Autoscaler updated %s worker count from %d to %d", - spec["cluster"], - current_replicas, - desired_workers, - ) - else: - logger.debug( - "Not autoscaling %s with %d workers", spec["cluster"], current_replicas - ) + logger.info( + f"Autoscaler updated {spec['cluster']} worker count from {current_replicas} to {desired_workers}" + ) + else: + logger.debug( + f"Not autoscaling {spec['cluster']} with {current_replicas} workers" + ) diff --git a/requirements.txt b/requirements.txt index 588d0e6ba..46772fbe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ kubernetes-asyncio>=12.0.1 kopf>=1.35.3 pykube-ng>=22.9.0 rich>=12.5.1 +kr8s>=0.4.0