From 6f364001a088c5dd0086af4e2bb811eef367b45b Mon Sep 17 00:00:00 2001 From: Jonas Dedden Date: Tue, 2 Apr 2024 14:23:46 +0200 Subject: [PATCH] Include test for additional worker group; test overriding of environment variables --- .../tests/resources/simpleworkergroup.yaml | 4 +- .../controller/tests/test_controller.py | 78 +++++++++++++++---- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml b/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml index cd7da0e92..e99ebe608 100644 --- a/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml +++ b/dask_kubernetes/operator/controller/tests/resources/simpleworkergroup.yaml @@ -5,7 +5,7 @@ metadata: spec: cluster: simple worker: - replicas: 2 + replicas: 1 spec: containers: - name: worker @@ -23,3 +23,5 @@ spec: env: - name: WORKER_ENV value: hello-world # We dont test the value, just the name + - name: DASK_WORKER_NAME + value: test-worker diff --git a/dask_kubernetes/operator/controller/tests/test_controller.py b/dask_kubernetes/operator/controller/tests/test_controller.py index 144865955..357ffaa49 100644 --- a/dask_kubernetes/operator/controller/tests/test_controller.py +++ b/dask_kubernetes/operator/controller/tests/test_controller.py @@ -95,6 +95,36 @@ async def cm(job_file): yield cm +@pytest.fixture() +def gen_worker_group(k8s_cluster, ns): + """Yields an instantiated context manager for creating/deleting a worker group.""" + + @asynccontextmanager + async def cm(worker_group_file): + worker_group_path = os.path.join(DIR, "resources", worker_group_file) + with open(worker_group_path) as f: + worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"] + + # Create cluster resource + k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path) + while worker_group_name not in k8s_cluster.kubectl( + "get", "daskworkergroups.kubernetes.dask.org", "-n", ns + ): + await asyncio.sleep(0.1) + + try: + yield worker_group_name, ns + finally: + # Test: remove the wait=True, because I think this is blocking the operator + k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path) + while worker_group_name in k8s_cluster.kubectl( + "get", "daskworkergroups.kubernetes.dask.org", "-n", ns + ): + await asyncio.sleep(0.1) + + yield cm + + def test_customresources(k8s_cluster): assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd") assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd") @@ -671,32 +701,50 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster): @pytest.mark.anyio -async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster): +async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster, gen_worker_group): with kopf_runner: - async with gen_cluster() as (cluster_name, ns): + async with ( + gen_cluster() as (cluster_name, ns), + gen_worker_group("simpleworkergroup.yaml") as (additional_workergroup_name, _), + ): cluster = await DaskCluster.get(cluster_name, namespace=ns) + additional_workergroup = await DaskWorkerGroup.get(additional_workergroup_name, namespace=ns) worker_groups = [] while not worker_groups: worker_groups = await cluster.worker_groups() await asyncio.sleep(0.1) assert len(worker_groups) == 1 # Just the default worker group - wg = worker_groups[0] - assert isinstance(wg, DaskWorkerGroup) + worker_groups = worker_groups + [additional_workergroup] - pods = [] - while not pods: - pods = await wg.pods() - await asyncio.sleep(0.1) - assert all([isinstance(p, Pod) for p in pods]) + for wg in worker_groups: + assert isinstance(wg, DaskWorkerGroup) - deployments = [] - while not deployments: - deployments = await wg.deployments() - await asyncio.sleep(0.1) - assert all([isinstance(d, Deployment) for d in deployments]) + deployments = [] + while not deployments: + deployments = await wg.deployments() + await asyncio.sleep(0.1) + assert all([isinstance(d, Deployment) for d in deployments]) - assert (await wg.cluster()).name == cluster.name + pods = [] + while not pods: + pods = await wg.pods() + await asyncio.sleep(0.1) + assert all([isinstance(p, Pod) for p in pods]) + + assert (await wg.cluster()).name == cluster.name + + for deployment in deployments: + assert deployment.labels["dask.org/cluster-name"] == cluster.name + for env in deployment.spec["template"]["spec"]["containers"][0]["env"]: + if env["name"] == "DASK_WORKER_NAME": + if wg.name == additional_workergroup_name: + assert env["value"] == "test-worker" + else: + assert env["value"] == deployment.name + if env["name"] == "DASK_SCHEDULER_ADDRESS": + scheduler_service = await cluster.scheduler_service() + assert f"{scheduler_service.name}.{ns}" in env["value"] @pytest.mark.anyio