From 4aa35aa5adcd9a2064b39bd214666508e5c66154 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 26 Sep 2023 10:32:57 +0200 Subject: [PATCH 1/3] shuffle_task() now returns a dict mapping partition IDs to dataframes --- dask_cuda/explicit_comms/dataframe/shuffle.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 0ca1c48e..854115fe 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -328,7 +328,7 @@ async def shuffle_task( ignore_index: bool, num_rounds: int, batchsize: int, -) -> List[DataFrame]: +) -> Dict[int, DataFrame]: """Explicit-comms shuffle task This function is running on each worker participating in the shuffle. @@ -360,8 +360,8 @@ async def shuffle_task( Returns ------- - partitions: list of DataFrames - List of dataframe-partitions + partitions: dict + dict that maps each Partition ID to a dataframe-partition """ proxify = get_proxify(s["worker"]) @@ -387,14 +387,13 @@ async def shuffle_task( ) # Finally, we concatenate the output dataframes into the final output partitions - ret = [] + ret = {} while out_part_id_to_dataframe_list: - ret.append( - proxify( - dd_concat( - out_part_id_to_dataframe_list.popitem()[1], - ignore_index=ignore_index, - ) + part_id, dataframe_list = out_part_id_to_dataframe_list.popitem() + ret[part_id] = proxify( + dd_concat( + dataframe_list, + ignore_index=ignore_index, ) ) # For robustness, we yield this task to give Dask a chance to do bookkeeping @@ -529,9 +528,12 @@ def shuffle( dsk = {} for rank in ranks: - for i, part_id in enumerate(rank_to_out_part_ids[rank]): + for part_id in rank_to_out_part_ids[rank]: dsk[(name, part_id)] = c.client.submit( - getitem, shuffle_result[rank], i, workers=[c.worker_addresses[rank]] + getitem, + shuffle_result[rank], + part_id, + workers=[c.worker_addresses[rank]], ) # Create a distributed Dataframe from all the pieces From c6950a900e9068abaae3528b227bdf512bd97f1d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 26 Sep 2023 12:51:25 +0200 Subject: [PATCH 2/3] test _partitions --- dask_cuda/tests/test_explicit_comms.py | 42 ++++++++++++++++++-------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 1a15370b..49bd7058 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -93,7 +93,7 @@ def check_partitions(df, npartitions): return True -def _test_dataframe_shuffle(backend, protocol, n_workers): +def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): if backend == "cudf": cudf = pytest.importorskip("cudf") @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): if backend == "cudf": df = cudf.DataFrame.from_pandas(df) + if _partitions: + df["_partitions"] = 0 + for input_nparts in range(1, 5): for output_nparts in range(1, 5): ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist( @@ -123,33 +126,46 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): with dask.config.set(explicit_comms_batchsize=batchsize): ddf = explicit_comms_shuffle( ddf, - ["key"], + ["_partitions"] if _partitions else ["key"], npartitions=output_nparts, batchsize=batchsize, ).persist() assert ddf.npartitions == output_nparts - # Check that each partition hashes to the same value - result = ddf.map_partitions( - check_partitions, output_nparts - ).compute() - assert all(result.to_list()) + if _partitions: + # If "_partitions" is the hash key, we except all but + # the first partition to be empty + assert_eq(ddf.partitions[0].compute(), df) + assert all( + len(ddf.partitions[i].compute()) == 0 + for i in range(1, ddf.npartitions) + ) + else: + + # Check that each partition hashes to the same value + result = ddf.map_partitions( + check_partitions, output_nparts + ).compute() + assert all(result.to_list()) - # Check the values (ignoring the row order) - expected = df.sort_values("key") - got = ddf.compute().sort_values("key") - assert_eq(got, expected) + # Check the values (ignoring the row order) + expected = df.sort_values("key") + got = ddf.compute().sort_values("key") + assert_eq(got, expected) @pytest.mark.parametrize("nworkers", [1, 2, 3]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -def test_dataframe_shuffle(backend, protocol, nworkers): +@pytest.mark.parametrize("_partitions", [True, False]) +def test_dataframe_shuffle(backend, protocol, nworkers, _partitions): if backend == "cudf": pytest.importorskip("cudf") - p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers)) + p = mp.Process( + target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions) + ) p.start() p.join() assert not p.exitcode From 95070d2c7acddce71e360a5e4732d153da2a2abc Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 26 Sep 2023 16:09:43 +0200 Subject: [PATCH 3/3] Typo Co-authored-by: Richard (Rick) Zamora --- dask_cuda/tests/test_explicit_comms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 49bd7058..ae4e3332 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -134,7 +134,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): assert ddf.npartitions == output_nparts if _partitions: - # If "_partitions" is the hash key, we except all but + # If "_partitions" is the hash key, we expect all but # the first partition to be empty assert_eq(ddf.partitions[0].compute(), df) assert all( @@ -142,7 +142,6 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): for i in range(1, ddf.npartitions) ) else: - # Check that each partition hashes to the same value result = ddf.map_partitions( check_partitions, output_nparts