Skip to content

Commit

Permalink
Merge branch 'rank-placement' of https://github.com/dask/dask-mpi int…
Browse files Browse the repository at this point in the history
…o rank-placement
  • Loading branch information
kmpaul committed Oct 12, 2023
2 parents 82038c1 + 51f21e1 commit cc7bf1b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 3 additions & 1 deletion dask_mpi/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ async def run_scheduler(with_worker=False, with_client=False):
comm.Barrier()

if with_worker:
asyncio.get_event_loop().create_task(run_worker(with_client=with_client))
asyncio.get_event_loop().create_task(
run_worker(with_client=with_client)
)

elif with_client:
asyncio.get_event_loop().create_task(run_client())
Expand Down
6 changes: 4 additions & 2 deletions dask_mpi/tests/execute_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def client_func(m=4, c=1, s=0, x=True):
assert time() < start + 10
sleep(0.2)

actual_worker_ranks = set(v["name"] for k,v in c.scheduler_info()["workers"].items())
actual_worker_ranks = set(
v["name"] for k, v in c.scheduler_info()["workers"].items()
)
assert actual_worker_ranks == worker_ranks

for i in actual_worker_ranks:
Expand All @@ -32,7 +34,7 @@ def client_func(m=4, c=1, s=0, x=True):
parser.add_argument("-x", type=lambda v: v.lower() != "false", default=None)
kwargs = vars(parser.parse_args())

execute_kwargs = {k:v for k,v in kwargs.items() if v is not None}
execute_kwargs = {k: v for k, v in kwargs.items() if v is not None}
if "c" in execute_kwargs:
execute_kwargs["client_rank"] = execute_kwargs["c"]
if "s" in execute_kwargs:
Expand Down
6 changes: 4 additions & 2 deletions dask_mpi/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
(2, ["-c", "0", "-s", "0", "-x", "False"], 0),
(1, ["-c", "0", "-s", "0", "-x", "False"], 0),
(1, ["-c", "0", "-s", "0", "-x", "True"], 1),
]
],
)
def test_execute(mpisize, execute_args, retcode, mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "execute_script.py"
)

execute_args += ["-m", str(mpisize)]
p = subprocess.Popen(mpirun + ["-n", str(mpisize), sys.executable, script_file] + execute_args)
p = subprocess.Popen(
mpirun + ["-n", str(mpisize), sys.executable, script_file] + execute_args
)

p.communicate()
assert p.returncode == retcode

0 comments on commit cc7bf1b

Please sign in to comment.