Skip to content

Commit

Permalink
Reduce P2P transfer task overhead (#8912)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Oct 30, 2024
1 parent 9da5824 commit f340f18
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 122 deletions.
60 changes: 60 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from tornado.ioloop import IOLoop

import dask.config
from dask._task_spec import Task, _inline_recursively
from dask.core import flatten
from dask.sizeof import sizeof
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta

Expand Down Expand Up @@ -575,3 +577,61 @@ def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int:
raise
except Exception as e:
raise RuntimeError(f"P2P {id} failed during barrier phase") from e


class P2PBarrierTask(Task):
spec: ShuffleSpec

__slots__ = tuple(__annotations__)

def __init__(
self,
key: Any,
func: Callable[..., Any],
/,
*args: Any,
spec: ShuffleSpec,
**kwargs: Any,
):
self.spec = spec
super().__init__(key, func, *args, **kwargs)

def copy(self) -> P2PBarrierTask:
self.unpack()
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.spec)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)

def __getstate__(self) -> dict[str, Any]:
state = super().__getstate__()
state["spec"] = self.spec
return state

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.spec = state["spec"]

def __eq__(self, value: object) -> bool:
if not isinstance(value, P2PBarrierTask):
return False
if not super().__eq__(value):
return False
if self.spec != value.spec:
return False
return True
117 changes: 74 additions & 43 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
# mypy: ignore-errors
from __future__ import annotations

from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any

import dask
from dask._task_spec import GraphNode, Task, TaskRef
from dask.base import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.tokenize import tokenize
from dask.typing import Key

from distributed.shuffle._arrow import check_minimal_arrow_version
from distributed.shuffle._core import (
P2PBarrierTask,
ShuffleId,
barrier_key,
get_worker_plugin,
p2p_barrier,
)
from distributed.shuffle._shuffle import shuffle_transfer
from distributed.shuffle._shuffle import DataFrameShuffleSpec, shuffle_transfer

if TYPE_CHECKING:
import pandas as pd
from pandas._typing import IndexLabel, MergeHow, Suffixes

# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

from dask.dataframe.core import _Frame

_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode]

_HASH_COLUMN_NAME = "__hash_partition"

Expand Down Expand Up @@ -148,21 +155,11 @@ def merge_transfer(
input: pd.DataFrame,
id: ShuffleId,
input_partition: int,
npartitions: int,
meta: pd.DataFrame,
parts_out: set[int],
disk: bool,
):
return shuffle_transfer(
input=input,
id=id,
input_partition=input_partition,
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
meta=meta,
parts_out=parts_out,
disk=disk,
drop_column=True,
)


Expand Down Expand Up @@ -208,7 +205,7 @@ class HashJoinP2PLayer(Layer):
suffixes: Suffixes
indicator: bool
meta_output: pd.DataFrame
parts_out: Sequence[int]
parts_out: set[int]

name_input_left: str
meta_input_left: pd.DataFrame
Expand Down Expand Up @@ -241,7 +238,7 @@ def __init__(
how: MergeHow = "inner",
suffixes: Suffixes = ("_x", "_y"),
indicator: bool = False,
parts_out: Sequence | None = None,
parts_out: Iterable[int] | None = None,
annotations: dict | None = None,
) -> None:
check_minimal_arrow_version()
Expand All @@ -257,7 +254,10 @@ def __init__(
self.suffixes = suffixes
self.indicator = indicator
self.meta_output = meta_output
self.parts_out = parts_out or list(range(npartitions))
if parts_out:
self.parts_out = set(parts_out)
else:
self.parts_out = set(range(npartitions))
self.n_partitions_left = n_partitions_left
self.n_partitions_right = n_partitions_right
self.left_index = left_index
Expand Down Expand Up @@ -325,7 +325,7 @@ def _dict(self):
self._cached_dict = dsk
return self._cached_dict

def _cull(self, parts_out: Sequence[str]):
def _cull(self, parts_out: Iterable[int]):
return HashJoinP2PLayer(
name=self.name,
name_input_left=self.name_input_left,
Expand Down Expand Up @@ -365,7 +365,7 @@ def cull(self, keys: Iterable[str], all_keys: Any) -> tuple[HashJoinP2PLayer, di
else:
return self, culled_deps

def _construct_graph(self) -> dict[tuple | str, tuple]:
def _construct_graph(self) -> _T_LowLevelGraph:
token_left = tokenize(
# Include self.name to ensure that shuffle IDs are unique for individual
# merge operations. Reusing shuffles between merges is dangerous because of
Expand All @@ -375,6 +375,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.left_on,
self.left_index,
)
shuffle_id_left = ShuffleId(token_left)
token_right = tokenize(
# Include self.name to ensure that shuffle IDs are unique for individual
# merge operations. Reusing shuffles between merges is dangerous because of
Expand All @@ -384,50 +385,79 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.right_on,
self.right_index,
)
dsk: dict[tuple | str, tuple] = {}
shuffle_id_right = ShuffleId(token_right)
dsk: _T_LowLevelGraph = {}
name_left = "hash-join-transfer-" + token_left
name_right = "hash-join-transfer-" + token_right
transfer_keys_left = list()
transfer_keys_right = list()
for i in range(self.n_partitions_left):
transfer_keys_left.append((name_left, i))
dsk[(name_left, i)] = (
t = Task(
(name_left, i),
merge_transfer,
(self.name_input_left, i),
token_left,
TaskRef((self.name_input_left, i)),
shuffle_id_left,
i,
self.npartitions,
self.meta_input_left,
self.parts_out,
self.disk,
)
dsk[t.key] = t
transfer_keys_left.append(t.ref())

transfer_keys_right = list()
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
dsk[(name_right, i)] = (
t = Task(
(name_right, i),
merge_transfer,
(self.name_input_right, i),
token_right,
TaskRef((self.name_input_right, i)),
shuffle_id_right,
i,
self.npartitions,
self.meta_input_right,
self.parts_out,
self.disk,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
_barrier_key_right = barrier_key(ShuffleId(token_right))
dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left)
dsk[_barrier_key_right] = (p2p_barrier, token_right, transfer_keys_right)
dsk[t.key] = t
transfer_keys_right.append(t.ref())

_barrier_key_left = barrier_key(shuffle_id_left)
barrier_left = P2PBarrierTask(
_barrier_key_left,
p2p_barrier,
token_left,
transfer_keys_left,
spec=DataFrameShuffleSpec(
id=shuffle_id_left,
npartitions=self.npartitions,
column=_HASH_COLUMN_NAME,
meta=self.meta_input_left,
parts_out=self.parts_out,
disk=self.disk,
drop_column=True,
),
)
dsk[barrier_left.key] = barrier_left
_barrier_key_right = barrier_key(shuffle_id_right)
barrier_right = P2PBarrierTask(
_barrier_key_right,
p2p_barrier,
token_right,
transfer_keys_right,
spec=DataFrameShuffleSpec(
id=shuffle_id_right,
npartitions=self.npartitions,
column=_HASH_COLUMN_NAME,
meta=self.meta_input_right,
parts_out=self.parts_out,
disk=self.disk,
drop_column=True,
),
)
dsk[barrier_right.key] = barrier_right

name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
t = Task(
(name, part_out),
merge_unpack,
token_left,
token_right,
part_out,
_barrier_key_left,
_barrier_key_right,
barrier_left.ref(),
barrier_right.ref(),
self.how,
self.left_on,
self.right_on,
Expand All @@ -437,4 +467,5 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.right_index,
self.indicator,
)
dsk[t.key] = t
return dsk
25 changes: 13 additions & 12 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,18 @@
from distributed.metrics import context_meter
from distributed.shuffle._core import (
NDIndex,
P2PBarrierTask,
ShuffleId,
ShuffleRun,
ShuffleSpec,
barrier_key,
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
p2p_barrier,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._shuffle import barrier_key
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand All @@ -164,15 +165,12 @@ def rechunk_transfer(
input: np.ndarray,
id: ShuffleId,
input_chunk: NDIndex,
new: ChunkedAxes,
old: ChunkedAxes,
disk: bool,
) -> int:
with handle_transfer_errors(id):
return get_worker_plugin().add_partition(
input,
partition_id=input_chunk,
spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk),
id=id,
)


Expand Down Expand Up @@ -815,16 +813,19 @@ def partial_rechunk(
key,
rechunk_transfer,
input_key,
partial_token,
ShuffleId(partial_token),
partial_index,
partial_new,
partial_old,
disk,
)
transfer_keys.append(t.ref())

dsk[_barrier_key] = barrier = Task(
_barrier_key, p2p_barrier, partial_token, transfer_keys
dsk[_barrier_key] = barrier = P2PBarrierTask(
_barrier_key,
p2p_barrier,
partial_token,
transfer_keys,
spec=ArrayRechunkSpec(
id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk
),
)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
Expand All @@ -835,7 +836,7 @@ def partial_rechunk(
dsk[k] = Task(
k,
rechunk_unpack,
partial_token,
ShuffleId(partial_token),
partial_index,
barrier.ref(),
)
Expand Down
Loading

0 comments on commit f340f18

Please sign in to comment.