Skip to content

Commit

Permalink
Merge branch 'main' of github.com:dask/dask-kubernetes into py-312
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Apr 16, 2024
2 parents 21b7d49 + 258d556 commit c0844d6
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 77 deletions.
21 changes: 20 additions & 1 deletion dask_kubernetes/common/objects.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
Convenience functions for creating pod templates.
"""

import copy
import json
from collections import namedtuple

from kubernetes import client
from kubernetes.client.configuration import Configuration

from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
from dask_kubernetes.constants import (
KUBECLUSTER_CONTAINER_NAME,
MAX_CLUSTER_NAME_LEN,
VALID_CLUSTER_NAME,
)
from dask_kubernetes.exceptions import ValidationError

_FakeResponse = namedtuple("_FakeResponse", ["data"])

Expand Down Expand Up @@ -365,3 +371,16 @@ def clean_pdb_template(pdb_template):
pdb_template.spec.selector = client.V1LabelSelector()

return pdb_template


def validate_cluster_name(cluster_name: str) -> None:
"""Raise exception if cluster name is too long and/or has invalid characters"""
if not VALID_CLUSTER_NAME.match(cluster_name):
raise ValidationError(
message=(
f"The DaskCluster {cluster_name} is invalid: a lowercase RFC 1123 subdomain must "
"consist of lower case alphanumeric characters, '-' or '.', and must start "
"and end with an alphanumeric character. DaskCluster name must also be under "
f"{MAX_CLUSTER_NAME_LEN} characters."
)
)
26 changes: 24 additions & 2 deletions dask_kubernetes/common/tests/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dask_kubernetes.common.objects import make_pod_from_dict
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
import pytest

from dask_kubernetes.common.objects import make_pod_from_dict, validate_cluster_name
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME, MAX_CLUSTER_NAME_LEN
from dask_kubernetes.exceptions import ValidationError


def test_make_pod_from_dict():
Expand Down Expand Up @@ -64,3 +67,22 @@ def test_make_pod_from_dict_default_container_name():
assert pod.spec.containers[0].name == "dask-0"
assert pod.spec.containers[1].name == "sidecar"
assert pod.spec.containers[2].name == "dask-2"


@pytest.mark.parametrize(
"cluster_name",
[
(MAX_CLUSTER_NAME_LEN + 1) * "a",
"invalid.chars.in.name",
],
)
def test_validate_cluster_name_raises_on_invalid_name(
cluster_name,
):

with pytest.raises(ValidationError):
validate_cluster_name(cluster_name)


def test_validate_cluster_name_success_on_valid_name():
assert validate_cluster_name("valid-cluster-name-123") is None
10 changes: 10 additions & 0 deletions dask_kubernetes/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
import re

KUBECLUSTER_CONTAINER_NAME = "dask-container"
KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63
SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler"
MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len(
SCHEDULER_NAME_TEMPLATE.format(cluster_name="")
)
VALID_CLUSTER_NAME = re.compile(
rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$"
)
7 changes: 7 additions & 0 deletions dask_kubernetes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ class CrashLoopBackOffError(Exception):

class SchedulerStartupError(Exception):
"""Scheduler failed to start."""


class ValidationError(Exception):
"""Manifest validation exception"""

def __init__(self, message: str) -> None:
self.message = message
114 changes: 66 additions & 48 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import time
from collections import defaultdict
from contextlib import suppress
Expand All @@ -14,6 +15,9 @@
from importlib_metadata import entry_points
from kr8s.asyncio.objects import Deployment, Pod, Service

from dask_kubernetes.common.objects import validate_cluster_name
from dask_kubernetes.constants import SCHEDULER_NAME_TEMPLATE
from dask_kubernetes.exceptions import ValidationError
from dask_kubernetes.operator._objects import (
DaskAutoscaler,
DaskCluster,
Expand Down Expand Up @@ -75,18 +79,19 @@ def build_scheduler_deployment_spec(
}
)
metadata = {
"name": f"{cluster_name}-scheduler",
"name": SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
"labels": labels,
"annotations": annotations,
}
spec = {}
spec["replicas"] = 1
spec["selector"] = {
"matchLabels": labels,
}
spec["template"] = {
"metadata": metadata,
"spec": pod_spec,
spec = {
"replicas": 1,
"selector": {
"matchLabels": labels,
},
"template": {
"metadata": metadata,
"spec": pod_spec,
},
}
return {
"apiVersion": "apps/v1",
Expand All @@ -107,7 +112,7 @@ def build_scheduler_service_spec(cluster_name, spec, annotations, labels):
"apiVersion": "v1",
"kind": "Service",
"metadata": {
"name": f"{cluster_name}-scheduler",
"name": SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
"labels": labels,
"annotations": annotations,
},
Expand All @@ -132,38 +137,41 @@ def build_worker_deployment_spec(
"labels": labels,
"annotations": annotations,
}
spec = {}
spec["replicas"] = 1 # make_worker_spec returns dict with a replicas key?
spec["selector"] = {
"matchLabels": labels,
}
spec["template"] = {
"metadata": metadata,
"spec": pod_spec,
spec = {
"replicas": 1,
"selector": {
"matchLabels": labels,
},
"template": {
"metadata": metadata,
"spec": copy.deepcopy(pod_spec),
},
}
deployment_spec = {
"apiVersion": "apps/v1",
"kind": "Deployment",
"metadata": metadata,
"spec": spec,
}
env = [
{
"name": "DASK_WORKER_NAME",
"value": worker_name,
},
{
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
},
]
for i in range(len(deployment_spec["spec"]["template"]["spec"]["containers"])):
if "env" in deployment_spec["spec"]["template"]["spec"]["containers"][i]:
deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"].extend(
env
)
else:
deployment_spec["spec"]["template"]["spec"]["containers"][i]["env"] = env
worker_env = {
"name": "DASK_WORKER_NAME",
"value": worker_name,
}
scheduler_env = {
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
}
for container in deployment_spec["spec"]["template"]["spec"]["containers"]:
if "env" not in container:
container["env"] = [worker_env, scheduler_env]
continue

container_env_names = [env_item["name"] for env_item in container["env"]]

if "DASK_WORKER_NAME" not in container_env_names:
container["env"].append(worker_env)
if "DASK_SCHEDULER_ADDRESS" not in container_env_names:
container["env"].append(scheduler_env)
return deployment_spec


Expand All @@ -187,19 +195,21 @@ def build_job_pod_spec(job_name, cluster_name, namespace, spec, annotations, lab
"labels": labels,
"annotations": annotations,
},
"spec": spec,
"spec": copy.deepcopy(spec),
}
env = [
{
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
},
]
for i in range(len(pod_spec["spec"]["containers"])):
if "env" in pod_spec["spec"]["containers"][i]:
pod_spec["spec"]["containers"][i]["env"].extend(env)
else:
pod_spec["spec"]["containers"][i]["env"] = env
scheduler_env = {
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
}
for container in pod_spec["spec"]["containers"]:
if "env" not in container:
container["env"] = [scheduler_env]
continue

container_env_names = [env_item["name"] for env_item in container["env"]]

if "DASK_SCHEDULER_ADDRESS" not in container_env_names:
container["env"].append(scheduler_env)
return pod_spec


Expand Down Expand Up @@ -273,6 +283,12 @@ async def daskcluster_create(name, namespace, logger, patch, **kwargs):
This allows us to track that the operator is running.
"""
logger.info(f"DaskCluster {name} created in {namespace}.")
try:
validate_cluster_name(name)
except ValidationError as validation_exc:
patch.status["phase"] = "Error"
raise kopf.PermanentError(validation_exc.message)

patch.status["phase"] = "Created"


Expand Down Expand Up @@ -599,7 +615,9 @@ async def daskworkergroup_replica_update(
if workers_needed < 0:
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=f"{cluster_name}-scheduler",
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(
cluster_name=cluster_name
),
worker_group_name=name,
namespace=namespace,
logger=logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ metadata:
spec:
cluster: simple
worker:
replicas: 2
replicas: 1
spec:
containers:
- name: worker
Expand All @@ -23,3 +23,5 @@ spec:
env:
- name: WORKER_ENV
value: hello-world # We dont test the value, just the name
- name: DASK_WORKER_NAME
value: test-worker
Loading

0 comments on commit c0844d6

Please sign in to comment.