diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 2a624dc5af0..2eebed75e1e 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -56,32 +56,48 @@ from dask.dataframe import DataFrame -from distributed.protocol import dask_serialize +from distributed.protocol import dask_serialize, dask_deserialize +import pickle class DelayedShufflePartition: - def __init__(self, shuffle_run: DataFrameShuffleRun | None, partition_id: int): + def __init__( + self, + shuffle_run: DataFrameShuffleRun | None, + partition_id: int | None, + data = None, + ): self.shuffle_run = shuffle_run self.partition_id = partition_id - self._data = None + self._data = data @property def data(self): if self._data is None: + assert self.partition_id is not None + assert self.shuffle_run is not None try: data = self.shuffle_run._read_from_disk((self.partition_id,)) - self._data = convert_shards(data, self.meta) + self._data = convert_shards(data, self.shuffle_run.meta) except KeyError: self._data = self.shuffle_run.meta.copy() return self._data - def serialize(self): - return dask_serialize(self.data) + def __reduce__(self): + # we return a tuple of class_name to call, + # and optional parameters to pass when re-creating + return (self.__class__, (None, None, self.data)) + @dask_serialize.register(DelayedShufflePartition) -def serialize(obj: DelayedShufflePartition): - return dask_serialize(obj.data) +def _serialize(obj: DelayedShufflePartition): + return None, [pickle.dumps(obj.data)] -def _get_partition_data(part): +@dask_deserialize.register(DelayedShufflePartition) +def _deserialize(header, frames): + return DelayedShufflePartition(None, None, pickle.loads(frames[0])) + + +def _get_partition_data(part, barrier_key): if isinstance(part, DelayedShufflePartition): return part.data return part @@ -333,6 +349,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: dsk[(name, part_out)] = ( _get_partition_data, (name_lazy, part_out), + _barrier_key, ) return dsk