Skip to content

Commit

Permalink
reformatted with precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
sil-lnagel committed Mar 20, 2024
1 parent a8e7c1f commit 55b59af
Showing 1 changed file with 90 additions and 49 deletions.
139 changes: 90 additions & 49 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import copy
import logging
import re
import time
from enum import Enum
from collections import defaultdict
from contextlib import suppress
from datetime import datetime
from enum import Enum
from uuid import uuid4
import re

import aiohttp
import dask.config
Expand All @@ -27,10 +28,9 @@
DaskWorkerGroup,
)
from dask_kubernetes.operator.networking import get_scheduler_address
import logging

logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('aiohttp').setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("aiohttp").setLevel(logging.WARNING)

_ANNOTATION_NAMESPACES_TO_IGNORE = (
"kopf.zalando.org",
Expand All @@ -55,12 +55,8 @@
"resumed",
"waiting",
)
IDLE_STATES = (
"released",
"error",
"cancelled",
"forgotton"
)
IDLE_STATES = ("released", "error", "cancelled", "forgotton")


class WorkerState(Enum):
IDLE = 1
Expand Down Expand Up @@ -294,8 +290,9 @@ async def startup(settings: kopf.OperatorSettings, logger, **kwargs):
# https://kopf.readthedocs.io/en/latest/configuration/#networking-timeouts
settings.networking.request_timeout = 10

# val = dask.config.get("kubernetes.controller.worker-allocation.batch-size")
show_config = lambda config: logger.info(f"{config}: {dask.config.get(config, None)}")
show_config = lambda config: logger.info(
f"{config}: {dask.config.get(config, None)}"
)

logger.info("- configuration -")
show_config("kubernetes.controller.autoscaler.method")
Expand All @@ -305,7 +302,6 @@ async def startup(settings: kopf.OperatorSettings, logger, **kwargs):
logger.info("---")



# There may be useful things for us to expose via the liveness probe
# https://kopf.readthedocs.io/en/stable/probing/#probe-handlers
@kopf.on.probe(id="now")
Expand Down Expand Up @@ -419,9 +415,7 @@ async def daskworkergroup_create(body, namespace, logger, **kwargs):
)


