diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 007992b4d2b..c42f6d94426 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -2,11 +2,12 @@ import logging import os +import pickle from collections import defaultdict from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from functools import partial, cached_property +from functools import cached_property, partial from pathlib import Path from typing import TYPE_CHECKING, Any @@ -21,7 +22,7 @@ from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.protocol import dask_serialize, dask_deserialize +from distributed.protocol import dask_deserialize, dask_serialize from distributed.shuffle._arrow import ( check_dtype_support, check_minimal_arrow_version, @@ -89,27 +90,27 @@ def __init__( self.partition_id = partition_id def load(self) -> LoadedPartition: - with handle_unpack_errors(self.partition_id): + with handle_unpack_errors(self.shuffle_run.id): try: data = self.shuffle_run._read_from_disk((self.partition_id,)) except KeyError: data = None - return LoadedPartition(data, self.shuffle_run.meta, self.partition_id) + return LoadedPartition(data, self.shuffle_run.meta, self.shuffle_run.id) class LoadedPartition: def __init__( self, - data: list[pa.Table], + data: list[pa.Table] | None, meta: pd.DataFrame, - partition_id: int, + shuffle_id: ShuffleId, ): self.data = data self.meta = meta - self.partition_id = partition_id + self.shuffle_id = shuffle_id def convert(self) -> pd.DataFrame: - with handle_unpack_errors(self.partition_id): + with handle_unpack_errors(self.shuffle_id): if self.data is None: data = self.meta.copy() else: @@ -118,35 +119,36 @@ def convert(self) -> pd.DataFrame: @dask_serialize.register(UnloadedPartition) -def _serialize_unloaded(obj: UnloadedPartition): - import pickle - - # Convert to LoadedPartition when serialized. Note that +def _serialize_unloaded(obj: UnloadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]: + # Convert to LoadedPartition before serializing. Note that # we don't convert all the way to DataFrame, because this # adds unnecessary overhead and memory pressure for the # cudf backend (and minor overhead for pandas) loaded = obj.load() - return (loaded.partition_id,), [pickle.dumps(loaded.meta), pickle.dumps(loaded.data)] + return (loaded.shuffle_id,), [ + pickle.dumps(loaded.meta), + pickle.dumps(loaded.data), + ] @dask_serialize.register(LoadedPartition) -def _serialize_loaded(obj: LoadedPartition): - import pickle - - return (obj.partition_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)] +def _serialize_loaded(obj: LoadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]: + return (obj.shuffle_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)] @dask_deserialize.register((UnloadedPartition, LoadedPartition)) -def _deserialize_loaded(header, frames): - import pickle - - partition_id = header[0] +def _deserialize_loaded( + header: tuple[ShuffleId], frames: list[bytes] +) -> LoadedPartition: + shuffle_id = header[0] meta = pickle.loads(frames[0]) data = pickle.loads(frames[1]) - return LoadedPartition(data, meta, partition_id) + return LoadedPartition(data, meta, shuffle_id) -def _get_partition_data(part, barrier_key): +def _get_partition_data( + part: UnloadedPartition | LoadedPartition | pd.DataFrame, barrier_key: int +) -> pd.DataFrame: # Used by rearrange_by_column_p2p to "unwrap" # UnloadedPartition/LoadedPartition data after # a P2PShuffleLayer @@ -369,7 +371,7 @@ def cull( return self, culled_deps @cached_property - def _tokens(self): + def _tokens(self) -> tuple[str, str]: token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out) _barrier_key = barrier_key(ShuffleId(token)) return token, _barrier_key