diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index c14fb1dd0da..2a624dc5af0 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -56,6 +56,42 @@ from dask.dataframe import DataFrame +from distributed.protocol import dask_serialize + +class DelayedShufflePartition: + def __init__(self, shuffle_run: DataFrameShuffleRun | None, partition_id: int): + self.shuffle_run = shuffle_run + self.partition_id = partition_id + self._data = None + + @property + def data(self): + if self._data is None: + try: + data = self.shuffle_run._read_from_disk((self.partition_id,)) + self._data = convert_shards(data, self.meta) + except KeyError: + self._data = self.shuffle_run.meta.copy() + return self._data + + def serialize(self): + return dask_serialize(self.data) + +@dask_serialize.register(DelayedShufflePartition) +def serialize(obj: DelayedShufflePartition): + return dask_serialize(obj.data) + +def _get_partition_data(part): + if isinstance(part, DelayedShufflePartition): + return part.data + return part + + + + + + + def shuffle_transfer( input: pd.DataFrame, id: ShuffleId, @@ -282,14 +318,22 @@ def _construct_graph(self) -> _T_LowLevelGraph: dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys) - name = self.name + name_lazy = f"lazy-{self.name}" for part_out in self.parts_out: - dsk[(name, part_out)] = ( + dsk[(name_lazy, part_out)] = ( shuffle_unpack, token, part_out, _barrier_key, ) + + # TODO: Do this in a Blockwise layer after the shuffle + name = self.name + for part_out in self.parts_out: + dsk[(name, part_out)] = ( + _get_partition_data, + (name_lazy, part_out), + ) return dsk @@ -501,11 +545,12 @@ def _get_output_partition( key: Key, **kwargs: Any, ) -> pd.DataFrame: - try: - data = self._read_from_disk((partition_id,)) - return convert_shards(data, self.meta) - except KeyError: - return self.meta.copy() + return DelayedShufflePartition(self, partition_id) + # try: + # data = self._read_from_disk((partition_id,)) + # return convert_shards(data, self.meta) + # except KeyError: + # return self.meta.copy() def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id]