Skip to content

Commit

Permalink
basic functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 1, 2023
1 parent 9627f27 commit 7de1bfc
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,48 @@
from dask.dataframe import DataFrame


from distributed.protocol import dask_serialize
from distributed.protocol import dask_serialize, dask_deserialize
import pickle

class DelayedShufflePartition:
def __init__(self, shuffle_run: DataFrameShuffleRun | None, partition_id: int):
def __init__(
self,
shuffle_run: DataFrameShuffleRun | None,
partition_id: int | None,
data = None,
):
self.shuffle_run = shuffle_run
self.partition_id = partition_id
self._data = None
self._data = data

@property
def data(self):
if self._data is None:
assert self.partition_id is not None
assert self.shuffle_run is not None
try:
data = self.shuffle_run._read_from_disk((self.partition_id,))
self._data = convert_shards(data, self.meta)
self._data = convert_shards(data, self.shuffle_run.meta)
except KeyError:
self._data = self.shuffle_run.meta.copy()
return self._data

def serialize(self):
return dask_serialize(self.data)
def __reduce__(self):
# we return a tuple of class_name to call,
# and optional parameters to pass when re-creating
return (self.__class__, (None, None, self.data))


@dask_serialize.register(DelayedShufflePartition)
def serialize(obj: DelayedShufflePartition):
return dask_serialize(obj.data)
def _serialize(obj: DelayedShufflePartition):
return None, [pickle.dumps(obj.data)]

def _get_partition_data(part):
@dask_deserialize.register(DelayedShufflePartition)
def _deserialize(header, frames):
return DelayedShufflePartition(None, None, pickle.loads(frames[0]))


def _get_partition_data(part, barrier_key):
if isinstance(part, DelayedShufflePartition):
return part.data
return part
Expand Down Expand Up @@ -333,6 +349,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
dsk[(name, part_out)] = (
_get_partition_data,
(name_lazy, part_out),
_barrier_key,
)
return dsk

Expand Down

0 comments on commit 7de1bfc

Please sign in to comment.