Skip to content

Commit

Permalink
Improve graph submission time for P2P rechunking by avoiding unpack r…
Browse files Browse the repository at this point in the history
…ecursion into indices (#8672)
  • Loading branch information
fjetter authored Jun 5, 2024
1 parent 7cbfc4d commit 5708bdf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
7 changes: 4 additions & 3 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
from distributed.utils_comm import DoNotUnpack

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -445,9 +446,9 @@ def partial_rechunk(
rechunk_transfer,
input_key,
partial_token,
partial_index,
partial_new,
partial_old,
DoNotUnpack(partial_index),
DoNotUnpack(partial_new),
DoNotUnpack(partial_old),
disk,
)

Expand Down
21 changes: 21 additions & 0 deletions distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from distributed.config import get_loop_factory
from distributed.core import ConnectionPool, Status
from distributed.utils_comm import (
DoNotUnpack,
WrappedKey,
gather_from_workers,
pack_data,
Expand Down Expand Up @@ -261,3 +262,23 @@ def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())


def test_unpack_remotedata_custom_tuple():
# We don't want to recurse into custom tuples. This is used as a sentinel to
# avoid recursion for performance reasons if we know that there are no
# nested futures. This test case is not how this feature should be used in
# practice.

akey = WrappedKey("a")

ordinary_tuple = (1, 2, akey)
dont_recurse = DoNotUnpack(ordinary_tuple)

res, keys = unpack_remotedata(ordinary_tuple)
assert res is not ordinary_tuple
assert any(left != right for left, right in zip(ordinary_tuple, res))
assert keys == {akey}
res, keys = unpack_remotedata(dont_recurse)
assert not keys
assert res is dont_recurse
7 changes: 7 additions & 0 deletions distributed/utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def _unpack_remotedata_inner(
return o


class DoNotUnpack(tuple):
"""A tuple sublass to indicate that we should not unpack its contents
See also unpack_remotedata
"""


def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]:
"""Unpack WrappedKey objects from collection
Expand Down

0 comments on commit 5708bdf

Please sign in to comment.