Skip to content

Commit

Permalink
adds DaskScheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
normanrz committed Sep 12, 2023
1 parent 9601235 commit c3c275e
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 29 deletions.
10 changes: 9 additions & 1 deletion cluster_tools/cluster_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing_extensions import Literal

from cluster_tools.executors.dask import DaskExecutor
from cluster_tools.executors.debug_sequential import DebugSequentialExecutor
from cluster_tools.executors.multiprocessing_ import MultiprocessingExecutor
from cluster_tools.executors.pickle_ import PickleExecutor
Expand Down Expand Up @@ -70,6 +71,11 @@ def get_executor(
...


@overload
def get_executor(environment: Literal["dask"], **kwargs: Any) -> DaskExecutor:
...


@overload
def get_executor(
environment: Literal["multiprocessing"], **kwargs: Any
Expand Down Expand Up @@ -105,6 +111,8 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
return PBSExecutor(**kwargs)
elif environment == "kubernetes":
return KubernetesExecutor(**kwargs)
elif environment == "dask":
return DaskExecutor(**kwargs)
elif environment == "multiprocessing":
global did_start_test_multiprocessing
if not did_start_test_multiprocessing:
Expand All @@ -121,4 +129,4 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
raise Exception("Unknown executor: {}".format(environment))


Executor = Union[ClusterExecutor, MultiprocessingExecutor]
Executor = Union[ClusterExecutor, MultiprocessingExecutor, DaskExecutor]
103 changes: 103 additions & 0 deletions cluster_tools/cluster_tools/executors/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
from concurrent import futures
from concurrent.futures import Future
from functools import partial
from typing import Any, Callable, Iterable, Iterator, List, Optional, TypeVar, cast

from dask.distributed import Client, as_completed
from typing_extensions import ParamSpec

from cluster_tools._utils.warning import enrich_future_with_uncaught_warning
from cluster_tools.executors.multiprocessing_ import CFutDict, MultiprocessingExecutor

_T = TypeVar("_T")
_P = ParamSpec("_P")
_S = TypeVar("_S")


class DaskExecutor(futures.Executor):
client: Client

def __init__(
self,
**kwargs: Any,
) -> None:
self.client = Client(**kwargs)

@classmethod
def as_completed(cls, futures: List[Future[_T]]) -> Iterator[Future[_T]]:
return as_completed(futures)

def submit( # type: ignore[override]
self,
__fn: Callable[_P, _T],
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Future[_T]":
if "__cfut_options" in kwargs:
output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[
"output_pickle_path"
]
del kwargs["__cfut_options"]

__fn = partial(
MultiprocessingExecutor._execute_and_persist_function,
output_pickle_path,
__fn,
)
fut = self.client.submit(partial(__fn, *args, **kwargs))

enrich_future_with_uncaught_warning(fut)
return fut

def map_unordered(self, fn: Callable[_P, _T], args: Any) -> Iterator[_T]:
futs: List["Future[_T]"] = self.map_to_futures(fn, args)

# Return a separate generator to avoid that map_unordered
# is executed lazily (otherwise, jobs would be submitted
# lazily, as well).
def result_generator() -> Iterator:
for fut in as_completed(futs):
yield fut.result()

return result_generator()

def map_to_futures(
self,
fn: Callable[[_S], _T],
args: Iterable[_S], # TODO change: allow more than one arg per call
output_pickle_path_getter: Optional[Callable[[_S], os.PathLike]] = None,
) -> List["Future[_T]"]:
if output_pickle_path_getter is not None:
futs = [
self.submit( # type: ignore[call-arg]
fn,
arg,
__cfut_options={
"output_pickle_path": output_pickle_path_getter(arg)
},
)
for arg in args
]
else:
futs = [self.submit(fn, arg) for arg in args]

return futs

def map(
self,
fn: Callable[[_S], _T],
args: Iterable[_S], # TODO change: allow more than one arg per call,
timeout=None,
chunksize=1,
):
return list(super().map(fn, args, timeout=timeout, chunksize=chunksize))

def forward_log(self, fut: "Future[_T]") -> _T:
return fut.result()

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
if wait:
self.client.close(timeout=60 * 60 * 24)
else:
self.client.close()
4 changes: 4 additions & 0 deletions cluster_tools/cluster_tools/executors/multiprocessing_.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(
else:
self._mp_logging_handler_pool = _MultiprocessingLoggingHandlerPool()

@classmethod
def as_completed(cls, futures: List[Future[_T]]) -> Iterator[Future[_T]]:
return futures.as_completed(futures)

def submit( # type: ignore[override]
self,
__fn: Callable[_P, _T],
Expand Down
4 changes: 4 additions & 0 deletions cluster_tools/cluster_tools/schedulers/cluster_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def __init__(
if "logging_setup_fn" in kwargs:
self.meta_data["logging_setup_fn"] = kwargs["logging_setup_fn"]

@classmethod
def as_completed(cls, futures: List[Future[_T]]) -> Iterator[Future[_T]]:
return futures.as_completed(futures)

@classmethod
@abstractmethod
def executor_key(cls) -> str:
Expand Down
Loading

0 comments on commit c3c275e

Please sign in to comment.