Skip to content

Commit

Permalink
Move the cluster name validation into a common module, add it to Kube…
Browse files Browse the repository at this point in the history
…Cluster init, and add tests
  • Loading branch information
Johanna Goergen committed Feb 29, 2024
1 parent bcad9aa commit 4890cce
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 30 deletions.
21 changes: 20 additions & 1 deletion dask_kubernetes/common/objects.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
Convenience functions for creating pod templates.
"""

import copy
import json
from collections import namedtuple

from kubernetes import client
from kubernetes.client.configuration import Configuration

from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
from dask_kubernetes.constants import (
KUBECLUSTER_CONTAINER_NAME,
MAX_CLUSTER_NAME_LEN,
VALID_CLUSTER_NAME,
)
from dask_kubernetes.exceptions import ValidationError

_FakeResponse = namedtuple("_FakeResponse", ["data"])

Expand Down Expand Up @@ -365,3 +371,16 @@ def clean_pdb_template(pdb_template):
pdb_template.spec.selector = client.V1LabelSelector()

return pdb_template


def validate_cluster_name(cluster_name: str) -> None:
"""Raise exception if cluster name is too long and/or has invalid characters"""
if not VALID_CLUSTER_NAME.match(cluster_name):
raise ValidationError(
message=(
f"The DaskCluster {cluster_name} is invalid: a lowercase RFC 1123 subdomain must "
"consist of lower case alphanumeric characters, '-' or '.', and must start "
"and end with an alphanumeric character. DaskCluster name must also be under "
f"{MAX_CLUSTER_NAME_LEN} characters."
)
)
26 changes: 24 additions & 2 deletions dask_kubernetes/common/tests/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dask_kubernetes.common.objects import make_pod_from_dict
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
import pytest

from dask_kubernetes.common.objects import make_pod_from_dict, validate_cluster_name
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME, MAX_CLUSTER_NAME_LEN
from dask_kubernetes.exceptions import ValidationError


def test_make_pod_from_dict():
Expand Down Expand Up @@ -64,3 +67,22 @@ def test_make_pod_from_dict_default_container_name():
assert pod.spec.containers[0].name == "dask-0"
assert pod.spec.containers[1].name == "sidecar"
assert pod.spec.containers[2].name == "dask-2"


@pytest.mark.parametrize(
"cluster_name",
[
(MAX_CLUSTER_NAME_LEN + 1) * "a",
"invalid.chars.in.name",
],
)
def test_validate_cluster_name_raises_on_invalid_name(
cluster_name,
):

with pytest.raises(ValidationError):
validate_cluster_name(cluster_name)


def test_validate_cluster_name_success_on_valid_name():
assert validate_cluster_name("valid-cluster-name-123") is None
10 changes: 10 additions & 0 deletions dask_kubernetes/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
import re

KUBECLUSTER_CONTAINER_NAME = "dask-container"
KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63
SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler"
MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len(
SCHEDULER_NAME_TEMPLATE.format(cluster_name="")
)
VALID_CLUSTER_NAME = re.compile(
rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$"
)
7 changes: 7 additions & 0 deletions dask_kubernetes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ class CrashLoopBackOffError(Exception):

class SchedulerStartupError(Exception):
"""Scheduler failed to start."""


class ValidationError(Exception):
"""Manifest validation exception"""

def __init__(self, message: str) -> None:
self.message = message
29 changes: 7 additions & 22 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import re
import time
from collections import defaultdict
from contextlib import suppress
Expand All @@ -15,6 +14,9 @@
from importlib_metadata import entry_points
from kr8s.asyncio.objects import Deployment, Pod, Service

from dask_kubernetes.common.objects import validate_cluster_name
from dask_kubernetes.constants import SCHEDULER_NAME_TEMPLATE
from dask_kubernetes.exceptions import ValidationError
from dask_kubernetes.operator._objects import (
DaskAutoscaler,
DaskCluster,
Expand All @@ -32,19 +34,6 @@
KUBERNETES_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION = "kubernetes.dask.org/cooldown-until"
KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63
SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler"
MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len(
SCHEDULER_NAME_TEMPLATE.format(cluster_name="")
)
VALID_CLUSTER_NAME = re.compile(
rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$"
)


def _validate_cluster_name(cluster_name: str) -> bool:
return bool(VALID_CLUSTER_NAME.match(cluster_name))


# Load operator plugins from other packages
PLUGINS = []
Expand Down Expand Up @@ -287,15 +276,11 @@ async def daskcluster_create(name, namespace, logger, patch, **kwargs):
This allows us to track that the operator is running.
"""
logger.info(f"DaskCluster {name} created in {namespace}.")

if not _validate_cluster_name(name):
try:
validate_cluster_name(name)
except ValidationError as validation_exc:
patch.status["phase"] = "Error"
raise kopf.PermanentError(
f"The DaskCluster {name} is invalid: a lowercase RFC 1123 subdomain must "
"consist of lower case alphanumeric characters, '-' or '.', and must start "
"and end with an alphanumeric character. DaskCluster name must also be under "
f"{MAX_CLUSTER_NAME_LEN} characters."
)
raise kopf.PermanentError(validation_exc.message)

patch.status["phase"] = "Created"

Expand Down
63 changes: 59 additions & 4 deletions dask_kubernetes/operator/controller/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dask.distributed import Client
from kr8s.asyncio.objects import Deployment, Pod, Service

from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN
from dask_kubernetes.operator._objects import DaskCluster, DaskJob, DaskWorkerGroup
from dask_kubernetes.operator.controller import (
KUBERNETES_DATETIME_FORMAT,
Expand All @@ -22,17 +23,32 @@

_EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"}
_EXPECTED_LABELS = {"test-label": "label-value"}
DEFAULT_CLUSTER_NAME = "simple"


@pytest.fixture()
def gen_cluster(k8s_cluster, ns):
def gen_cluster_manifest(tmp_path):
def factory(cluster_name=DEFAULT_CLUSTER_NAME):
original_manifest_path = os.path.join(DIR, "resources", "simplecluster.yaml")
with open(original_manifest_path, "r") as original_manifest_file:
manifest = yaml.safe_load(original_manifest_file)

manifest["metadata"]["name"] = cluster_name
new_manifest_path = tmp_path / "cluster.yaml"
new_manifest_path.write_text(yaml.safe_dump(manifest))
return tmp_path

return factory


@pytest.fixture()
def gen_cluster(k8s_cluster, ns, gen_cluster_manifest):
"""Yields an instantiated context manager for creating/deleting a simple cluster."""

@asynccontextmanager
async def cm():
cluster_path = os.path.join(DIR, "resources", "simplecluster.yaml")
cluster_name = "simple"
async def cm(cluster_name=DEFAULT_CLUSTER_NAME):

cluster_path = gen_cluster_manifest(cluster_name)
# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path)
while cluster_name not in k8s_cluster.kubectl(
Expand Down Expand Up @@ -687,3 +703,42 @@ async def test_object_dask_job(k8s_cluster, kopf_runner, gen_job):

cluster = await job.cluster()
assert isinstance(cluster, DaskCluster)


async def _get_cluster_status(k8s_cluster, ns, cluster_name):
"""
Will loop infinitely in search of non-falsey cluster status.
Make sure there is a timeout on any test which calls this.
"""
while True:
cluster_status = k8s_cluster.kubectl(
"get",
"-n",
ns,
"daskcluster.kubernetes.dask.org",
cluster_name,
"-o",
"jsonpath='{.status.phase}'",
).strip("'")
if cluster_status:
return cluster_status
await asyncio.sleep(0.1)


@pytest.mark.timeout(180)
@pytest.mark.anyio
@pytest.mark.parametrize(
"cluster_name,expected_status",
[
("valid-name", "Created"),
((MAX_CLUSTER_NAME_LEN + 1) * "a", "Error"),
("invalid.chars.in.name", "Error"),
],
)
async def test_create_cluster_validates_name(
cluster_name, expected_status, k8s_cluster, kopf_runner, gen_cluster
):
with kopf_runner:
async with gen_cluster(cluster_name=cluster_name) as (_, ns):
actual_status = await _get_cluster_status(k8s_cluster, ns, cluster_name)
assert expected_status == actual_status
2 changes: 2 additions & 0 deletions dask_kubernetes/operator/kubecluster/kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rich.table import Table
from tornado.ioloop import IOLoop

from dask_kubernetes.common.objects import validate_cluster_name
from dask_kubernetes.exceptions import CrashLoopBackOffError, SchedulerStartupError
from dask_kubernetes.operator._objects import (
DaskAutoscaler,
Expand Down Expand Up @@ -258,6 +259,7 @@ def __init__(
name = name.format(
user=getpass.getuser(), uuid=str(uuid.uuid4())[:10], **os.environ
)
validate_cluster_name(name)
self._instances.add(self)
self._rich_spinner = Spinner("dots", speed=0.5)
self._startup_component_status: dict = {}
Expand Down
20 changes: 19 additions & 1 deletion dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dask.distributed import Client
from distributed.utils import TimeoutError

from dask_kubernetes.exceptions import SchedulerStartupError
from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN
from dask_kubernetes.exceptions import SchedulerStartupError, ValidationError
from dask_kubernetes.operator import KubeCluster, make_cluster_spec


Expand Down Expand Up @@ -202,3 +203,20 @@ def test_typo_resource_limits(ns):
},
namespace=ns,
)


@pytest.mark.parametrize(
"cluster_name",
[
(MAX_CLUSTER_NAME_LEN + 1) * "a",
"invalid.chars.in.name",
],
)
def test_invalid_cluster_name_fails(cluster_name, kopf_runner, docker_image, ns):
with kopf_runner:
with pytest.raises(ValidationError):
KubeCluster(
name=cluster_name,
namespace=ns,
image=docker_image,
)

0 comments on commit 4890cce

Please sign in to comment.