Skip to content

Commit

Permalink
Include test for additional worker group; test overriding of environm…
Browse files Browse the repository at this point in the history
…ent variables
  • Loading branch information
Jonas Dedden committed Apr 2, 2024
1 parent a0f29cc commit 6f36400
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ metadata:
spec:
cluster: simple
worker:
replicas: 2
replicas: 1
spec:
containers:
- name: worker
Expand All @@ -23,3 +23,5 @@ spec:
env:
- name: WORKER_ENV
value: hello-world # We dont test the value, just the name
- name: DASK_WORKER_NAME
value: test-worker
78 changes: 63 additions & 15 deletions dask_kubernetes/operator/controller/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,36 @@ async def cm(job_file):
yield cm


@pytest.fixture()
def gen_worker_group(k8s_cluster, ns):
"""Yields an instantiated context manager for creating/deleting a worker group."""

@asynccontextmanager
async def cm(worker_group_file):
worker_group_path = os.path.join(DIR, "resources", worker_group_file)
with open(worker_group_path) as f:
worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"]

# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path)
while worker_group_name not in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

try:
yield worker_group_name, ns
finally:
# Test: remove the wait=True, because I think this is blocking the operator
k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path)
while worker_group_name in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

yield cm


def test_customresources(k8s_cluster):
assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
Expand Down Expand Up @@ -671,32 +701,50 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster):


@pytest.mark.anyio
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster):
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster, gen_worker_group):
with kopf_runner:
async with gen_cluster() as (cluster_name, ns):
async with (
gen_cluster() as (cluster_name, ns),
gen_worker_group("simpleworkergroup.yaml") as (additional_workergroup_name, _),
):
cluster = await DaskCluster.get(cluster_name, namespace=ns)
additional_workergroup = await DaskWorkerGroup.get(additional_workergroup_name, namespace=ns)

worker_groups = []
while not worker_groups:
worker_groups = await cluster.worker_groups()
await asyncio.sleep(0.1)
assert len(worker_groups) == 1 # Just the default worker group
wg = worker_groups[0]
assert isinstance(wg, DaskWorkerGroup)
worker_groups = worker_groups + [additional_workergroup]

pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])
for wg in worker_groups:
assert isinstance(wg, DaskWorkerGroup)

deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])
deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])

assert (await wg.cluster()).name == cluster.name
pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])

assert (await wg.cluster()).name == cluster.name

for deployment in deployments:
assert deployment.labels["dask.org/cluster-name"] == cluster.name
for env in deployment.spec["template"]["spec"]["containers"][0]["env"]:
if env["name"] == "DASK_WORKER_NAME":
if wg.name == additional_workergroup_name:
assert env["value"] == "test-worker"
else:
assert env["value"] == deployment.name
if env["name"] == "DASK_SCHEDULER_ADDRESS":
scheduler_service = await cluster.scheduler_service()
assert f"{scheduler_service.name}.{ns}" in env["value"]


@pytest.mark.anyio
Expand Down

0 comments on commit 6f36400

Please sign in to comment.