Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 4, 2023
1 parent b7a28f0 commit 69a89a2
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import logging
import os
import pickle
from collections import defaultdict
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial, cached_property
from functools import cached_property, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -21,7 +22,7 @@

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import dask_serialize, dask_deserialize
from distributed.protocol import dask_deserialize, dask_serialize
from distributed.shuffle._arrow import (
check_dtype_support,
check_minimal_arrow_version,
Expand Down Expand Up @@ -89,27 +90,27 @@ def __init__(
self.partition_id = partition_id

def load(self) -> LoadedPartition:
with handle_unpack_errors(self.partition_id):
with handle_unpack_errors(self.shuffle_run.id):
try:
data = self.shuffle_run._read_from_disk((self.partition_id,))
except KeyError:
data = None
return LoadedPartition(data, self.shuffle_run.meta, self.partition_id)
return LoadedPartition(data, self.shuffle_run.meta, self.shuffle_run.id)


class LoadedPartition:
def __init__(
self,
data: list[pa.Table],
data: list[pa.Table] | None,
meta: pd.DataFrame,
partition_id: int,
shuffle_id: ShuffleId,
):
self.data = data
self.meta = meta
self.partition_id = partition_id
self.shuffle_id = shuffle_id

def convert(self) -> pd.DataFrame:
with handle_unpack_errors(self.partition_id):
with handle_unpack_errors(self.shuffle_id):
if self.data is None:
data = self.meta.copy()
else:
Expand All @@ -118,35 +119,36 @@ def convert(self) -> pd.DataFrame:


@dask_serialize.register(UnloadedPartition)
def _serialize_unloaded(obj: UnloadedPartition):
import pickle

# Convert to LoadedPartition when serialized. Note that
def _serialize_unloaded(obj: UnloadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]:
# Convert to LoadedPartition before serializing. Note that
# we don't convert all the way to DataFrame, because this
# adds unnecessary overhead and memory pressure for the
# cudf backend (and minor overhead for pandas)
loaded = obj.load()
return (loaded.partition_id,), [pickle.dumps(loaded.meta), pickle.dumps(loaded.data)]
return (loaded.shuffle_id,), [
pickle.dumps(loaded.meta),
pickle.dumps(loaded.data),
]


@dask_serialize.register(LoadedPartition)
def _serialize_loaded(obj: LoadedPartition):
import pickle

return (obj.partition_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)]
def _serialize_loaded(obj: LoadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]:
return (obj.shuffle_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)]


@dask_deserialize.register((UnloadedPartition, LoadedPartition))
def _deserialize_loaded(header, frames):
import pickle

partition_id = header[0]
def _deserialize_loaded(
header: tuple[ShuffleId], frames: list[bytes]
) -> LoadedPartition:
shuffle_id = header[0]
meta = pickle.loads(frames[0])
data = pickle.loads(frames[1])
return LoadedPartition(data, meta, partition_id)
return LoadedPartition(data, meta, shuffle_id)


def _get_partition_data(part, barrier_key):
def _get_partition_data(
part: UnloadedPartition | LoadedPartition | pd.DataFrame, barrier_key: int
) -> pd.DataFrame:
# Used by rearrange_by_column_p2p to "unwrap"
# UnloadedPartition/LoadedPartition data after
# a P2PShuffleLayer
Expand Down Expand Up @@ -369,7 +371,7 @@ def cull(
return self, culled_deps

@cached_property
def _tokens(self):
def _tokens(self) -> tuple[str, str]:
token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out)
_barrier_key = barrier_key(ShuffleId(token))
return token, _barrier_key
Expand Down

0 comments on commit 69a89a2

Please sign in to comment.