Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-merge branch-23.10 to branch-23.12 #1245

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
) -> List[DataFrame]:
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task

This function is running on each worker participating in the shuffle.
Expand Down Expand Up @@ -360,8 +360,8 @@ async def shuffle_task(

Returns
-------
partitions: list of DataFrames
List of dataframe-partitions
partitions: dict
dict that maps each Partition ID to a dataframe-partition
"""

proxify = get_proxify(s["worker"])
Expand All @@ -387,14 +387,13 @@ async def shuffle_task(
)

# Finally, we concatenate the output dataframes into the final output partitions
ret = []
ret = {}
while out_part_id_to_dataframe_list:
ret.append(
proxify(
dd_concat(
out_part_id_to_dataframe_list.popitem()[1],
ignore_index=ignore_index,
)
part_id, dataframe_list = out_part_id_to_dataframe_list.popitem()
ret[part_id] = proxify(
dd_concat(
dataframe_list,
ignore_index=ignore_index,
)
)
# For robustness, we yield this task to give Dask a chance to do bookkeeping
Expand Down Expand Up @@ -529,9 +528,12 @@ def shuffle(

dsk = {}
for rank in ranks:
for i, part_id in enumerate(rank_to_out_part_ids[rank]):
for part_id in rank_to_out_part_ids[rank]:
dsk[(name, part_id)] = c.client.submit(
getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]]
getitem,
shuffle_result[rank],
part_id,
workers=[c.worker_addresses[rank]],
)

# Create a distributed Dataframe from all the pieces
Expand Down
43 changes: 29 additions & 14 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_partitions(df, npartitions):
return True


def _test_dataframe_shuffle(backend, protocol, n_workers):
def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

Expand All @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
df = cudf.DataFrame.from_pandas(df)

if _partitions:
df["_partitions"] = 0

for input_nparts in range(1, 5):
for output_nparts in range(1, 5):
ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist(
Expand All @@ -123,33 +126,45 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
with dask.config.set(explicit_comms_batchsize=batchsize):
ddf = explicit_comms_shuffle(
ddf,
["key"],
["_partitions"] if _partitions else ["key"],
npartitions=output_nparts,
batchsize=batchsize,
).persist()

assert ddf.npartitions == output_nparts

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)
if _partitions:
# If "_partitions" is the hash key, we expect all but
# the first partition to be empty
assert_eq(ddf.partitions[0].compute(), df)
assert all(
len(ddf.partitions[i].compute()) == 0
for i in range(1, ddf.npartitions)
)
else:
# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_dataframe_shuffle(backend, protocol, nworkers):
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
pytest.importorskip("cudf")

p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers))
p = mp.Process(
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
)
p.start()
p.join()
assert not p.exitcode
Expand Down
Loading