diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py index 9226bd7d3b8..8f393b7ddf9 100644 --- a/distributed/shuffle/_pickle.py +++ b/distributed/shuffle/_pickle.py @@ -57,6 +57,9 @@ def pickle_dataframe_shard( Parameters: obj: pandas """ + if hasattr(shard, "to_pandas"): + # Handle cudf-backed data + shard = shard.to_pandas(nullable=True) return pickle_bytelist( (input_part_id, shard.index, *shard._mgr.blocks), prelude=False ) @@ -104,6 +107,9 @@ def unpickle_and_concat_dataframe_shards( # Actually load memory-mapped buffers into memory and close the file # descriptors + if hasattr(type(meta), "from_pandas"): + # Handle cudf-backed data + return type(meta).from_pandas(dd.methods.concat(shards, copy=True)) return dd.methods.concat(shards, copy=True) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index fe74263a6ff..e6dd30660e1 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -332,7 +332,7 @@ def split_by_worker( """ out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list) - base = df[column].values.base # type: ignore[union-attr] + base = df[column].values.base or [] # type: ignore[union-attr] for output_part_id, part in df.groupby(column, observed=True): assert isinstance(output_part_id, int) if output_part_id not in worker_for: