diff --git a/.gitignore b/.gitignore index 5fc18e6..3227925 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,7 @@ global.lock purge.lock /temp/ /dask-worker-space/ +/dask-scratch-space/ # VSCode files .vscode/ diff --git a/dask_mpi/__init__.py b/dask_mpi/__init__.py index c2ec8cb..43ece32 100644 --- a/dask_mpi/__init__.py +++ b/dask_mpi/__init__.py @@ -1,5 +1,6 @@ -from . import _version -from .core import initialize, send_close_signal -from .exceptions import WorldTooSmallException +from ._version import get_versions +from .exceptions import WorldTooSmallException # noqa +from .execute import execute, send_close_signal # noqa +from .initialize import initialize # noqa __version__ = _version.get_versions()["version"] diff --git a/dask_mpi/cli.py b/dask_mpi/cli.py index 2cccdfd..fc2d848 100644 --- a/dask_mpi/cli.py +++ b/dask_mpi/cli.py @@ -23,6 +23,12 @@ type=int, help="Specify scheduler port number. Defaults to random.", ) +@click.option( + "--scheduler-rank", + default=0, + type=int, + help="The MPI rank on which the scheduler will launch. Defaults to 0.", +) @click.option( "--interface", type=str, default=None, help="Network interface like 'eth0' or 'ib0'" ) @@ -56,6 +62,14 @@ default=True, help="Start workers in nanny process for management (deprecated use --worker-class instead)", ) +@click.option( + "--exclusive-workers/--inclusive-workers", + default=True, + help=( + "Whether to force workers to run on unoccupied MPI ranks. If false, " + "then a worker will be launched on the same rank as the scheduler." + ), +) @click.option( "--worker-class", type=str, @@ -90,27 +104,30 @@ def main( scheduler_address, scheduler_file, + scheduler_port, + scheduler_rank, interface, + protocol, nthreads, - local_directory, memory_limit, + local_directory, scheduler, dashboard, dashboard_address, nanny, + exclusive_workers, worker_class, worker_options, - scheduler_port, - protocol, name, ): comm = MPI.COMM_WORLD world_size = comm.Get_size() - if scheduler and world_size < 2: + min_world_size = 1 + scheduler * max(scheduler_rank, exclusive_workers) + if world_size < min_world_size: raise WorldTooSmallException( - f"Not enough MPI ranks to start cluster, found {world_size}, " - "needs at least 2, one each for the scheduler and a worker." + f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and " + f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}." ) rank = comm.Get_rank() @@ -120,47 +137,52 @@ def main( except TypeError: worker_options = {} - if rank == 0 and scheduler: + async def run_worker(): + WorkerType = import_term(worker_class) + if not nanny: + WorkerType = Worker + raise DeprecationWarning( + "Option --no-nanny is deprectaed, use --worker-class instead" + ) + opts = { + "interface": interface, + "protocol": protocol, + "nthreads": nthreads, + "memory_limit": memory_limit, + "local_directory": local_directory, + "name": f"{name}-{rank}", + "scheduler_file": scheduler_file, + **worker_options, + } + if scheduler_address: + opts["scheduler_ip"] = scheduler_address - async def run_scheduler(): - async with Scheduler( - interface=interface, - protocol=protocol, - dashboard=dashboard, - dashboard_address=dashboard_address, - scheduler_file=scheduler_file, - port=scheduler_port, - ) as s: - comm.Barrier() - await s.finished() + async with WorkerType(**opts) as worker: + await worker.finished() - asyncio.get_event_loop().run_until_complete(run_scheduler()) + async def run_scheduler(launch_worker=False): + async with Scheduler( + interface=interface, + protocol=protocol, + dashboard=dashboard, + dashboard_address=dashboard_address, + scheduler_file=scheduler_file, + port=scheduler_port, + ) as scheduler: + comm.Barrier() + if launch_worker: + asyncio.get_event_loop().create_task(run_worker()) + + await scheduler.finished() + + if rank == scheduler_rank and scheduler: + asyncio.get_event_loop().run_until_complete( + run_scheduler(launch_worker=not exclusive_workers) + ) else: comm.Barrier() - async def run_worker(): - WorkerType = import_term(worker_class) - if not nanny: - raise DeprecationWarning( - "Option --no-nanny is deprectaed, use --worker-class instead" - ) - WorkerType = Worker - opts = { - "interface": interface, - "protocol": protocol, - "nthreads": nthreads, - "memory_limit": memory_limit, - "local_directory": local_directory, - "name": f"{name}-{rank}", - "scheduler_file": scheduler_file, - **worker_options, - } - if scheduler_address: - opts["scheduler_ip"] = scheduler_address - async with WorkerType(**opts) as worker: - await worker.finished() - asyncio.get_event_loop().run_until_complete(run_worker()) diff --git a/dask_mpi/execute.py b/dask_mpi/execute.py new file mode 100644 index 0000000..fb6df43 --- /dev/null +++ b/dask_mpi/execute.py @@ -0,0 +1,210 @@ +import asyncio +import threading + +import dask +from distributed import Client, Nanny, Scheduler +from distributed.utils import import_term + +from .exceptions import WorldTooSmallException + + +def execute( + client_function=None, + client_args=(), + client_kwargs=None, + client_rank=1, + scheduler=True, + scheduler_rank=0, + scheduler_address=None, + scheduler_port=None, + scheduler_file=None, + interface=None, + nthreads=1, + local_directory="", + memory_limit="auto", + nanny=False, + dashboard=True, + dashboard_address=":8787", + protocol=None, + exclusive_workers=True, + worker_class="distributed.Worker", + worker_options=None, + worker_name=None, + comm=None, +): + """ + Execute a function on a given MPI rank with a Dask cluster launched using mpi4py + + Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the + client script, and all other MPI ranks launch workers. All MPI ranks other than + MPI rank 1 block while their event loops run. + + In normal operation these ranks exit once rank 1 ends. If exit=False is set they + instead return an bool indicating whether they are the client and should execute + more client code, or a worker/scheduler who should not. In this case the user is + responsible for the client calling send_close_signal when work is complete, and + checking the returned value to choose further actions. + + Parameters + ---------- + func : callable + A function containing Dask client code to execute with a Dask cluster. If + func it not callable, then no client code will be executed. + args : list + Arguments to the client function + client_rank : int + The MPI rank on which to run func. + scheduler_rank : int + The MPI rank on which to run the Dask scheduler + scheduler_address : str + IP Address of the scheduler, used if scheduler is not launched + scheduler_port : int + Specify scheduler port number. Defaults to random. + scheduler_file : str + Filename to JSON encoded scheduler information. + interface : str + Network interface like 'eth0' or 'ib0' + nthreads : int + Number of threads per worker + local_directory : str + Directory to place worker files + memory_limit : int, float, or 'auto' + Number of bytes before spilling data to disk. This can be an + integer (nbytes), float (fraction of total memory), or 'auto'. + nanny : bool + Start workers in nanny process for management (deprecated, use worker_class instead) + dashboard : bool + Enable Bokeh visual diagnostics + dashboard_address : str + Bokeh port for visual diagnostics + protocol : str + Protocol like 'inproc' or 'tcp' + exclusive_workers : bool + Whether to only run Dask workers on their own MPI ranks + worker_class : str + Class to use when creating workers + worker_options : dict + Options to pass to workers + worker_name : str + Prefix for name given to workers. If defined, each worker will be named + '{worker_name}-{rank}'. Otherwise, the name of each worker is just '{rank}'. + comm : mpi4py.MPI.Intracomm + Optional MPI communicator to use instead of COMM_WORLD + kwargs : dict + Keyword arguments to the client function + """ + if comm is None: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + + world_size = comm.Get_size() + min_world_size = 1 + max(client_rank, scheduler_rank, exclusive_workers) + if world_size < min_world_size: + raise WorldTooSmallException( + f"Not enough MPI ranks to start cluster with exclusive_workers={exclusive_workers} and " + f"scheduler_rank={scheduler_rank}, found {world_size} MPI ranks but needs {min_world_size}." + ) + + rank = comm.Get_rank() + + if not worker_options: + worker_options = {} + + async def run_client(): + def wrapped_function(*args, **kwargs): + client_function(*args, **kwargs) + send_close_signal() + + threading.Thread( + target=wrapped_function, args=client_args, kwargs=client_kwargs + ).start() + + async def run_worker(with_client=False): + WorkerType = import_term(worker_class) + if nanny: + WorkerType = Nanny + raise DeprecationWarning( + "Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead" + ) + opts = { + "interface": interface, + "protocol": protocol, + "nthreads": nthreads, + "memory_limit": memory_limit, + "local_directory": local_directory, + "name": rank if not worker_name else f"{worker_name}-{rank}", + **worker_options, + } + if not scheduler and scheduler_address: + opts["scheduler_ip"] = scheduler_address + async with WorkerType(**opts) as worker: + if with_client: + asyncio.get_event_loop().create_task(run_client()) + + await worker.finished() + + async def run_scheduler(with_worker=False, with_client=False): + async with Scheduler( + interface=interface, + protocol=protocol, + dashboard=dashboard, + dashboard_address=dashboard_address, + scheduler_file=scheduler_file, + port=scheduler_port, + ) as scheduler: + dask.config.set(scheduler_address=scheduler.address) + comm.bcast(scheduler.address, root=scheduler_rank) + comm.Barrier() + + if with_worker: + asyncio.get_event_loop().create_task( + run_worker(with_client=with_client) + ) + + elif with_client: + asyncio.get_event_loop().create_task(run_client()) + + await scheduler.finished() + + with_scheduler = scheduler and (rank == scheduler_rank) + with_client = callable(client_function) and (rank == client_rank) + + if with_scheduler: + run_coro = run_scheduler( + with_worker=not exclusive_workers, + with_client=with_client, + ) + + else: + if scheduler: + scheduler_address = comm.bcast(None, root=scheduler_rank) + elif scheduler_address is None: + raise ValueError( + "Must provide scheduler_address if executing with scheduler=False" + ) + dask.config.set(scheduler_address=scheduler_address) + comm.Barrier() + + if with_client and exclusive_workers: + run_coro = run_client() + else: + run_coro = run_worker(with_client=with_client) + + asyncio.get_event_loop().run_until_complete(run_coro) + + +def send_close_signal(): + """ + The client can call this function to explicitly stop + the event loop. + + This is not needed in normal usage, where it is run + automatically when the client code exits python. + + You only need to call this manually when using exit=False + in initialize. + """ + + with Client() as c: + c.shutdown() diff --git a/dask_mpi/core.py b/dask_mpi/initialize.py similarity index 91% rename from dask_mpi/core.py rename to dask_mpi/initialize.py index e133a42..63c8219 100644 --- a/dask_mpi/core.py +++ b/dask_mpi/initialize.py @@ -3,10 +3,11 @@ import sys import dask -from distributed import Client, Nanny, Scheduler +from distributed import Nanny, Scheduler from distributed.utils import import_term from .exceptions import WorldTooSmallException +from .execute import send_close_signal def initialize( @@ -121,10 +122,10 @@ async def run_scheduler(): async def run_worker(): WorkerType = import_term(worker_class) if nanny: + WorkerType = Nanny raise DeprecationWarning( "Option nanny=True is deprectaed, use worker_class='distributed.Nanny' instead" ) - WorkerType = Nanny opts = { "interface": interface, "protocol": protocol, @@ -142,19 +143,3 @@ async def run_worker(): sys.exit() else: return False - - -def send_close_signal(): - """ - The client can call this function to explicitly stop - the event loop. - - This is not needed in normal usage, where it is run - automatically when the client code exits python. - - You only need to call this manually when using exit=False - in initialize. - """ - - with Client() as c: - c.shutdown() diff --git a/dask_mpi/tests/execute_basic.py b/dask_mpi/tests/execute_basic.py new file mode 100644 index 0000000..7af4e9f --- /dev/null +++ b/dask_mpi/tests/execute_basic.py @@ -0,0 +1,40 @@ +import sys +from time import sleep + +from distributed import Client +from distributed.metrics import time + +from dask_mpi import execute + + +def client_func(m, c, s, x): + xranks = {c, s} if x else set() + worker_ranks = set(i for i in range(m) if i not in xranks) + + with Client() as c: + start = time() + while len(c.scheduler_info()["workers"]) != len(worker_ranks): + assert time() < start + 10 + sleep(0.2) + + actual_worker_ranks = set( + v["name"] for k, v in c.scheduler_info()["workers"].items() + ) + assert actual_worker_ranks == worker_ranks + + for i in actual_worker_ranks: + assert c.submit(lambda x: x + 1, 10, workers=i).result() == 11 + + +if __name__ == "__main__": + vmap = {"True": True, "False": False, "None": None} + int_or_bool = lambda s: vmap[s] if s in vmap else int(s) + args = [int_or_bool(i) for i in sys.argv[1:]] + + execute( + client_function=client_func, + client_args=args, + client_rank=args[1], + scheduler_rank=args[2], + exclusive_workers=args[3], + ) diff --git a/dask_mpi/tests/execute_no_exit.py b/dask_mpi/tests/execute_no_exit.py new file mode 100644 index 0000000..0735684 --- /dev/null +++ b/dask_mpi/tests/execute_no_exit.py @@ -0,0 +1,31 @@ +from time import sleep, time + +from distributed import Client +from mpi4py.MPI import COMM_WORLD as world + +from dask_mpi import execute + +# Split our MPI world into two pieces, one consisting just of +# the old rank 3 process and the other with everything else +new_comm_assignment = 1 if world.rank == 3 else 0 +comm = world.Split(new_comm_assignment) + +if world.rank != 3: + + def client_code(): + with Client() as c: + start = time() + while len(c.scheduler_info()["workers"]) != 1: + assert time() < start + 10 + sleep(0.2) + + c.submit(lambda x: x + 1, 10).result() == 11 + c.submit(lambda x: x + 1, 20).result() == 21 + + execute(client_code, comm=comm) + +# check that our original comm is intact +world.Barrier() +x = 100 if world.rank == 0 else 200 +x = world.bcast(x) +assert x == 100 diff --git a/dask_mpi/tests/core_basic.py b/dask_mpi/tests/initialize_basic.py similarity index 100% rename from dask_mpi/tests/core_basic.py rename to dask_mpi/tests/initialize_basic.py diff --git a/dask_mpi/tests/core_no_exit.py b/dask_mpi/tests/initialize_no_exit.py similarity index 100% rename from dask_mpi/tests/core_no_exit.py rename to dask_mpi/tests/initialize_no_exit.py diff --git a/dask_mpi/tests/test_cli.py b/dask_mpi/tests/test_cli.py index e10e370..4a2f754 100644 --- a/dask_mpi/tests/test_cli.py +++ b/dask_mpi/tests/test_cli.py @@ -8,10 +8,11 @@ import pytest import requests +from dask.utils import tmpfile from distributed import Client from distributed.comm.addressing import get_address_host_port from distributed.metrics import time -from distributed.utils import import_term, tmpfile +from distributed.utils import import_term from distributed.utils_test import cleanup, loop, loop_in_thread, popen # noqa: F401 pytest.importorskip("mpi4py") @@ -51,6 +52,27 @@ def test_basic(loop, worker_class, mpirun): assert c.submit(lambda x: x + 1, 10).result() == 11 +def test_inclusive_workers(loop, mpirun): + with tmpfile(extension="json") as fn: + cmd = mpirun + [ + "-np", + "4", + "dask-mpi", + "--scheduler-file", + fn, + "--inclusive-workers", + ] + + with popen(cmd): + with Client(scheduler_file=fn) as client: + start = time() + while len(client.scheduler_info()["workers"]) < 4: + assert time() < start + 10 + sleep(0.1) + + assert client.submit(lambda x: x + 1, 10).result() == 11 + + def test_small_world(mpirun): with tmpfile(extension="json") as fn: # Set too few processes to start cluster @@ -69,6 +91,27 @@ def test_small_world(mpirun): assert p.returncode != 0 +def test_inclusive_small_world(mpirun): + with tmpfile(extension="json") as fn: + cmd = mpirun + [ + "-np", + "1", + "dask-mpi", + "--scheduler-file", + fn, + "--inclusive-workers", + ] + + with popen(cmd): + with Client(scheduler_file=fn) as client: + start = time() + while len(client.scheduler_info()["workers"]) < 1: + assert time() < start + 10 + sleep(0.1) + + assert client.submit(lambda x: x + 1, 10).result() == 11 + + def test_no_scheduler(loop, mpirun): with tmpfile(extension="json") as fn: cmd = mpirun + ["-np", "2", "dask-mpi", "--scheduler-file", fn] @@ -98,6 +141,35 @@ def test_no_scheduler(loop, mpirun): sleep(0.2) +def test_scheduler_rank(loop, mpirun): + with tmpfile(extension="json") as fn: + cmd = mpirun + [ + "-np", + "2", + "dask-mpi", + "--scheduler-file", + fn, + "--exclusive-workers", + "--scheduler-rank", + "1", + ] + + with popen(cmd, stdin=FNULL): + with Client(scheduler_file=fn) as client: + start = time() + while len(client.scheduler_info()["workers"]) < 1: + assert time() < start + 10 + sleep(0.2) + + worker_infos = client.scheduler_info()["workers"] + assert len(worker_infos) == 1 + + worker_info = next(iter(worker_infos.values())) + assert worker_info["name"].rsplit("-")[-1] == "0" + + assert client.submit(lambda x: x + 1, 10).result() == 11 + + @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_non_default_ports(loop, nanny, mpirun): with tmpfile(extension="json") as fn: @@ -150,23 +222,6 @@ def test_dashboard(loop, mpirun): requests.get("http://localhost:59583/status/") -@pytest.mark.skip(reason="Should we expose this option?") -def test_bokeh_worker(loop, mpirun): - with tmpfile(extension="json") as fn: - cmd = mpirun + [ - "-np", - "2", - "dask-mpi", - "--scheduler-file", - fn, - "--bokeh-worker-port", - "59584", - ] - - with popen(cmd, stdin=FNULL): - check_port_okay(59584) - - def tmpfile_static(extension="", dir=None): """ utility function for test_stale_sched test diff --git a/dask_mpi/tests/test_execute.py b/dask_mpi/tests/test_execute.py new file mode 100644 index 0000000..99b74fa --- /dev/null +++ b/dask_mpi/tests/test_execute.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import, division, print_function + +import os +import subprocess +import sys + +import pytest + +pytest.importorskip("mpi4py") + + +@pytest.mark.parametrize( + "mpisize,crank,srank,xworkers,retcode", + [ + (4, 1, 0, True, 0), # DEFAULTS + (1, 1, 0, True, 1), # Set too few processes to start cluster + (4, 2, 3, True, 0), + (5, 1, 3, True, 0), + (3, 2, 2, True, 0), + (2, 0, 0, False, 0), + (1, 0, 0, False, 0), + (1, 0, 0, True, 1), + ], +) +def test_basic(mpisize, crank, srank, xworkers, retcode, mpirun): + script_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "execute_basic.py" + ) + + script_args = [str(v) for v in (mpisize, crank, srank, xworkers)] + p = subprocess.Popen( + mpirun + ["-n", script_args[0], sys.executable, script_file] + script_args + ) + + p.communicate() + assert p.returncode == retcode + + +def test_no_exit(mpirun): + script_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "execute_no_exit.py" + ) + + p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) + + p.communicate() + assert p.returncode == 0 diff --git a/dask_mpi/tests/test_core.py b/dask_mpi/tests/test_initialize.py similarity index 58% rename from dask_mpi/tests/test_core.py rename to dask_mpi/tests/test_initialize.py index abe4a60..56757c4 100644 --- a/dask_mpi/tests/test_core.py +++ b/dask_mpi/tests/test_initialize.py @@ -11,7 +11,7 @@ def test_basic(mpirun): script_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "core_basic.py" + os.path.dirname(os.path.realpath(__file__)), "initialize_basic.py" ) p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) @@ -22,7 +22,7 @@ def test_basic(mpirun): def test_small_world(mpirun): script_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "core_basic.py" + os.path.dirname(os.path.realpath(__file__)), "initialize_basic.py" ) # Set too few processes to start cluster @@ -30,3 +30,14 @@ def test_small_world(mpirun): p.communicate() assert p.returncode != 0 + + +def test_no_exit(mpirun): + script_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "initialize_no_exit.py" + ) + + p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) + + p.communicate() + assert p.returncode == 0 diff --git a/dask_mpi/tests/test_no_exit.py b/dask_mpi/tests/test_no_exit.py deleted file mode 100644 index 65c27d5..0000000 --- a/dask_mpi/tests/test_no_exit.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import absolute_import, division, print_function - -import os -import subprocess -import sys - -import pytest - -pytest.importorskip("mpi4py") - - -def test_no_exit(mpirun): - script_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "core_no_exit.py" - ) - - p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) - - p.communicate() - assert p.returncode == 0 diff --git a/docs/environment.yml b/docs/environment.yml index 209c20f..3c70b3d 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -3,15 +3,23 @@ channels: - conda-forge - nodefaults dependencies: + - python<3.12 - dask>=2.19 - distributed>=2.19 - mpich - mpi4py>=3.0.3 - versioneer - - sphinx>=5.0 + - sphinx - pygments - pip - pip: + #>>>> See: https://github.com/dask/dask-sphinx-theme/issues/68 + - sphinxcontrib-applehelp<1.0.5 + - sphinxcontrib-devhelp<1.0.6 + - sphinxcontrib-htmlhelp<2.0.5 + - sphinxcontrib-serializinghtml<1.1.10 + - sphinxcontrib-qthelp<1.0.7 + #<<<< - dask-sphinx-theme>=3.0.5 - numpydoc - sphinx-click diff --git a/setup.py b/setup.py index 8a1aedf..d197b71 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def environment_dependencies(obj, dependencies=None): license="BSD 3-Clause", include_package_data=True, install_requires=install_requires, - python_requires=">=3.6", + python_requires=">=3.6,<3.12", packages=["dask_mpi"], long_description=long_description, entry_points="""