Skip to content

Commit

Permalink
back to simple load, but leave space for future dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 7, 2023
1 parent 61a3eb2 commit c730903
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
31 changes: 24 additions & 7 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,14 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
raise RuntimeError("`_sync_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)

def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self._sync_output_partition(partition_id, key)
if kwargs.pop("sync", True):
self._sync_output_partition(partition_id, key)
return self._get_output_partition(partition_id, key, **kwargs)

@abc.abstractmethod
Expand Down Expand Up @@ -504,24 +505,39 @@ def __init__(
self.key = key
self.kwargs = kwargs

def pre_serialize(self) -> Any:
"""Make the unloaded partition serializable"""
# TODO: Add mechanism to dispatch on meta.
# Right now, serializing an UnloadedPartition object
# will convert it to `type(self.shuffle_run.meta)`.
# However, it may be beneficial to futher delay the
# use of GPU memory for cudf/cupy-based data.
return self.load()

def load(self) -> Any:
"""Load the shuffle output partition into memory"""
with handle_unpack_errors(self.shuffle_run.id):
return self.shuffle_run._get_output_partition(
self.partition_id, self.key, **self.kwargs
return self.shuffle_run.get_output_partition(
self.partition_id,
self.key,
# We need sync=False, because `_sync_output_partition`
# was already called for the current shuffle run
sync=False,
**self.kwargs,
)


@dask_serialize.register(UnloadedPartition)
def _serialize_unloaded(obj: UnloadedPartition) -> tuple[None, list[bytes]]:
def _serialize_unloaded(obj):
# Convert to LoadedPartition before serializing. Note that
# we don't convert all the way to DataFrame, because this
# adds unnecessary overhead and memory pressure for the
# cudf backend (and minor overhead for pandas)
return None, [pickle.dumps(obj.load())]
return None, [pickle.dumps(obj.pre_serialize())]


@dask_deserialize.register(UnloadedPartition)
def _deserialize_unloaded(header: None, frames: list[bytes]) -> Any:
def _deserialize_unloaded(header, frames):
return pickle.loads(frames[0])


Expand All @@ -534,4 +550,5 @@ def load_output_partition(
assert barrier_key
if isinstance(data, UnloadedPartition):
data = data.load()
assert not isinstance(data, UnloadedPartition)
return data
4 changes: 2 additions & 2 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def shuffle_unpack(
)


def shuffle_unpack_unloaded(
def delayed_shuffle_unpack(
id: ShuffleId, output_partition: int, barrier_run_id: int
) -> pd.DataFrame:
with handle_unpack_errors(id):
Expand Down Expand Up @@ -305,7 +305,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
shuffle_unpack_unloaded,
delayed_shuffle_unpack,
token,
part_out,
_barrier_key,
Expand Down

0 comments on commit c730903

Please sign in to comment.