Skip to content

Commit

Permalink
Add execute tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmpaul committed Oct 12, 2023
1 parent 9341961 commit c61ec65
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 53 deletions.
21 changes: 0 additions & 21 deletions dask_mpi/tests/execute_basic.py

This file was deleted.

43 changes: 43 additions & 0 deletions dask_mpi/tests/execute_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from argparse import ArgumentParser
from time import sleep

from distributed import Client
from distributed.metrics import time

from dask_mpi import execute


def client_func(m=4, c=1, s=0, x=True):
xranks = {c, s} if x else set()
worker_ranks = set(i for i in range(m) if i not in xranks)

with Client() as c:
start = time()
while len(c.scheduler_info()["workers"]) != len(worker_ranks):
assert time() < start + 10
sleep(0.2)

actual_worker_ranks = set(v["name"] for k,v in c.scheduler_info()["workers"].items())
assert actual_worker_ranks == worker_ranks

for i in actual_worker_ranks:
assert c.submit(lambda x: x + 1, 10, workers=i).result() == 11


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-m", type=int, default=None)
parser.add_argument("-c", type=int, default=None)
parser.add_argument("-s", type=int, default=None)
parser.add_argument("-x", type=lambda v: v.lower() != "false", default=None)
kwargs = vars(parser.parse_args())

execute_kwargs = {k:v for k,v in kwargs.items() if v is not None}
if "c" in execute_kwargs:
execute_kwargs["client_rank"] = execute_kwargs["c"]
if "s" in execute_kwargs:
execute_kwargs["scheduler_rank"] = execute_kwargs["s"]
if "x" in execute_kwargs:
execute_kwargs["exclusive_workers"] = execute_kwargs["x"]

execute(client_func, **execute_kwargs)
34 changes: 34 additions & 0 deletions dask_mpi/tests/test_execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import absolute_import, division, print_function

import os
import subprocess
import sys

import pytest

pytest.importorskip("mpi4py")


@pytest.mark.parametrize(
"mpisize,execute_args,retcode",
[
(4, [], 0),
(1, [], 1), # Set too few processes to start cluster
(4, ["-c", "2", "-s", "3"], 0),
(5, ["-s", "3"], 0),
(3, ["-c", "2", "-s", "2"], 0),
(2, ["-c", "0", "-s", "0", "-x", "False"], 0),
(1, ["-c", "0", "-s", "0", "-x", "False"], 0),
(1, ["-c", "0", "-s", "0", "-x", "True"], 1),
]
)
def test_execute(mpisize, execute_args, retcode, mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "execute_script.py"
)

execute_args += ["-m", str(mpisize)]
p = subprocess.Popen(mpirun + ["-n", str(mpisize), sys.executable, script_file] + execute_args)

p.communicate()
assert p.returncode == retcode
32 changes: 0 additions & 32 deletions dask_mpi/tests/test_execute_basic.py

This file was deleted.

0 comments on commit c61ec65

Please sign in to comment.