From 8a8df982d39e9b3a782517704fb72c2d6625a200 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 28 Jul 2023 17:09:52 +0200 Subject: [PATCH 01/10] Factor out static part of update_graph --- distributed/scheduler.py | 341 +++++++++++++++++++-------------------- 1 file changed, 163 insertions(+), 178 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9e7fc3689f..4f2d0b366d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -129,8 +129,6 @@ # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias - from dask.highlevelgraph import HighLevelGraph - # Not to be confused with distributed.worker_state_machine.TaskStateState TaskStateState: TypeAlias = Literal[ "released", @@ -4318,12 +4316,141 @@ async def add_nanny(self) -> dict[str, Any]: } return msg + @staticmethod + def _materialize_graph( + graph_header: dict, graph_frames: list[bytes], global_annotations: dict + ) -> tuple[dict, dict, dict]: + try: + from distributed.protocol import deserialize + + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + + from distributed.worker import dumps_task + + dsk = dask.utils.ensure_dict(graph) + + annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for annotations_type, value in global_annotations.items(): + annotations_by_type[annotations_type].update( + {k: (value(k) if callable(value) else value) for k in dsk} + ) + + for layer in graph.layers.values(): + if layer.annotations: + annot = layer.annotations + for annot_type, value in annot.items(): + annotations_by_type[annot_type].update( + { + stringify(k): (value(k) if callable(value) else value) + for k in layer + } + ) + + dependencies, _ = get_deps(dsk) + + # Remove `Future` objects from graph and note any future dependencies + dsk2 = {} + fut_deps = {} + for k, v in dsk.items(): + dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + if futs: + fut_deps[k] = futs + dsk = dsk2 + + # - Add in deps for any tasks that depend on futures + for k, futures in fut_deps.items(): + dependencies[k].update(f.key for f in futures) + new_dsk = {} + # Annotation callables are evaluated on the non-stringified version of + # the keys + exclusive = set(graph) + for k, v in dsk.items(): + new_k = stringify(k) + new_dsk[new_k] = stringify(v, exclusive=exclusive) + dsk = new_dsk + dependencies = { + stringify(k): {stringify(dep) for dep in deps} + for k, deps in dependencies.items() + } + + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps + + # Remove aliases + for k in list(dsk): + if dsk[k] is k: + del dsk[k] + dsk = valmap(dumps_task, dsk) + + return dsk, dependencies, annotations_by_type + + @staticmethod + def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): + n = 0 + lost_keys = set() + while len(dsk) != n: # walk through new tasks, cancel any bad deps + n = len(dsk) + for k, deps in list(dependencies.items()): + if any( + dep not in known_tasks and dep not in dsk for dep in deps + ): # bad key + lost_keys.add(k) + logger.info("User asked for computation on lost data, %s", k) + del dsk[k] + del dependencies[k] + if k in keys: + keys.remove(k) + + # Avoid computation that is already finished + already_in_memory = set() # tasks that are already done + for k, v in dependencies.items(): + if v and k in known_tasks: + ts = known_tasks[k] + if ts.state in ("memory", "erred"): + already_in_memory.add(k) + + done = set(already_in_memory) + if already_in_memory: + dependents = dask.core.reverse_dict(dependencies) + stack = list(already_in_memory) + while stack: # remove unnecessary dependencies + key = stack.pop() + try: + deps = dependencies[key] + except KeyError: + deps = known_tasks[key].dependencies + for dep in deps: + if dep in dependents: + child_deps = dependents[dep] + elif dep in known_tasks: + child_deps = known_tasks[dep].dependencies + else: + child_deps = set() + if all(d in done for d in child_deps): + if dep in known_tasks and dep not in done: + done.add(dep) + stack.append(dep) + for anc in done: + dsk.pop(anc, None) + dependencies.pop(anc, None) + return lost_keys + def update_graph( self, client: str, graph_header: dict, graph_frames: list[bytes], - keys: list[str], + keys: set[str], internal_priority: dict[str, int] | None, submitting_task: str | None, user_priority: int | dict[str, int] = 0, @@ -4334,47 +4461,32 @@ def update_graph( stimulus_id: str | None = None, ) -> None: start = time() - try: - # TODO: deserialization + materialization should be offloaded to a - # thread since this is non-trivial compute time that blocks the - # event loop. This likely requires us to use a lock since we need to - # guarantee ordering of update_graph calls (as long as there is just - # a single offload thread, this is not a problem) - from distributed.protocol import deserialize - - graph = deserialize(graph_header, graph_frames).data - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - try: - raise RuntimeError(textwrap.dedent(msg)) from e - except RuntimeError as e: - err = error_message(e) - for key in keys: - self.report( - { - "op": "task-erred", - "key": key, - "exception": err["exception"], - "traceback": err["traceback"], - } - ) - - return annotations = annotations or {} if isinstance(annotations, ToPickle): # type: ignore # FIXME: what the heck? annotations = annotations.data # type: ignore - - stimulus_id = stimulus_id or f"update-graph-{time()}" - ( - dsk, - dependencies, - annotations_by_type, - ) = self.materialize_graph(graph, annotations) - - if internal_priority is None: + try: + ( + dsk, + dependencies, + annotations_by_type, + ) = self._materialize_graph(graph_header, graph_frames, annotations) + del graph_header, graph_frames + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) + keys = set(keys) + lost_keys = self._match_graph_with_tasks(self.tasks, dsk, dependencies, keys) + ordered: dict = {} + if not internal_priority: # Removing all non-local keys before calling order() dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys stripped_deps = { @@ -4382,24 +4494,18 @@ def update_graph( for k, v in dependencies.items() if k in dsk_keys } - internal_priority = dask.order.order(dsk, dependencies=stripped_deps) + ordered = dask.order.order(dsk, dependencies=stripped_deps) + assert ordered - requested_keys = set(keys) - del keys + stimulus_id = stimulus_id or f"update-graph-{time()}" + + # FIXME: How can I log this cleanly? if len(dsk) > 1: self.log_event( ["all", client], {"action": "update_graph", "count": len(dsk)} ) - self._pop_known_tasks( - known_tasks=self.tasks, dsk=dsk, dependencies=dependencies - ) - if lost_keys := self._pop_lost_tasks( - dsk=dsk, - known_tasks=self.tasks, - dependencies=dependencies, - keys=requested_keys, - ): + if lost_keys: self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) self.client_releases_keys( keys=lost_keys, client=client, stimulus_id=stimulus_id @@ -4420,7 +4526,7 @@ def update_graph( computation.annotations.update(annotations) runnable, touched_tasks, new_tasks = self._generate_taskstates( - keys=requested_keys, + keys=keys, dsk=dsk, dependencies=dependencies, computation=computation, @@ -4432,7 +4538,7 @@ def update_graph( ) self._set_priorities( - internal_priority=internal_priority, + internal_priority=internal_priority or ordered, submitting_task=submitting_task, user_priority=user_priority, fifo_timeout=fifo_timeout, @@ -4440,11 +4546,11 @@ def update_graph( tasks=runnable, ) - self.client_desires_keys(keys=requested_keys, client=client) + self.client_desires_keys(keys=keys, client=client) # Add actors if actors is True: - actors = list(requested_keys) + actors = list(keys) for actor in actors or []: ts = self.tasks[actor] ts.actor = True @@ -4496,7 +4602,7 @@ def update_graph( self, client=client, tasks=[ts.key for ts in touched_tasks], - keys=requested_keys, + keys=keys, dependencies=dependencies, annotations=dict(annotations_for_plugin), priority=priority, @@ -4670,127 +4776,6 @@ def _set_priorities( isinstance(el, (int, float)) for el in ts.priority ) - @staticmethod - def _pop_lost_tasks( - dsk: dict, keys: set[str], known_tasks: dict[str, TaskState], dependencies: dict - ) -> set[str]: - n = 0 - out = set() - while len(dsk) != n: # walk through new tasks, cancel any bad deps - n = len(dsk) - for k, deps in list(dependencies.items()): - if any( - dep not in known_tasks and dep not in dsk for dep in deps - ): # bad key - out.add(k) - logger.info("User asked for computation on lost data, %s", k) - del dsk[k] - del dependencies[k] - if k in keys: - keys.remove(k) - return out - - @staticmethod - def _pop_known_tasks( - known_tasks: dict[str, TaskState], dsk: dict, dependencies: dict - ) -> set[str]: - # Avoid computation that is already finished - already_in_memory = set() # tasks that are already done - for k, v in dependencies.items(): - if v and k in known_tasks: - ts = known_tasks[k] - if ts.state in ("memory", "erred"): - already_in_memory.add(k) - - done = set(already_in_memory) - if already_in_memory: - dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) - while stack: # remove unnecessary dependencies - key = stack.pop() - try: - deps = dependencies[key] - except KeyError: - deps = known_tasks[key].dependencies - for dep in deps: - if dep in dependents: - child_deps = dependents[dep] - elif dep in known_tasks: - child_deps = known_tasks[dep].dependencies - else: - child_deps = set() - if all(d in done for d in child_deps): - if dep in known_tasks and dep not in done: - done.add(dep) - stack.append(dep) - for anc in done: - dsk.pop(anc, None) - dependencies.pop(anc, None) - return done - - @staticmethod - def materialize_graph( - hlg: HighLevelGraph, global_annotations: dict - ) -> tuple[dict, dict, dict]: - from distributed.worker import dumps_task - - dsk = dask.utils.ensure_dict(hlg) - - annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for type_, value in global_annotations.items(): - annotations_by_type[type_].update( - {stringify(k): (value(k) if callable(value) else value) for k in dsk} - ) - for layer in hlg.layers.values(): - if layer.annotations: - annot = layer.annotations - for annot_type, value in annot.items(): - annotations_by_type[annot_type].update( - { - stringify(k): (value(k) if callable(value) else value) - for k in layer - } - ) - - dependencies, _ = get_deps(dsk) - - # Remove `Future` objects from graph and note any future dependencies - dsk2 = {} - fut_deps = {} - for k, v in dsk.items(): - dsk2[k], futs = unpack_remotedata(v, byte_keys=True) - if futs: - fut_deps[k] = futs - dsk = dsk2 - - # - Add in deps for any tasks that depend on futures - for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) - new_dsk = {} - exclusive = set(hlg) - for k, v in dsk.items(): - new_k = stringify(k) - new_dsk[new_k] = stringify(v, exclusive=exclusive) - dsk = new_dsk - dependencies = { - stringify(k): {stringify(dep) for dep in deps} - for k, deps in dependencies.items() - } - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - # Remove aliases - for k in list(dsk): - if dsk[k] is k: - del dsk[k] - dsk = valmap(dumps_task, dsk) - return dsk, dependencies, dict(annotations_by_type) - def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: """Respond to an event which may have opened spots on worker threadpools From d192690f1254627f7e23d3cf0aaf99a6801a621b Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 28 Jul 2023 19:24:47 +0200 Subject: [PATCH 02/10] offload static components of graph materialization --- distributed/scheduler.py | 321 ++++++++++++++++------------ distributed/tests/test_client.py | 1 - distributed/tests/test_scheduler.py | 2 +- 3 files changed, 187 insertions(+), 137 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4f2d0b366d..c5552bd365 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -92,6 +92,7 @@ from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle +from distributed.protocol import deserialize from distributed.protocol.pickle import dumps, loads from distributed.protocol.serialize import Serialized, ToPickle, serialize from distributed.publish import PublishExtension @@ -112,6 +113,7 @@ key_split_group, log_errors, no_default, + offload, recursive_to_dict, validate_key, wait_for, @@ -124,11 +126,14 @@ ) from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension +from distributed.worker import dumps_task if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) from typing_extensions import TypeAlias + from dask.highlevelgraph import HighLevelGraph + # Not to be confused with distributed.worker_state_machine.TaskStateState TaskStateState: TypeAlias = Literal[ "released", @@ -3761,6 +3766,7 @@ def __init__( setproctitle("dask scheduler [not started]") Scheduler._instances.add(self) self.rpc.allow_offload = False + self._update_graph_lock = asyncio.Lock() ################## # Administration # @@ -4316,93 +4322,14 @@ async def add_nanny(self) -> dict[str, Any]: } return msg - @staticmethod - def _materialize_graph( - graph_header: dict, graph_frames: list[bytes], global_annotations: dict - ) -> tuple[dict, dict, dict]: - try: - from distributed.protocol import deserialize - - graph = deserialize(graph_header, graph_frames).data - del graph_header, graph_frames - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - raise RuntimeError(textwrap.dedent(msg)) from e - - from distributed.worker import dumps_task - - dsk = dask.utils.ensure_dict(graph) - - annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) - for annotations_type, value in global_annotations.items(): - annotations_by_type[annotations_type].update( - {k: (value(k) if callable(value) else value) for k in dsk} - ) - - for layer in graph.layers.values(): - if layer.annotations: - annot = layer.annotations - for annot_type, value in annot.items(): - annotations_by_type[annot_type].update( - { - stringify(k): (value(k) if callable(value) else value) - for k in layer - } - ) - - dependencies, _ = get_deps(dsk) - - # Remove `Future` objects from graph and note any future dependencies - dsk2 = {} - fut_deps = {} - for k, v in dsk.items(): - dsk2[k], futs = unpack_remotedata(v, byte_keys=True) - if futs: - fut_deps[k] = futs - dsk = dsk2 - - # - Add in deps for any tasks that depend on futures - for k, futures in fut_deps.items(): - dependencies[k].update(f.key for f in futures) - new_dsk = {} - # Annotation callables are evaluated on the non-stringified version of - # the keys - exclusive = set(graph) - for k, v in dsk.items(): - new_k = stringify(k) - new_dsk[new_k] = stringify(v, exclusive=exclusive) - dsk = new_dsk - dependencies = { - stringify(k): {stringify(dep) for dep in deps} - for k, deps in dependencies.items() - } - - # Remove any self-dependencies (happens on test_publish_bag() and others) - for k, v in dependencies.items(): - deps = set(v) - if k in deps: - deps.remove(k) - dependencies[k] = deps - - # Remove aliases - for k in list(dsk): - if dsk[k] is k: - del dsk[k] - dsk = valmap(dumps_task, dsk) - - return dsk, dependencies, annotations_by_type - - @staticmethod - def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): + def _match_graph_with_tasks(self, dsk, dependencies, keys): n = 0 lost_keys = set() while len(dsk) != n: # walk through new tasks, cancel any bad deps n = len(dsk) for k, deps in list(dependencies.items()): if any( - dep not in known_tasks and dep not in dsk for dep in deps + dep not in self.tasks and dep not in dsk for dep in deps ): # bad key lost_keys.add(k) logger.info("User asked for computation on lost data, %s", k) @@ -4414,8 +4341,8 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): # Avoid computation that is already finished already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in known_tasks: - ts = known_tasks[k] + if v and k in self.tasks: + ts = self.tasks[k] if ts.state in ("memory", "erred"): already_in_memory.add(k) @@ -4428,16 +4355,16 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): try: deps = dependencies[key] except KeyError: - deps = known_tasks[key].dependencies + deps = self.tasks[key].dependencies for dep in deps: if dep in dependents: child_deps = dependents[dep] - elif dep in known_tasks: - child_deps = known_tasks[dep].dependencies + elif dep in self.tasks: + child_deps = self.tasks[dep].dependencies else: child_deps = set() if all(d in done for d in child_deps): - if dep in known_tasks and dep not in done: + if dep in self.tasks and dep not in done: done.add(dep) stack.append(dep) for anc in done: @@ -4445,61 +4372,37 @@ def _match_graph_with_tasks(known_tasks, dsk, dependencies, keys): dependencies.pop(anc, None) return lost_keys - def update_graph( + def _create_taskstate_from_graph( self, - client: str, - graph_header: dict, - graph_frames: list[bytes], + *, + start: float, + dsk: dict, + dependencies: dict, keys: set[str], - internal_priority: dict[str, int] | None, + ordered: dict[str, int], + client: str, + annotations_by_type: dict, + global_annotations: dict | None, + stimulus_id: str, submitting_task: str | None, user_priority: int | dict[str, int] = 0, actors: bool | list[str] | None = None, fifo_timeout: float = 0.0, code: tuple[SourceCode, ...] = (), - annotations: dict | None = None, - stimulus_id: str | None = None, ) -> None: - start = time() - annotations = annotations or {} - if isinstance(annotations, ToPickle): # type: ignore - # FIXME: what the heck? - annotations = annotations.data # type: ignore - try: - ( - dsk, - dependencies, - annotations_by_type, - ) = self._materialize_graph(graph_header, graph_frames, annotations) - del graph_header, graph_frames - except RuntimeError as e: - err = error_message(e) - for key in keys: - self.report( - { - "op": "task-erred", - "key": key, - "exception": err["exception"], - "traceback": err["traceback"], - } - ) - keys = set(keys) - lost_keys = self._match_graph_with_tasks(self.tasks, dsk, dependencies, keys) - ordered: dict = {} - if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set(dsk) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - ordered = dask.order.order(dsk, dependencies=stripped_deps) - assert ordered + """ + Take a low level graph and create the necessary scheduler state to + compute it. - stimulus_id = stimulus_id or f"update-graph-{time()}" + WARNING + ------- + This method must not be made async since nothing here is concurrency + safe. All interactions with TaskState objects here should be happening + in the same event loop tick. + """ + + lost_keys = self._match_graph_with_tasks(dsk, dependencies, keys) - # FIXME: How can I log this cleanly? if len(dsk) > 1: self.log_event( ["all", client], {"action": "update_graph", "count": len(dsk)} @@ -4520,10 +4423,11 @@ def update_graph( if code: # add new code blocks computation.code.add(code) - if annotations: + if global_annotations: # FIXME: This is kind of inconsistent since it only includes global # annotations. - computation.annotations.update(annotations) + computation.annotations.update(global_annotations) + del global_annotations runnable, touched_tasks, new_tasks = self._generate_taskstates( keys=keys, @@ -4538,7 +4442,7 @@ def update_graph( ) self._set_priorities( - internal_priority=internal_priority or ordered, + internal_priority=ordered, submitting_task=submitting_task, user_priority=user_priority, fifo_timeout=fifo_timeout, @@ -4616,6 +4520,90 @@ def update_graph( if ts.state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) + @log_errors + async def update_graph( + self, + client: str, + graph_header: dict, + graph_frames: list[bytes], + keys: set[str], + internal_priority: dict[str, int] | None, + submitting_task: str | None, + user_priority: int | dict[str, int] = 0, + actors: bool | list[str] | None = None, + fifo_timeout: float = 0.0, + code: tuple[SourceCode, ...] = (), + annotations: dict | None = None, + stimulus_id: str | None = None, + ) -> None: + # FIXME: Apparently empty dicts arrive as a ToPickle object + if isinstance(annotations, ToPickle): + annotations = annotations.data # type: ignore[unreachable] + start = time() + async with self._update_graph_lock: + try: + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) + else: + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps + ) + + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation objects. This + # should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) end = time() self.digest_metric("update-graph-duration", end - start) @@ -8490,3 +8478,66 @@ def transition( self.metadata[key] = ts.metadata self.state[key] = finish self.keys.discard(key) + + +def _materialize_graph( + graph: HighLevelGraph, global_annotations: dict +) -> tuple[dict, dict, dict]: + dsk = dask.utils.ensure_dict(graph) + annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for annotations_type, value in global_annotations.items(): + annotations_by_type[annotations_type].update( + {stringify(k): (value(k) if callable(value) else value) for k in dsk} + ) + + for layer in graph.layers.values(): + if layer.annotations: + annot = layer.annotations + for annot_type, value in annot.items(): + annotations_by_type[annot_type].update( + { + stringify(k): (value(k) if callable(value) else value) + for k in layer + } + ) + dependencies, _ = get_deps(dsk) + + # Remove `Future` objects from graph and note any future dependencies + dsk2 = {} + fut_deps = {} + for k, v in dsk.items(): + dsk2[k], futs = unpack_remotedata(v, byte_keys=True) + if futs: + fut_deps[k] = futs + dsk = dsk2 + + # - Add in deps for any tasks that depend on futures + for k, futures in fut_deps.items(): + dependencies[k].update(f.key for f in futures) + new_dsk = {} + # Annotation callables are evaluated on the non-stringified version of + # the keys + exclusive = set(graph) + for k, v in dsk.items(): + new_k = stringify(k) + new_dsk[new_k] = stringify(v, exclusive=exclusive) + dsk = new_dsk + dependencies = { + stringify(k): {stringify(dep) for dep in deps} + for k, deps in dependencies.items() + } + + # Remove any self-dependencies (happens on test_publish_bag() and others) + for k, v in dependencies.items(): + deps = set(v) + if k in deps: + deps.remove(k) + dependencies[k] = deps + + # Remove aliases + for k in list(dsk): + if dsk[k] is k: + del dsk[k] + dsk = valmap(dumps_task, dsk) + + return dsk, dependencies, annotations_by_type diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5a9c5e8014..3409be2963 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3719,7 +3719,6 @@ async def test_scatter_raises_if_no_workers(c, s): await c.scatter(1, timeout=0.5) -@pytest.mark.slow @gen_test() async def test_reconnect(): port = open_port() diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 275148f605..9a78fd2a5f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1358,7 +1358,7 @@ async def test_update_graph_culls(s, a, b): ) header, frames = serialize(ToPickle(dsk), on_error="raise") - s.update_graph( + await s.update_graph( graph_header=header, graph_frames=frames, keys=["y"], From 18a6f31ff45efc2e68289d5116286a9f314dbcac Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 3 Aug 2023 15:11:12 +0200 Subject: [PATCH 03/10] handle errors properly --- distributed/scheduler.py | 100 +++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c5552bd365..1cb44e00e3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4542,14 +4542,57 @@ async def update_graph( start = time() async with self._update_graph_lock: try: - graph = deserialize(graph_header, graph_frames).data - del graph_header, graph_frames - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - raise RuntimeError(textwrap.dedent(msg)) from e + try: + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + else: + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps + ) + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation + # objects. This should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) except RuntimeError as e: err = error_message(e) for key in keys: @@ -4561,49 +4604,6 @@ async def update_graph( "traceback": err["traceback"], } ) - else: - ( - dsk, - dependencies, - annotations_by_type, - ) = await offload( - _materialize_graph, - graph=graph, - global_annotations=annotations or {}, - ) - del graph - if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set( - dsk - ) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - internal_priority = await offload( - dask.order.order, dsk=dsk, dependencies=stripped_deps - ) - - self._create_taskstate_from_graph( - dsk=dsk, - client=client, - dependencies=dependencies, - keys=set(keys), - ordered=internal_priority or {}, - submitting_task=submitting_task, - user_priority=user_priority, - actors=actors, - fifo_timeout=fifo_timeout, - code=code, - annotations_by_type=annotations_by_type, - # FIXME: This is just used to attach to Computation objects. This - # should be removed - global_annotations=annotations, - start=start, - stimulus_id=stimulus_id or f"update-graph-{start}", - ) end = time() self.digest_metric("update-graph-duration", end - start) From c002bd909e1f695cf313ee978d899a81db011fe6 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 3 Aug 2023 19:33:33 +0200 Subject: [PATCH 04/10] more robust test_steal_more_attractive_tasks --- distributed/tests/test_steal.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index a0dadfcc32..ac439ab675 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -641,10 +641,11 @@ async def test_steal_when_more_tasks(c, s, a, *rest): "slowidentity": 0.2, "slow2": 1, }, - "distributed.scheduler.work-stealing-interval": "20ms", }, ) async def test_steal_more_attractive_tasks(c, s, a, *rest): + ext = s.extensions["stealing"] + def slow2(x): sleep(1) return x @@ -652,9 +653,17 @@ def slow2(x): x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB await wait(x) + # We have to stop the extension entirely since otherwise a tick might + # already allow a stealing request may sneak in before all tasks are on the + # scheduler + await ext.stop() futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) + while future.key not in s.tasks: + await asyncio.sleep(0.01) + # Now call it once explicitly to move the heavy task + ext.balance() while not any(w.state.tasks for w in rest): await asyncio.sleep(0.01) From 4d80e59a081078fa20553306852d6b0ce0cb68cd Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 09:03:59 +0200 Subject: [PATCH 05/10] Extend comment about stopped stealing extension --- distributed/tests/test_steal.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index ac439ab675..10569265ed 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -653,9 +653,19 @@ def slow2(x): x = c.submit(mul, b"0", 100000000, workers=a.address) # 100 MB await wait(x) - # We have to stop the extension entirely since otherwise a tick might - # already allow a stealing request may sneak in before all tasks are on the - # scheduler + # The submits below are all individual update_graph calls which are very + # likely submitted in the same batch. + # Prior to https://github.com/dask/distributed/pull/8049, the entire batch + # would be processed by the scheduler in the same event loop tick. + # Therefore, the first PC `stealing.balance` call would be guaranteed to see + # all the tasks and make the correct decision. + # After the PR, the batch is processed in multiple event loop ticks, so the + # first PC `stealing.balance` call would potentially only see the first + # tasks and would try to rebalance them instead of the slow and heavy one. + # To guarantee that the stealing extension sees all tasks, we're stopping + # the callback and are calling balance ourselves once we are certain the + # tasks are all on the scheduler. + # Related https://github.com/dask/distributed/pull/5443 await ext.stop() futures = [c.submit(slowidentity, x, pure=False, delay=0.2) for i in range(10)] future = c.submit(slow2, x, priority=-1) From bcf7a3920219453db9f2a7a190e2e181e92d9000 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 09:43:12 +0200 Subject: [PATCH 06/10] Remove lock --- distributed/scheduler.py | 122 +++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 62 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1cb44e00e3..58ee51e912 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3766,7 +3766,6 @@ def __init__( setproctitle("dask scheduler [not started]") Scheduler._instances.add(self) self.rpc.allow_offload = False - self._update_graph_lock = asyncio.Lock() ################## # Administration # @@ -4540,70 +4539,69 @@ async def update_graph( if isinstance(annotations, ToPickle): annotations = annotations.data # type: ignore[unreachable] start = time() - async with self._update_graph_lock: + try: try: - try: - graph = deserialize(graph_header, graph_frames).data - del graph_header, graph_frames - except Exception as e: - msg = """\ - Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments - """ - raise RuntimeError(textwrap.dedent(msg)) from e - else: - ( - dsk, - dependencies, - annotations_by_type, - ) = await offload( - _materialize_graph, - graph=graph, - global_annotations=annotations or {}, + graph = deserialize(graph_header, graph_frames).data + del graph_header, graph_frames + except Exception as e: + msg = """\ + Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments + """ + raise RuntimeError(textwrap.dedent(msg)) from e + else: + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps ) - del graph - if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set( - dsk - ) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - internal_priority = await offload( - dask.order.order, dsk=dsk, dependencies=stripped_deps - ) - self._create_taskstate_from_graph( - dsk=dsk, - client=client, - dependencies=dependencies, - keys=set(keys), - ordered=internal_priority or {}, - submitting_task=submitting_task, - user_priority=user_priority, - actors=actors, - fifo_timeout=fifo_timeout, - code=code, - annotations_by_type=annotations_by_type, - # FIXME: This is just used to attach to Computation - # objects. This should be removed - global_annotations=annotations, - start=start, - stimulus_id=stimulus_id or f"update-graph-{start}", - ) - except RuntimeError as e: - err = error_message(e) - for key in keys: - self.report( - { - "op": "task-erred", - "key": key, - "exception": err["exception"], - "traceback": err["traceback"], - } - ) + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation + # objects. This should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) + except RuntimeError as e: + err = error_message(e) + for key in keys: + self.report( + { + "op": "task-erred", + "key": key, + "exception": err["exception"], + "traceback": err["traceback"], + } + ) end = time() self.digest_metric("update-graph-duration", end - start) From c2250b9d6b2c8353a5b20dc7e58d42aaffe30ecf Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 15:24:30 +0200 Subject: [PATCH 07/10] wait for rootish test_decide_worker_rootish_while_last_worker_is_retiring --- distributed/tests/test_scheduler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9a78fd2a5f..8c5bb3e75c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -303,6 +303,15 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a): while a.state.executing_count != 1 or b.state.executing_count != 1: await asyncio.sleep(0.01) + # Rootish is a dynamic property as it is defined right now. Since the + # above submit calls are individual update_graph calls, waiting for + # tasks to be in executing state on the worker is not sufficient to + # guarantee that all the y tasks are already on the scheduler. Only + # after at least 5 have been registered, will the task be flagged as + # rootish + while "y-2" not in s.tasks or not s.is_rootish(s.tasks["y-2"]): + await asyncio.sleep(0.01) + # - y-2 has no restrictions # - TaskGroup(y) has more than 4 tasks (total_nthreads * 2) # - TaskGroup(y) has less than 5 dependency groups From 08ef21ac82c2833a4925b032d50d19047ff70f79 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 15:37:51 +0200 Subject: [PATCH 08/10] review nits --- distributed/scheduler.py | 100 +++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 58ee51e912..d07a3a28c6 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4321,7 +4321,9 @@ async def add_nanny(self) -> dict[str, Any]: } return msg - def _match_graph_with_tasks(self, dsk, dependencies, keys): + def _match_graph_with_tasks( + self, dsk: dict[str, Any], dependencies: dict[str, set[str]], keys: set[str] + ) -> set[str]: n = 0 lost_keys = set() while len(dsk) != n: # walk through new tasks, cancel any bad deps @@ -4336,30 +4338,29 @@ def _match_graph_with_tasks(self, dsk, dependencies, keys): del dependencies[k] if k in keys: keys.remove(k) - + del deps # Avoid computation that is already finished - already_in_memory = set() # tasks that are already done + done = set() # tasks that are already done for k, v in dependencies.items(): if v and k in self.tasks: ts = self.tasks[k] if ts.state in ("memory", "erred"): - already_in_memory.add(k) + done.add(k) - done = set(already_in_memory) - if already_in_memory: + if done: dependents = dask.core.reverse_dict(dependencies) - stack = list(already_in_memory) + stack = list(done) while stack: # remove unnecessary dependencies key = stack.pop() try: deps = dependencies[key] except KeyError: - deps = self.tasks[key].dependencies + deps = {ts.key for ts in self.tasks[key].dependencies} for dep in deps: if dep in dependents: child_deps = dependents[dep] elif dep in self.tasks: - child_deps = self.tasks[dep].dependencies + child_deps = {ts.key for ts in self.tasks[key].dependencies} else: child_deps = set() if all(d in done for d in child_deps): @@ -4548,49 +4549,48 @@ async def update_graph( Error during deserialization of the task graph. This frequently occurs if the Scheduler and Client have different environments. For more information, see https://docs.dask.org/en/stable/deployment-considerations.html#consistent-software-environments """ raise RuntimeError(textwrap.dedent(msg)) from e - else: - ( - dsk, - dependencies, - annotations_by_type, - ) = await offload( - _materialize_graph, - graph=graph, - global_annotations=annotations or {}, + ( + dsk, + dependencies, + annotations_by_type, + ) = await offload( + _materialize_graph, + graph=graph, + global_annotations=annotations or {}, + ) + del graph + if not internal_priority: + # Removing all non-local keys before calling order() + dsk_keys = set( + dsk + ) # intersection() of sets is much faster than dict_keys + stripped_deps = { + k: v.intersection(dsk_keys) + for k, v in dependencies.items() + if k in dsk_keys + } + internal_priority = await offload( + dask.order.order, dsk=dsk, dependencies=stripped_deps ) - del graph - if not internal_priority: - # Removing all non-local keys before calling order() - dsk_keys = set( - dsk - ) # intersection() of sets is much faster than dict_keys - stripped_deps = { - k: v.intersection(dsk_keys) - for k, v in dependencies.items() - if k in dsk_keys - } - internal_priority = await offload( - dask.order.order, dsk=dsk, dependencies=stripped_deps - ) - self._create_taskstate_from_graph( - dsk=dsk, - client=client, - dependencies=dependencies, - keys=set(keys), - ordered=internal_priority or {}, - submitting_task=submitting_task, - user_priority=user_priority, - actors=actors, - fifo_timeout=fifo_timeout, - code=code, - annotations_by_type=annotations_by_type, - # FIXME: This is just used to attach to Computation - # objects. This should be removed - global_annotations=annotations, - start=start, - stimulus_id=stimulus_id or f"update-graph-{start}", - ) + self._create_taskstate_from_graph( + dsk=dsk, + client=client, + dependencies=dependencies, + keys=set(keys), + ordered=internal_priority or {}, + submitting_task=submitting_task, + user_priority=user_priority, + actors=actors, + fifo_timeout=fifo_timeout, + code=code, + annotations_by_type=annotations_by_type, + # FIXME: This is just used to attach to Computation + # objects. This should be removed + global_annotations=annotations, + start=start, + stimulus_id=stimulus_id or f"update-graph-{start}", + ) except RuntimeError as e: err = error_message(e) for key in keys: From 36087733fe3131ac9ed9523b73e17eef5a9f30c6 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 15:43:48 +0200 Subject: [PATCH 09/10] Restore slow marker for test_reconnect --- distributed/tests/test_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3409be2963..5a9c5e8014 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3719,6 +3719,7 @@ async def test_scatter_raises_if_no_workers(c, s): await c.scatter(1, timeout=0.5) +@pytest.mark.slow @gen_test() async def test_reconnect(): port = open_port() From 8e7c690af0ec861d8845d3464fa600e51189d27f Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 4 Aug 2023 16:08:39 +0200 Subject: [PATCH 10/10] indent del --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d07a3a28c6..9e1b16fd76 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4338,7 +4338,7 @@ def _match_graph_with_tasks( del dependencies[k] if k in keys: keys.remove(k) - del deps + del deps # Avoid computation that is already finished done = set() # tasks that are already done for k, v in dependencies.items():