From 4890cce7e9c07c716fef3fc80da4bf6f6762013d Mon Sep 17 00:00:00 2001 From: Johanna Goergen Date: Thu, 29 Feb 2024 14:38:52 +0100 Subject: [PATCH] Move the cluster name validation into a common module, add it to KubeCluster init, and add tests --- dask_kubernetes/common/objects.py | 21 ++++++- dask_kubernetes/common/tests/test_objects.py | 26 +++++++- dask_kubernetes/constants.py | 10 +++ dask_kubernetes/exceptions.py | 7 +++ .../operator/controller/controller.py | 29 +++------ .../controller/tests/test_controller.py | 63 +++++++++++++++++-- .../operator/kubecluster/kubecluster.py | 2 + .../kubecluster/tests/test_kubecluster.py | 20 +++++- 8 files changed, 148 insertions(+), 30 deletions(-) diff --git a/dask_kubernetes/common/objects.py b/dask_kubernetes/common/objects.py index 44c90fd56..1eedeebd7 100644 --- a/dask_kubernetes/common/objects.py +++ b/dask_kubernetes/common/objects.py @@ -1,6 +1,7 @@ """ Convenience functions for creating pod templates. """ + import copy import json from collections import namedtuple @@ -8,7 +9,12 @@ from kubernetes import client from kubernetes.client.configuration import Configuration -from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME +from dask_kubernetes.constants import ( + KUBECLUSTER_CONTAINER_NAME, + MAX_CLUSTER_NAME_LEN, + VALID_CLUSTER_NAME, +) +from dask_kubernetes.exceptions import ValidationError _FakeResponse = namedtuple("_FakeResponse", ["data"]) @@ -365,3 +371,16 @@ def clean_pdb_template(pdb_template): pdb_template.spec.selector = client.V1LabelSelector() return pdb_template + + +def validate_cluster_name(cluster_name: str) -> None: + """Raise exception if cluster name is too long and/or has invalid characters""" + if not VALID_CLUSTER_NAME.match(cluster_name): + raise ValidationError( + message=( + f"The DaskCluster {cluster_name} is invalid: a lowercase RFC 1123 subdomain must " + "consist of lower case alphanumeric characters, '-' or '.', and must start " + "and end with an alphanumeric character. DaskCluster name must also be under " + f"{MAX_CLUSTER_NAME_LEN} characters." + ) + ) diff --git a/dask_kubernetes/common/tests/test_objects.py b/dask_kubernetes/common/tests/test_objects.py index 61ef99b3f..23e318961 100644 --- a/dask_kubernetes/common/tests/test_objects.py +++ b/dask_kubernetes/common/tests/test_objects.py @@ -1,5 +1,8 @@ -from dask_kubernetes.common.objects import make_pod_from_dict -from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME +import pytest + +from dask_kubernetes.common.objects import make_pod_from_dict, validate_cluster_name +from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME, MAX_CLUSTER_NAME_LEN +from dask_kubernetes.exceptions import ValidationError def test_make_pod_from_dict(): @@ -64,3 +67,22 @@ def test_make_pod_from_dict_default_container_name(): assert pod.spec.containers[0].name == "dask-0" assert pod.spec.containers[1].name == "sidecar" assert pod.spec.containers[2].name == "dask-2" + + +@pytest.mark.parametrize( + "cluster_name", + [ + (MAX_CLUSTER_NAME_LEN + 1) * "a", + "invalid.chars.in.name", + ], +) +def test_validate_cluster_name_raises_on_invalid_name( + cluster_name, +): + + with pytest.raises(ValidationError): + validate_cluster_name(cluster_name) + + +def test_validate_cluster_name_success_on_valid_name(): + assert validate_cluster_name("valid-cluster-name-123") is None diff --git a/dask_kubernetes/constants.py b/dask_kubernetes/constants.py index f22c804a5..5133e2de7 100644 --- a/dask_kubernetes/constants.py +++ b/dask_kubernetes/constants.py @@ -1 +1,11 @@ +import re + KUBECLUSTER_CONTAINER_NAME = "dask-container" +KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63 +SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler" +MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len( + SCHEDULER_NAME_TEMPLATE.format(cluster_name="") +) +VALID_CLUSTER_NAME = re.compile( + rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$" +) diff --git a/dask_kubernetes/exceptions.py b/dask_kubernetes/exceptions.py index d501aab48..dc107c8c3 100644 --- a/dask_kubernetes/exceptions.py +++ b/dask_kubernetes/exceptions.py @@ -4,3 +4,10 @@ class CrashLoopBackOffError(Exception): class SchedulerStartupError(Exception): """Scheduler failed to start.""" + + +class ValidationError(Exception): + """Manifest validation exception""" + + def __init__(self, message: str) -> None: + self.message = message diff --git a/dask_kubernetes/operator/controller/controller.py b/dask_kubernetes/operator/controller/controller.py index 9c622224f..3056f15bf 100644 --- a/dask_kubernetes/operator/controller/controller.py +++ b/dask_kubernetes/operator/controller/controller.py @@ -1,5 +1,4 @@ import asyncio -import re import time from collections import defaultdict from contextlib import suppress @@ -15,6 +14,9 @@ from importlib_metadata import entry_points from kr8s.asyncio.objects import Deployment, Pod, Service +from dask_kubernetes.common.objects import validate_cluster_name +from dask_kubernetes.constants import SCHEDULER_NAME_TEMPLATE +from dask_kubernetes.exceptions import ValidationError from dask_kubernetes.operator._objects import ( DaskAutoscaler, DaskCluster, @@ -32,19 +34,6 @@ KUBERNETES_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION = "kubernetes.dask.org/cooldown-until" -KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63 -SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler" -MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len( - SCHEDULER_NAME_TEMPLATE.format(cluster_name="") -) -VALID_CLUSTER_NAME = re.compile( - rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$" -) - - -def _validate_cluster_name(cluster_name: str) -> bool: - return bool(VALID_CLUSTER_NAME.match(cluster_name)) - # Load operator plugins from other packages PLUGINS = [] @@ -287,15 +276,11 @@ async def daskcluster_create(name, namespace, logger, patch, **kwargs): This allows us to track that the operator is running. """ logger.info(f"DaskCluster {name} created in {namespace}.") - - if not _validate_cluster_name(name): + try: + validate_cluster_name(name) + except ValidationError as validation_exc: patch.status["phase"] = "Error" - raise kopf.PermanentError( - f"The DaskCluster {name} is invalid: a lowercase RFC 1123 subdomain must " - "consist of lower case alphanumeric characters, '-' or '.', and must start " - "and end with an alphanumeric character. DaskCluster name must also be under " - f"{MAX_CLUSTER_NAME_LEN} characters." - ) + raise kopf.PermanentError(validation_exc.message) patch.status["phase"] = "Created" diff --git a/dask_kubernetes/operator/controller/tests/test_controller.py b/dask_kubernetes/operator/controller/tests/test_controller.py index 0bc3b0997..75ad3489a 100644 --- a/dask_kubernetes/operator/controller/tests/test_controller.py +++ b/dask_kubernetes/operator/controller/tests/test_controller.py @@ -11,6 +11,7 @@ from dask.distributed import Client from kr8s.asyncio.objects import Deployment, Pod, Service +from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN from dask_kubernetes.operator._objects import DaskCluster, DaskJob, DaskWorkerGroup from dask_kubernetes.operator.controller import ( KUBERNETES_DATETIME_FORMAT, @@ -22,17 +23,32 @@ _EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"} _EXPECTED_LABELS = {"test-label": "label-value"} +DEFAULT_CLUSTER_NAME = "simple" @pytest.fixture() -def gen_cluster(k8s_cluster, ns): +def gen_cluster_manifest(tmp_path): + def factory(cluster_name=DEFAULT_CLUSTER_NAME): + original_manifest_path = os.path.join(DIR, "resources", "simplecluster.yaml") + with open(original_manifest_path, "r") as original_manifest_file: + manifest = yaml.safe_load(original_manifest_file) + + manifest["metadata"]["name"] = cluster_name + new_manifest_path = tmp_path / "cluster.yaml" + new_manifest_path.write_text(yaml.safe_dump(manifest)) + return tmp_path + + return factory + + +@pytest.fixture() +def gen_cluster(k8s_cluster, ns, gen_cluster_manifest): """Yields an instantiated context manager for creating/deleting a simple cluster.""" @asynccontextmanager - async def cm(): - cluster_path = os.path.join(DIR, "resources", "simplecluster.yaml") - cluster_name = "simple" + async def cm(cluster_name=DEFAULT_CLUSTER_NAME): + cluster_path = gen_cluster_manifest(cluster_name) # Create cluster resource k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path) while cluster_name not in k8s_cluster.kubectl( @@ -687,3 +703,42 @@ async def test_object_dask_job(k8s_cluster, kopf_runner, gen_job): cluster = await job.cluster() assert isinstance(cluster, DaskCluster) + + +async def _get_cluster_status(k8s_cluster, ns, cluster_name): + """ + Will loop infinitely in search of non-falsey cluster status. + Make sure there is a timeout on any test which calls this. + """ + while True: + cluster_status = k8s_cluster.kubectl( + "get", + "-n", + ns, + "daskcluster.kubernetes.dask.org", + cluster_name, + "-o", + "jsonpath='{.status.phase}'", + ).strip("'") + if cluster_status: + return cluster_status + await asyncio.sleep(0.1) + + +@pytest.mark.timeout(180) +@pytest.mark.anyio +@pytest.mark.parametrize( + "cluster_name,expected_status", + [ + ("valid-name", "Created"), + ((MAX_CLUSTER_NAME_LEN + 1) * "a", "Error"), + ("invalid.chars.in.name", "Error"), + ], +) +async def test_create_cluster_validates_name( + cluster_name, expected_status, k8s_cluster, kopf_runner, gen_cluster +): + with kopf_runner: + async with gen_cluster(cluster_name=cluster_name) as (_, ns): + actual_status = await _get_cluster_status(k8s_cluster, ns, cluster_name) + assert expected_status == actual_status diff --git a/dask_kubernetes/operator/kubecluster/kubecluster.py b/dask_kubernetes/operator/kubecluster/kubecluster.py index d5891ab4a..5b2b540d0 100644 --- a/dask_kubernetes/operator/kubecluster/kubecluster.py +++ b/dask_kubernetes/operator/kubecluster/kubecluster.py @@ -29,6 +29,7 @@ from rich.table import Table from tornado.ioloop import IOLoop +from dask_kubernetes.common.objects import validate_cluster_name from dask_kubernetes.exceptions import CrashLoopBackOffError, SchedulerStartupError from dask_kubernetes.operator._objects import ( DaskAutoscaler, @@ -258,6 +259,7 @@ def __init__( name = name.format( user=getpass.getuser(), uuid=str(uuid.uuid4())[:10], **os.environ ) + validate_cluster_name(name) self._instances.add(self) self._rich_spinner = Spinner("dots", speed=0.5) self._startup_component_status: dict = {} diff --git a/dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py b/dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py index 931f24b90..662e711fa 100644 --- a/dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py +++ b/dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py @@ -2,7 +2,8 @@ from dask.distributed import Client from distributed.utils import TimeoutError -from dask_kubernetes.exceptions import SchedulerStartupError +from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN +from dask_kubernetes.exceptions import SchedulerStartupError, ValidationError from dask_kubernetes.operator import KubeCluster, make_cluster_spec @@ -202,3 +203,20 @@ def test_typo_resource_limits(ns): }, namespace=ns, ) + + +@pytest.mark.parametrize( + "cluster_name", + [ + (MAX_CLUSTER_NAME_LEN + 1) * "a", + "invalid.chars.in.name", + ], +) +def test_invalid_cluster_name_fails(cluster_name, kopf_runner, docker_image, ns): + with kopf_runner: + with pytest.raises(ValidationError): + KubeCluster( + name=cluster_name, + namespace=ns, + image=docker_image, + )