diff --git a/cluster_tools/cluster_tools/__init__.py b/cluster_tools/cluster_tools/__init__.py index b2dfe159d..f6f975799 100644 --- a/cluster_tools/cluster_tools/__init__.py +++ b/cluster_tools/cluster_tools/__init__.py @@ -114,7 +114,7 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor": if "client" in kwargs: return DaskExecutor(kwargs["client"]) else: - return DaskExecutor.from_kwargs(**kwargs) + return DaskExecutor.from_config(**kwargs) elif environment == "multiprocessing": global did_start_test_multiprocessing if not did_start_test_multiprocessing: diff --git a/cluster_tools/cluster_tools/executors/dask.py b/cluster_tools/cluster_tools/executors/dask.py index eb46cb4ee..223081173 100644 --- a/cluster_tools/cluster_tools/executors/dask.py +++ b/cluster_tools/cluster_tools/executors/dask.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, Callable, + Dict, Iterable, Iterator, List, @@ -37,13 +38,13 @@ def __init__( self.client = client @classmethod - def from_kwargs( + def from_config( cls, - **kwargs: Any, + job_resources: Dict[str, Any], ) -> "DaskExecutor": from distributed import Client - return cls(Client(**kwargs)) + return cls(Client(**job_resources)) @classmethod def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]: