Skip to content

Commit

Permalink
try delaying conversion only
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 8, 2023
1 parent 4c16a25 commit cfe4ad5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 102 deletions.
96 changes: 5 additions & 91 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import dask_deserialize, dask_serialize, to_serialize
from distributed.protocol import to_serialize
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import ShuffleClosedError
Expand Down Expand Up @@ -308,8 +308,7 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
if kwargs.pop("sync", True):
self._sync_output_partition(partition_id, key)
self._sync_output_partition(partition_id, key)
return self._get_output_partition(partition_id, key, **kwargs)

@abc.abstractmethod
Expand All @@ -318,11 +317,11 @@ def _get_output_partition(
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

def get_unloaded_output_partition(
def get_raw_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> UnloadedPartition:
) -> Any:
self._sync_output_partition(partition_id, key)
return UnloadedPartition(self, partition_id, key, **kwargs)
return self._get_output_partition(partition_id, key, convert=False, **kwargs)

@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
Expand Down Expand Up @@ -467,88 +466,3 @@ def _mean_shard_size(shards: Iterable) -> int:
if count == 10:
break
return size // count if count else 0


class UnloadedPartition:
"""Wrap unloaded shuffle output
The purpose of this class is to keep a shuffled partition
on disk until it is needed by one of its dependent tasks.
Otherwise, the in-memory partition may need to be spilled
back to disk before the dependent task is executed anyway.
If the output tasks of a ``P2PShuffleLayer`` return objects
of type ``UnloadedPartition``, that layer must be followed
by an extra ``Blockwise`` call to ``load_output_partition``
(to ensure the partitions are actually loaded). We want this
extra layer to be ``Blockwise`` so that the loading can be
fused into down-stream tasks. We do NOT want the original
``shuffle_unpack`` tasks to be fused into dependent tasks,
because this would prevent load balancing after the shuffle
(long-running post-shuffle tasks may be pinned to specific
workers, while others sit idle).
Note that serialization automatically loads the wrapped
data, because the object may be moved to a worker that
doesn't have access to the same local storage.
"""

def __init__(
self,
shuffle_run: ShuffleRun,
partition_id: _T_partition_id,
key: Key,
**kwargs: Any,
):
self.shuffle_run = shuffle_run
self.partition_id = partition_id
self.key = key
self.kwargs = kwargs

def pre_serialize(self) -> Any:
"""Make the unloaded partition serializable"""
# TODO: Add mechanism to dispatch on meta.
# Right now, serializing an UnloadedPartition object
# will convert it to `type(self.shuffle_run.meta)`.
# However, it may be beneficial to futher delay the
# use of GPU memory for cudf/cupy-based data.
return self.load()

def load(self) -> Any:
"""Load the shuffle output partition into memory"""
with handle_unpack_errors(self.shuffle_run.id):
return self.shuffle_run.get_output_partition(
self.partition_id,
self.key,
# We need sync=False, because `_sync_output_partition`
# was already called for the current shuffle run
sync=False,
**self.kwargs,
)


@dask_serialize.register(UnloadedPartition)
def _serialize_unloaded(obj):
# Convert to LoadedPartition before serializing. Note that
# we don't convert all the way to DataFrame, because this
# adds unnecessary overhead and memory pressure for the
# cudf backend (and minor overhead for pandas)
return None, [pickle.dumps(obj.pre_serialize())]


@dask_deserialize.register(UnloadedPartition)
def _deserialize_unloaded(header, frames):
return pickle.loads(frames[0])


def load_output_partition(
data: UnloadedPartition | _T_partition_type, barrier_key: int
) -> _T_partition_type:
# Used by rearrange_by_column_p2p to "unwrap"
# UnloadedPartition/LoadedPartition data after
# a P2PShuffleLayer
assert barrier_key
if isinstance(data, UnloadedPartition):
data = data.load()
assert not isinstance(data, UnloadedPartition)
return data
20 changes: 12 additions & 8 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
load_output_partition,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
Expand Down Expand Up @@ -91,12 +90,12 @@ def shuffle_unpack(
)


def delayed_shuffle_unpack(
def shuffle_unpack_partial(
id: ShuffleId, output_partition: int, barrier_run_id: int
) -> pd.DataFrame:
with handle_unpack_errors(id):
return get_worker_plugin().get_output_partition(
id, barrier_run_id, output_partition, load=False
id, barrier_run_id, output_partition, convert=False
)


Expand Down Expand Up @@ -151,10 +150,10 @@ def rearrange_by_column_p2p(
meta,
[None] * (npartitions + 1),
).map_partitions(
load_output_partition,
layer._tokens[1],
partial(_convert_output_partition, meta=meta),
meta=meta,
enforce_metadata=False,
align_dataframes=False,
)


Expand Down Expand Up @@ -305,7 +304,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
delayed_shuffle_unpack,
shuffle_unpack_partial,
token,
part_out,
_barrier_key,
Expand Down Expand Up @@ -519,13 +518,14 @@ def _get_output_partition(
self,
partition_id: int,
key: Key,
convert: bool = True,
**kwargs: Any,
) -> pd.DataFrame:
try:
data = self._read_from_disk((partition_id,))
return convert_shards(data, self.meta)
return convert_shards(data, self.meta) if convert else data
except KeyError:
return self.meta.copy()
return self.meta.copy() if convert else None

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down Expand Up @@ -580,3 +580,7 @@ def _get_worker_for_range_sharding(
"""Get address of target worker for this output partition using range sharding"""
i = len(workers) * output_partition // npartitions
return workers[i]


def _convert_output_partition(data: pa.Table, meta: Any = None) -> pd.DataFrame:
return meta.copy() if data is None else convert_shards(data, meta)
6 changes: 3 additions & 3 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def get_output_partition(
run_id: int,
partition_id: int | NDIndex,
meta: pd.DataFrame | None = None,
load: bool = True,
convert: bool = True,
) -> Any:
"""
Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin.
Expand All @@ -426,13 +426,13 @@ def get_output_partition(
"""
shuffle_run = self.get_shuffle_run(shuffle_id, run_id)
key = thread_state.key
if load:
if convert:
return shuffle_run.get_output_partition(
partition_id=partition_id,
key=key,
meta=meta,
)
return shuffle_run.get_unloaded_output_partition(
return shuffle_run.get_raw_output_partition(
partition_id=partition_id,
key=key,
meta=meta,
Expand Down

0 comments on commit cfe4ad5

Please sign in to comment.