async def get_workers_to_close(
n_workers, scheduler_service_name, namespace, logger
):
async def get_workers_to_close(n_workers, scheduler_service_name, namespace, logger):
comm_address = await get_scheduler_address(
scheduler_service_name,
namespace,
Expand All @@ -435,7 +429,6 @@ async def get_workers_to_close(
return workers_to_close



async def retire_workers(
workers_to_close, scheduler_service_name, worker_group_name, namespace, logger
):
Expand Down Expand Up @@ -490,7 +483,7 @@ async def retire_workers(
namespace=namespace,
label_selector={"dask.org/workergroup-name": worker_group_name},
)
return [w.name for w in workers[:-len(workers_to_close)]]
return [w.name for w in workers[: -len(workers_to_close)]]


async def check_scheduler_idle(scheduler_service_name, namespace, logger):
Expand Down Expand Up @@ -626,15 +619,22 @@ async def get_managed_pod(deployment, wid, logger):

# this filtering is currently required because deployment.pods() will return other
# pods that are not part of the deployment. potential kr8s bug?
pods = list(filter(lambda x: x.name.startswith(wid), pods)) #
pods = list(filter(lambda x: x.name.startswith(wid), pods)) #
if len(pods) > 1:
logger.warning(f"Deployment {deployment} has {len(pods)} pods but should have exactly 1.")
logger.warning(
f"Deployment {deployment} has {len(pods)} pods but should have exactly 1."
)

return pods[0]


async def determine_worker_state(pod, logger, dashboard_port=8787,
active_states=ACTIVE_STATES, idle_states=IDLE_STATES):
async def determine_worker_state(
pod,
logger,
dashboard_port=8787,
active_states=ACTIVE_STATES,
idle_states=IDLE_STATES,
):
"""Determine if the worker is idle, busy or can't be checked."""
worker_state = WorkerState.UNCERTAIN
metrics_url = f"http://{pod.status['podIP']}:{dashboard_port}/metrics"
Expand All @@ -646,23 +646,31 @@ async def determine_worker_state(pod, logger, dashboard_port=8787,
task_states = parse_dask_worker_tasks(
metrics_text,
active_states=active_states,
idle_states=idle_states
idle_states=idle_states,
)
protected_tasks = [(s, task_states[s]) for s in active_states if task_states[s] > 0.0]
protected_tasks = [
(s, task_states[s])
for s in active_states
if task_states[s] > 0.0
]
if protected_tasks:
logger.info(f"pod {pod} is busy.")
logger.info(f"protected tasks: {protected_tasks}")
worker_state = WorkerState.BUSY
else:
worker_state = WorkerState.IDLE
else:
logger.warning(f"Metrics request to pod {pod} failed (http_status={response.status}).")
logger.warning(
f"Metrics request to pod {pod} failed (http_status={response.status})."
)
worker_state = WorkerState.UNCERTAIN
except aiohttp.ClientError as e:
logger.warning(f"Could not query: '{metrics_url}'. Do workers expose /metrics at this URL?")
logger.warning(
f"Could not query: '{metrics_url}'. Do workers expose /metrics at this URL?"
)
logger.warning(f"aiohttp.ClientError: {e}")
worker_state = WorkerState.UNCERTAIN

return worker_state


Expand Down Expand Up @@ -708,7 +716,9 @@ async def daskworkergroup_replica_update(
dask.config.get("kubernetes.controller.worker-allocation.delay") or 0
)
if workers_needed > 0:
logger.info(f"Starting to scale worker group {name} up to {desired_workers} workers...")
logger.info(
f"Starting to scale worker group {name} up to {desired_workers} workers..."
)
for _ in range(batch_size):
data = build_worker_deployment_spec(
worker_group_name=name,
Expand All @@ -732,7 +742,9 @@ async def daskworkergroup_replica_update(
if workers_needed < 0:
workers_not_needed = workers_needed * -1

logger.info(f"Attempting to downscale {name} by -{workers_not_needed} workers.")
logger.info(
f"Attempting to downscale {name} by -{workers_not_needed} workers."
)
worker_ids = await get_workers_to_close(
n_workers=-workers_needed,
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
Expand All @@ -744,12 +756,18 @@ async def daskworkergroup_replica_update(

logger.info(f"Workers to close: {worker_ids}")

if dask.config.get("kubernetes.controller.autoscaler.method", "default") != "careful":
if (
dask.config.get("kubernetes.controller.autoscaler.method", "default")
!= "careful"
):
await retire_workers(
workers_to_close=worker_ids,
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
cluster_name=cluster_name
),
worker_group_name=name,
namespace=namespace, logger=logger
namespace=namespace,
logger=logger,
)
for wid in worker_ids:
worker_deployment = await Deployment(wid, namespace=namespace)
Expand All @@ -758,7 +776,9 @@ async def daskworkergroup_replica_update(
f"Scaled worker group {name} down to {desired_workers} workers."
)
else:
deployments = [await Deployment(wid, namespace=namespace) for wid in worker_ids]
deployments = [
await Deployment(wid, namespace=namespace) for wid in worker_ids
]

# if we don't wait for ready deployment.pods() will fail with:
# return Box(self.raw["spec"]) KeyError: 'spec' kr8s bug?
Expand All @@ -767,13 +787,21 @@ async def daskworkergroup_replica_update(

ready_deployments, not_ready_deployments = [], []
for deployment, ready, wid in zip(deployments, readiness, worker_ids):
(ready_deployments if ready else not_ready_deployments).append((deployment, wid))
(ready_deployments if ready else not_ready_deployments).append(
(deployment, wid)
)

idle_deployments, busy_deployments, uncertain_deployments = [], [], []
dashboard_port = dask.config.get("kubernetes.controller.worker.dashboard-port", 8787)
dashboard_port = dask.config.get(
"kubernetes.controller.worker.dashboard-port", 8787
)
for deployment, wid in ready_deployments:
pod = await get_managed_pod(deployment=deployment, wid=wid, logger=logger)
worker_state = await determine_worker_state(pod=pod, dashboard_port=dashboard_port, logger=logger)
pod = await get_managed_pod(
deployment=deployment, wid=wid, logger=logger
)
worker_state = await determine_worker_state(
pod=pod, dashboard_port=dashboard_port, logger=logger
)

if worker_state == WorkerState.IDLE:
idle_deployments.append((deployment, wid))
Expand All @@ -783,18 +811,22 @@ async def daskworkergroup_replica_update(
uncertain_deployments.append((deployment, wid))
else:
logger.error(f"Unknown worker state {worker_state:r}")
raise kopf.PermanentError(f"Unknown worker state {worker_state:r}")
raise kopf.PermanentError(
f"Unknown worker state {worker_state:r}"
)

retired_worker_count = 0
if idle_deployments:
workers_to_retire = [wid for _, wid in idle_deployments]
logger.info(f"Gracefully retire workers: {workers_to_retire}")
await retire_workers(
workers_to_close=workers_to_retire,
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
cluster_name=cluster_name
),
worker_group_name=name,
namespace=namespace,
logger=logger
logger=logger,
)
for deployment, _ in idle_deployments:
await deployment.delete()
Expand All @@ -805,22 +837,31 @@ async def daskworkergroup_replica_update(
await worker_deployment.delete()
retired_worker_count += 1


if busy_deployments:
busy_workers = [wid for _, wid in busy_deployments]
logger.info(f"Refusing to retire busy workers: {busy_workers}")

if uncertain_deployments:
uncertain_workers = [wid for _, wid in uncertain_deployments]
logger.info(f"Refusing to retire workers that could not be queried: {uncertain_workers}")
logger.info(
f"Refusing to retire workers that could not be queried: {uncertain_workers}"
)

if retired_worker_count != workers_not_needed:
logging.info(f"Could only retire {retired_worker_count} of {workers_not_needed} workers.")
logging.info(f"(busy={len(busy_deployments)}, uncertain={len(uncertain_deployments)})")
retry_delay = dask.config.get("kubernetes.controller.autoscaler.retry-delay", 6)
raise kopf.TemporaryError(f"Retired ({retired_worker_count}/{workers_not_needed}) workers"
f" busy={len(busy_deployments)}, uncertain={len(uncertain_deployments)})",
delay=retry_delay)
logging.info(
f"Could only retire {retired_worker_count} of {workers_not_needed} workers."
)
logging.info(
f"(busy={len(busy_deployments)}, uncertain={len(uncertain_deployments)})"
)
retry_delay = dask.config.get(
"kubernetes.controller.autoscaler.retry-delay", 6
)
raise kopf.TemporaryError(
f"Retired ({retired_worker_count}/{workers_not_needed}) workers"
f" busy={len(busy_deployments)}, uncertain={len(uncertain_deployments)})",
delay=retry_delay,
)
else:
logger.info(
f"Successfully scaled worker group {name} down to {desired_workers} workers."
Expand Down

0 comments on commit 55b59af

Please sign in to comment.