From a571fc4e34d5ba6e2631b100dc969b71502acc31 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 5 Jun 2024 15:50:04 +0200 Subject: [PATCH] Improve graph submission time for P2P rechunking by avoiding unpack recursion into indices --- distributed/shuffle/_rechunk.py | 7 ++++--- distributed/tests/test_utils_comm.py | 21 +++++++++++++++++++++ distributed/utils_comm.py | 7 +++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index be36f47f56..b33e90730b 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -130,6 +130,7 @@ from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof +from distributed.utils_comm import DoNotUnpack if TYPE_CHECKING: import numpy as np @@ -445,9 +446,9 @@ def partial_rechunk( rechunk_transfer, input_key, partial_token, - partial_index, - partial_new, - partial_old, + DoNotUnpack(partial_index), + DoNotUnpack(partial_new), + DoNotUnpack(partial_old), disk, ) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index ee0eaff089..94b1c33f3a 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -13,6 +13,7 @@ from distributed.config import get_loop_factory from distributed.core import ConnectionPool, Status from distributed.utils_comm import ( + DoNotUnpack, WrappedKey, gather_from_workers, pack_data, @@ -261,3 +262,23 @@ def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None: res, keys = unpack_remotedata(dsk) assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed assert_eq(keys, set()) + + +def test_unpack_remotedata_custom_tuple(): + # We don't want to recurse into custom tuples. This is used as a sentinel to + # avoid recursion for performance reasons if we know that there are no + # nested futures. This test case is not how this feature should be used in + # practice. + + akey = WrappedKey("a") + + ordinary_tuple = (1, 2, akey) + dont_recurse = DoNotUnpack(ordinary_tuple) + + res, keys = unpack_remotedata(ordinary_tuple) + assert res is not ordinary_tuple + assert any(left != right for left, right in zip(ordinary_tuple, res)) + assert keys == {akey} + res, keys = unpack_remotedata(dont_recurse) + assert not keys + assert res is dont_recurse diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e0a9eda88b..7c10c25635 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -269,6 +269,13 @@ def _unpack_remotedata_inner( return o +class DoNotUnpack(tuple): + """A tuple sublass to indicate that we should not unpack its contents + + See also unpack_remotedata + """ + + def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]: """Unpack WrappedKey objects from collection