diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 6d46ff7228..803a33fe6c 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -193,6 +193,38 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword): with dask.config.set(config): x2 = rechunk(x, chunks=new, method=keyword) expected_algorithm = keyword if keyword is not None else config_value + if expected_algorithm == "p2p": + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + elif expected_algorithm == "tasks": + assert not any( + key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__() + ) + # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic) + else: + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + + assert x2.chunks == new + assert np.all(await c.compute(x2) == a) + + +@pytest.mark.parametrize( + ["new", "expected_algorithm"], + [ + # All-to-all rechunking defaults to P2P + (((1,) * 100, (100,)), "p2p"), + # Localized rechunking defaults to tasks + (((50, 50), (2,) * 50), "tasks"), + # Less local rechunking first defaults to tasks, + (((25, 25, 25, 25), (4,) * 25), "tasks"), + # then switches to p2p + (((10,) * 10, (10,) * 10), "p2p"), + ], +) +@gen_cluster(client=True) +async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm): + a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100)) + x = da.from_array(a, chunks=(100, 1)) + x2 = rechunk(x, chunks=new) if expected_algorithm == "p2p": assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) else: