Skip to content

Commit

Permalink
experimenting
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 1, 2023
1 parent 2f04dcb commit a6cbdcd
Showing 1 changed file with 52 additions and 7 deletions.
59 changes: 52 additions & 7 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,42 @@
from dask.dataframe import DataFrame


from distributed.protocol import dask_serialize

class DelayedShufflePartition:
def __init__(self, shuffle_run: DataFrameShuffleRun | None, partition_id: int):
self.shuffle_run = shuffle_run
self.partition_id = partition_id
self._data = None

@property
def data(self):
if self._data is None:
try:
data = self.shuffle_run._read_from_disk((self.partition_id,))
self._data = convert_shards(data, self.meta)
except KeyError:
self._data = self.shuffle_run.meta.copy()
return self._data

def serialize(self):
return dask_serialize(self.data)

@dask_serialize.register(DelayedShufflePartition)
def serialize(obj: DelayedShufflePartition):
return dask_serialize(obj.data)

def _get_partition_data(part):
if isinstance(part, DelayedShufflePartition):
return part.data
return part







def shuffle_transfer(
input: pd.DataFrame,
id: ShuffleId,
Expand Down Expand Up @@ -282,14 +318,22 @@ def _construct_graph(self) -> _T_LowLevelGraph:

dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys)

name = self.name
name_lazy = f"lazy-{self.name}"
for part_out in self.parts_out:
dsk[(name, part_out)] = (
dsk[(name_lazy, part_out)] = (
shuffle_unpack,
token,
part_out,
_barrier_key,
)

# TODO: Do this in a Blockwise layer after the shuffle
name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
_get_partition_data,
(name_lazy, part_out),
)
return dsk


Expand Down Expand Up @@ -501,11 +545,12 @@ def _get_output_partition(
key: Key,
**kwargs: Any,
) -> pd.DataFrame:
try:
data = self._read_from_disk((partition_id,))
return convert_shards(data, self.meta)
except KeyError:
return self.meta.copy()
return DelayedShufflePartition(self, partition_id)
# try:
# data = self._read_from_disk((partition_id,))
# return convert_shards(data, self.meta)
# except KeyError:
# return self.meta.copy()

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down

0 comments on commit a6cbdcd

Please sign in to comment.