Skip to content

Commit

Permalink
enable cudf backend
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed May 8, 2024
1 parent 95d5377 commit f191770
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 6 additions & 0 deletions distributed/shuffle/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def pickle_dataframe_shard(
Parameters:
obj: pandas
"""
if hasattr(shard, "to_pandas"):
# Handle cudf-backed data
shard = shard.to_pandas(nullable=True)
return pickle_bytelist(
(input_part_id, shard.index, *shard._mgr.blocks), prelude=False
)
Expand Down Expand Up @@ -104,6 +107,9 @@ def unpickle_and_concat_dataframe_shards(

# Actually load memory-mapped buffers into memory and close the file
# descriptors
if hasattr(type(meta), "from_pandas"):
# Handle cudf-backed data
return type(meta).from_pandas(dd.methods.concat(shards, copy=True))
return dd.methods.concat(shards, copy=True)


Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def split_by_worker(
"""
out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list)

base = df[column].values.base # type: ignore[union-attr]
base = df[column].values.base or [] # type: ignore[union-attr]
for output_part_id, part in df.groupby(column, observed=True):
assert isinstance(output_part_id, int)
if output_part_id not in worker_for:
Expand Down

0 comments on commit f191770

Please sign in to comment.