From c730903b6ea24822d8ffdc2a6bb4577d4ad6d489 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 7 Dec 2023 12:58:26 -0600 Subject: [PATCH] back to simple load, but leave space for future dispatching --- distributed/shuffle/_core.py | 31 ++++++++++++++++++++++++------- distributed/shuffle/_shuffle.py | 4 ++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 27718bffbc..dbae956e16 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -302,13 +302,14 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non self.raise_if_closed() sync(self._loop, self._ensure_output_worker, partition_id, key) if not self.transferred: - raise RuntimeError("`get_output_partition` called before barrier task") + raise RuntimeError("`_sync_output_partition` called before barrier task") sync(self._loop, self.flush_receive) def get_output_partition( self, partition_id: _T_partition_id, key: Key, **kwargs: Any ) -> _T_partition_type: - self._sync_output_partition(partition_id, key) + if kwargs.pop("sync", True): + self._sync_output_partition(partition_id, key) return self._get_output_partition(partition_id, key, **kwargs) @abc.abstractmethod @@ -504,24 +505,39 @@ def __init__( self.key = key self.kwargs = kwargs + def pre_serialize(self) -> Any: + """Make the unloaded partition serializable""" + # TODO: Add mechanism to dispatch on meta. + # Right now, serializing an UnloadedPartition object + # will convert it to `type(self.shuffle_run.meta)`. + # However, it may be beneficial to futher delay the + # use of GPU memory for cudf/cupy-based data. + return self.load() + def load(self) -> Any: + """Load the shuffle output partition into memory""" with handle_unpack_errors(self.shuffle_run.id): - return self.shuffle_run._get_output_partition( - self.partition_id, self.key, **self.kwargs + return self.shuffle_run.get_output_partition( + self.partition_id, + self.key, + # We need sync=False, because `_sync_output_partition` + # was already called for the current shuffle run + sync=False, + **self.kwargs, ) @dask_serialize.register(UnloadedPartition) -def _serialize_unloaded(obj: UnloadedPartition) -> tuple[None, list[bytes]]: +def _serialize_unloaded(obj): # Convert to LoadedPartition before serializing. Note that # we don't convert all the way to DataFrame, because this # adds unnecessary overhead and memory pressure for the # cudf backend (and minor overhead for pandas) - return None, [pickle.dumps(obj.load())] + return None, [pickle.dumps(obj.pre_serialize())] @dask_deserialize.register(UnloadedPartition) -def _deserialize_unloaded(header: None, frames: list[bytes]) -> Any: +def _deserialize_unloaded(header, frames): return pickle.loads(frames[0]) @@ -534,4 +550,5 @@ def load_output_partition( assert barrier_key if isinstance(data, UnloadedPartition): data = data.load() + assert not isinstance(data, UnloadedPartition) return data diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index aac566b658..c78944d752 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -91,7 +91,7 @@ def shuffle_unpack( ) -def shuffle_unpack_unloaded( +def delayed_shuffle_unpack( id: ShuffleId, output_partition: int, barrier_run_id: int ) -> pd.DataFrame: with handle_unpack_errors(id): @@ -305,7 +305,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: name = self.name for part_out in self.parts_out: dsk[(name, part_out)] = ( - shuffle_unpack_unloaded, + delayed_shuffle_unpack, token, part_out, _barrier_key,