Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit-comms: preserve partition IDs #1240

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def shuffle_task(
ignore_index: bool,
num_rounds: int,
batchsize: int,
) -> List[DataFrame]:
) -> Dict[int, DataFrame]:
"""Explicit-comms shuffle task

This function is running on each worker participating in the shuffle.
Expand Down Expand Up @@ -360,8 +360,8 @@ async def shuffle_task(

Returns
-------
partitions: list of DataFrames
List of dataframe-partitions
partitions: dict
dict that maps each Partition ID to a dataframe-partition
"""

proxify = get_proxify(s["worker"])
Expand All @@ -387,14 +387,13 @@ async def shuffle_task(
)

# Finally, we concatenate the output dataframes into the final output partitions
ret = []
ret = {}
while out_part_id_to_dataframe_list:
ret.append(
proxify(
dd_concat(
out_part_id_to_dataframe_list.popitem()[1],
ignore_index=ignore_index,
)
part_id, dataframe_list = out_part_id_to_dataframe_list.popitem()
ret[part_id] = proxify(
dd_concat(
dataframe_list,
ignore_index=ignore_index,
)
)
# For robustness, we yield this task to give Dask a chance to do bookkeeping
Expand Down Expand Up @@ -529,9 +528,12 @@ def shuffle(

dsk = {}
for rank in ranks:
for i, part_id in enumerate(rank_to_out_part_ids[rank]):
for part_id in rank_to_out_part_ids[rank]:
dsk[(name, part_id)] = c.client.submit(
getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]]
getitem,
shuffle_result[rank],
part_id,
workers=[c.worker_addresses[rank]],
)

# Create a distributed Dataframe from all the pieces
Expand Down
42 changes: 29 additions & 13 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_partitions(df, npartitions):
return True


def _test_dataframe_shuffle(backend, protocol, n_workers):
def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

Expand All @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
df = cudf.DataFrame.from_pandas(df)

if _partitions:
df["_partitions"] = 0

for input_nparts in range(1, 5):
for output_nparts in range(1, 5):
ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist(
Expand All @@ -123,33 +126,46 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
with dask.config.set(explicit_comms_batchsize=batchsize):
ddf = explicit_comms_shuffle(
ddf,
["key"],
["_partitions"] if _partitions else ["key"],
npartitions=output_nparts,
batchsize=batchsize,
).persist()

assert ddf.npartitions == output_nparts

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())
if _partitions:
# If "_partitions" is the hash key, we except all but
pentschev marked this conversation as resolved.
Show resolved Hide resolved
# the first partition to be empty
assert_eq(ddf.partitions[0].compute(), df)
assert all(
len(ddf.partitions[i].compute()) == 0
for i in range(1, ddf.npartitions)
)
else:

madsbk marked this conversation as resolved.
Show resolved Hide resolved
# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)
# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_dataframe_shuffle(backend, protocol, nworkers):
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
pytest.importorskip("cudf")

p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers))
p = mp.Process(
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
)
p.start()
p.join()
assert not p.exitcode
Expand Down