diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index dbae956e161..2867997229f 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -24,7 +24,7 @@ from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.protocol import dask_deserialize, dask_serialize, to_serialize +from distributed.protocol import to_serialize from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer from distributed.shuffle._exceptions import ShuffleClosedError @@ -308,8 +308,7 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non def get_output_partition( self, partition_id: _T_partition_id, key: Key, **kwargs: Any ) -> _T_partition_type: - if kwargs.pop("sync", True): - self._sync_output_partition(partition_id, key) + self._sync_output_partition(partition_id, key) return self._get_output_partition(partition_id, key, **kwargs) @abc.abstractmethod @@ -318,11 +317,11 @@ def _get_output_partition( ) -> _T_partition_type: """Get an output partition to the shuffle run""" - def get_unloaded_output_partition( + def get_raw_output_partition( self, partition_id: _T_partition_id, key: Key, **kwargs: Any - ) -> UnloadedPartition: + ) -> Any: self._sync_output_partition(partition_id, key) - return UnloadedPartition(self, partition_id, key, **kwargs) + return self._get_output_partition(partition_id, key, convert=False, **kwargs) @abc.abstractmethod def read(self, path: Path) -> tuple[Any, int]: @@ -467,88 +466,3 @@ def _mean_shard_size(shards: Iterable) -> int: if count == 10: break return size // count if count else 0 - - -class UnloadedPartition: - """Wrap unloaded shuffle output - - The purpose of this class is to keep a shuffled partition - on disk until it is needed by one of its dependent tasks. - Otherwise, the in-memory partition may need to be spilled - back to disk before the dependent task is executed anyway. - - If the output tasks of a ``P2PShuffleLayer`` return objects - of type ``UnloadedPartition``, that layer must be followed - by an extra ``Blockwise`` call to ``load_output_partition`` - (to ensure the partitions are actually loaded). We want this - extra layer to be ``Blockwise`` so that the loading can be - fused into down-stream tasks. We do NOT want the original - ``shuffle_unpack`` tasks to be fused into dependent tasks, - because this would prevent load balancing after the shuffle - (long-running post-shuffle tasks may be pinned to specific - workers, while others sit idle). - - Note that serialization automatically loads the wrapped - data, because the object may be moved to a worker that - doesn't have access to the same local storage. - """ - - def __init__( - self, - shuffle_run: ShuffleRun, - partition_id: _T_partition_id, - key: Key, - **kwargs: Any, - ): - self.shuffle_run = shuffle_run - self.partition_id = partition_id - self.key = key - self.kwargs = kwargs - - def pre_serialize(self) -> Any: - """Make the unloaded partition serializable""" - # TODO: Add mechanism to dispatch on meta. - # Right now, serializing an UnloadedPartition object - # will convert it to `type(self.shuffle_run.meta)`. - # However, it may be beneficial to futher delay the - # use of GPU memory for cudf/cupy-based data. - return self.load() - - def load(self) -> Any: - """Load the shuffle output partition into memory""" - with handle_unpack_errors(self.shuffle_run.id): - return self.shuffle_run.get_output_partition( - self.partition_id, - self.key, - # We need sync=False, because `_sync_output_partition` - # was already called for the current shuffle run - sync=False, - **self.kwargs, - ) - - -@dask_serialize.register(UnloadedPartition) -def _serialize_unloaded(obj): - # Convert to LoadedPartition before serializing. Note that - # we don't convert all the way to DataFrame, because this - # adds unnecessary overhead and memory pressure for the - # cudf backend (and minor overhead for pandas) - return None, [pickle.dumps(obj.pre_serialize())] - - -@dask_deserialize.register(UnloadedPartition) -def _deserialize_unloaded(header, frames): - return pickle.loads(frames[0]) - - -def load_output_partition( - data: UnloadedPartition | _T_partition_type, barrier_key: int -) -> _T_partition_type: - # Used by rearrange_by_column_p2p to "unwrap" - # UnloadedPartition/LoadedPartition data after - # a P2PShuffleLayer - assert barrier_key - if isinstance(data, UnloadedPartition): - data = data.load() - assert not isinstance(data, UnloadedPartition) - return data diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index c78944d752c..e4b8169cb3a 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -39,7 +39,6 @@ get_worker_plugin, handle_transfer_errors, handle_unpack_errors, - load_output_partition, ) from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin @@ -91,12 +90,12 @@ def shuffle_unpack( ) -def delayed_shuffle_unpack( +def shuffle_unpack_partial( id: ShuffleId, output_partition: int, barrier_run_id: int ) -> pd.DataFrame: with handle_unpack_errors(id): return get_worker_plugin().get_output_partition( - id, barrier_run_id, output_partition, load=False + id, barrier_run_id, output_partition, convert=False ) @@ -151,10 +150,10 @@ def rearrange_by_column_p2p( meta, [None] * (npartitions + 1), ).map_partitions( - load_output_partition, - layer._tokens[1], + partial(_convert_output_partition, meta=meta), meta=meta, enforce_metadata=False, + align_dataframes=False, ) @@ -305,7 +304,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: name = self.name for part_out in self.parts_out: dsk[(name, part_out)] = ( - delayed_shuffle_unpack, + shuffle_unpack_partial, token, part_out, _barrier_key, @@ -519,13 +518,14 @@ def _get_output_partition( self, partition_id: int, key: Key, + convert: bool = True, **kwargs: Any, ) -> pd.DataFrame: try: data = self._read_from_disk((partition_id,)) - return convert_shards(data, self.meta) + return convert_shards(data, self.meta) if convert else data except KeyError: - return self.meta.copy() + return self.meta.copy() if convert else None def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] @@ -580,3 +580,7 @@ def _get_worker_for_range_sharding( """Get address of target worker for this output partition using range sharding""" i = len(workers) * output_partition // npartitions return workers[i] + + +def _convert_output_partition(data: pa.Table, meta: Any = None) -> pd.DataFrame: + return meta.copy() if data is None else convert_shards(data, meta) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 3cbb6f594b3..a56dc0b0a47 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -417,7 +417,7 @@ def get_output_partition( run_id: int, partition_id: int | NDIndex, meta: pd.DataFrame | None = None, - load: bool = True, + convert: bool = True, ) -> Any: """ Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin. @@ -426,13 +426,13 @@ def get_output_partition( """ shuffle_run = self.get_shuffle_run(shuffle_id, run_id) key = thread_state.key - if load: + if convert: return shuffle_run.get_output_partition( partition_id=partition_id, key=key, meta=meta, ) - return shuffle_run.get_unloaded_output_partition( + return shuffle_run.get_raw_output_partition( partition_id=partition_id, key=key, meta=meta,