Skip to content

Commit

Permalink
breaking GetPartition into load and convert stages
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 4, 2023
1 parent 07ba167 commit 389182f
Showing 1 changed file with 85 additions and 37 deletions.
122 changes: 85 additions & 37 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import dask_serialize, dask_deserialize
from distributed.shuffle._arrow import (
check_dtype_support,
check_minimal_arrow_version,
Expand Down Expand Up @@ -56,56 +57,108 @@
from dask.dataframe import DataFrame


from distributed.protocol import dask_serialize, dask_deserialize
import pickle

class DelayedShufflePartition:
class GetPartition:
"""Wrap/Delay partition-loading logic
The purpose of this class is to keep a shuffled partition
on disk until it is needed by one of its dependent tasks.
Otherwise, the in-memory partition may need to be spilled
back to disk before the dependent task is executed anyway.
If ``shuffle_unpack`` returns a ``GetPartition`` object,
``P2PShuffleLayer`` must be followed by an extra ``Blockwise``
call to ``_get_partition_data`` (to unwrap the data). We want
an extra ``Blockwise`` layer here so that the unwrap step can
be fused into down-stream tasks. We do NOT want the original
``shuffle_unpack`` tasks to be fused into dependent tasks,
because this would prevent effective load balancing after the
shuffle (long-running post-shuffle tasks may be pinned to
specific workers, while others sit idle).
````
"""
def __init__(
self,
shuffle_run: DataFrameShuffleRun | None,
partition_id: int | None,
data = None,
partition_id: int,
empty: bool = False,
meta: Any = None,
loaded_data: Any = None,
converted_data: Any = None,
):
self.shuffle_run = shuffle_run
self.partition_id = partition_id
self._data = data
self.empty = empty
self._meta = meta
self._loaded_data = loaded_data
self._converted_data = converted_data

def load_data(self):
if self._loaded_data is None and not self.empty:
with handle_unpack_errors(self.partition_id):
try:
self._loaded_data = self.shuffle_run._read_from_disk((self.partition_id,))
except KeyError:
self.empty = True

def convert_data(self):
if self._converted_data is None:
if self.empty:
self._converted_data = self.meta.copy()
else:
self.load_data()
with handle_unpack_errors(self.partition_id):
self._converted_data = convert_shards(self._loaded_data, self.meta)
self._loaded_data = None # No longer needed

@property
def meta(self):
if self.shuffle_run is not None:
return self.shuffle_run.meta
return self._meta

@property
def data(self):
if self._data is None:
assert self.partition_id is not None
assert self.shuffle_run is not None
try:
data = self.shuffle_run._read_from_disk((self.partition_id,))
self._data = convert_shards(data, self.shuffle_run.meta)
except KeyError:
self._data = self.shuffle_run.meta.copy()
return self._data
if self._converted_data is None:
self.convert_data()
return self._converted_data

def __reduce__(self):
# we return a tuple of class_name to call,
# and optional parameters to pass when re-creating
return (self.__class__, (None, None, self.data))


@dask_serialize.register(DelayedShufflePartition)
def _serialize(obj: DelayedShufflePartition):
return None, [pickle.dumps(obj.data)]
# Always load data when this class is serialized
return (
self.__class__, (
None,
self.partition_id,
self.empty,
self.meta,
self._loaded_data,
self._converted_data,
)
)

@dask_deserialize.register(DelayedShufflePartition)
def _deserialize(header, frames):
return DelayedShufflePartition(None, None, pickle.loads(frames[0]))

@dask_serialize.register(GetPartition)
def _serialize_get_partition(obj: GetPartition):
import pickle

def _get_partition_data(part, barrier_key):
if isinstance(part, DelayedShufflePartition):
return part.data
return part
obj.load_data()
return (obj.partition_id, obj.empty), [pickle.dumps(obj.meta), pickle.dumps(obj._loaded_data)]


@dask_deserialize.register(GetPartition)
def _deserialize_get_partition(header, frames):
import pickle

partition_id, empty = header[:2]
meta = pickle.loads(frames[0])
loaded_data = pickle.loads(frames[1])
return GetPartition(None, partition_id, empty, meta, loaded_data, None)


def _get_partition_data(part, barrier_key):
assert barrier_key
if isinstance(part, GetPartition):
return part.data
return part


def shuffle_transfer(
Expand Down Expand Up @@ -564,12 +617,7 @@ def _get_output_partition(
key: Key,
**kwargs: Any,
) -> pd.DataFrame:
return DelayedShufflePartition(self, partition_id)
# try:
# data = self._read_from_disk((partition_id,))
# return convert_shards(data, self.meta)
# except KeyError:
# return self.meta.copy()
return GetPartition(self, partition_id)

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down

0 comments on commit 389182f

Please sign in to comment.