From e2aa90ef9b961a33978cad247e7162f01b44d649 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Aug 2024 17:28:44 +0200 Subject: [PATCH 1/4] Pre-allocate references --- distributed/shuffle/_rechunk.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 0f0dfbf21f..1621cefd50 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -370,7 +370,13 @@ def cull( indices_to_keep = self._keys_to_indices(keys) _old_to_new = old_to_new(self.chunks_input, self.chunks) - culled_deps: defaultdict[Key, set[Key]] = defaultdict(set) + # Pre-allocate old block references, to allow reuse and reduce the + # graph's memory footprint a bit. + old_blocks = np.empty([len(c) for c in self.chunks_input], dtype="O") + for ndindex in np.ndindex(old_blocks.shape): + old_blocks[ndindex] = (self.name_input,) + ndindex + + culled_deps: dict[Key, set[Key]] = {} for nindex in indices_to_keep: old_indices_per_axis = [] keepmap[nindex] = True @@ -378,8 +384,10 @@ def cull( old_indices_per_axis.append( [old_chunk_index for old_chunk_index, _ in new_axis[index]] ) - for old_nindex in product(*old_indices_per_axis): - culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex) + culled_deps_for_nindex = { + old_blocks[old_nindex] for old_nindex in product(*old_indices_per_axis) + } + culled_deps[(self.name,) + nindex] = culled_deps_for_nindex # Protect against mutations later on with frozenset frozen_deps = { From 10dac617868b79a97da4793f2ffbb8495a4efb9d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Aug 2024 17:59:21 +0200 Subject: [PATCH 2/4] Reuse sets --- distributed/shuffle/_rechunk.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 1621cefd50..e50d9f39f4 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -376,7 +376,9 @@ def cull( for ndindex in np.ndindex(old_blocks.shape): old_blocks[ndindex] = (self.name_input,) + ndindex - culled_deps: dict[Key, set[Key]] = {} + keys_for_indices: dict[frozenset[tuple[int, ...]], frozenset[Key]] = {} + + culled_deps: dict[Key, frozenset[Key]] = {} for nindex in indices_to_keep: old_indices_per_axis = [] keepmap[nindex] = True @@ -384,22 +386,19 @@ def cull( old_indices_per_axis.append( [old_chunk_index for old_chunk_index, _ in new_axis[index]] ) - culled_deps_for_nindex = { - old_blocks[old_nindex] for old_nindex in product(*old_indices_per_axis) - } - culled_deps[(self.name,) + nindex] = culled_deps_for_nindex - - # Protect against mutations later on with frozenset - frozen_deps = { - output_task: frozenset(input_tasks) - for output_task, input_tasks in culled_deps.items() - } + indices = frozenset(product(*old_indices_per_axis)) + if indices not in keys_for_indices: + keys_for_indices[indices] = frozenset( + old_blocks[index] for index in indices + ) + + culled_deps[(self.name,) + nindex] = keys_for_indices[indices] if np.array_equal(keepmap, self.keepmap): - return self, frozen_deps + return self, culled_deps else: culled_layer = self._cull(keepmap) - return culled_layer, frozen_deps + return culled_layer, culled_deps def _construct_graph(self) -> _T_LowLevelGraph: import numpy as np From 01641ec8f5fc61defc70489c124386b87395d832 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Aug 2024 20:44:26 +0200 Subject: [PATCH 3/4] Simplify culling --- distributed/shuffle/_rechunk.py | 54 ++++++++++++++------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index e50d9f39f4..2470f3c750 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -370,29 +370,21 @@ def cull( indices_to_keep = self._keys_to_indices(keys) _old_to_new = old_to_new(self.chunks_input, self.chunks) - # Pre-allocate old block references, to allow reuse and reduce the - # graph's memory footprint a bit. - old_blocks = np.empty([len(c) for c in self.chunks_input], dtype="O") - for ndindex in np.ndindex(old_blocks.shape): - old_blocks[ndindex] = (self.name_input,) + ndindex - - keys_for_indices: dict[frozenset[tuple[int, ...]], frozenset[Key]] = {} - - culled_deps: dict[Key, frozenset[Key]] = {} - for nindex in indices_to_keep: - old_indices_per_axis = [] - keepmap[nindex] = True - for index, new_axis in zip(nindex, _old_to_new): - old_indices_per_axis.append( - [old_chunk_index for old_chunk_index, _ in new_axis[index]] - ) - indices = frozenset(product(*old_indices_per_axis)) - if indices not in keys_for_indices: - keys_for_indices[indices] = frozenset( - old_blocks[index] for index in indices - ) + for ndindex in indices_to_keep: + keepmap[ndindex] = True - culled_deps[(self.name,) + nindex] = keys_for_indices[indices] + culled_deps = {} + for ndpartial in _split_partials(_old_to_new): + if not np.any(keepmap[ndpartial.new]): + continue + + deps = frozenset( + (self.name_input,) + ndindex + for ndindex in _ndindices_of_slice(ndpartial.old) + ) + + for ndindex in _ndindices_of_slice(ndpartial.new): + culled_deps[(self.name,) + ndindex] = deps if np.array_equal(keepmap, self.keepmap): return self, culled_deps @@ -702,14 +694,12 @@ def _slice_new_chunks_into_partials( return tuple(sliced_axes) -def _partial_ndindex(ndslice: NDSlice) -> np.ndindex: - import numpy as np - - return np.ndindex(tuple(slice.stop - slice.start for slice in ndslice)) +def _ndindices_of_slice(ndslice: NDSlice) -> Iterator[NDIndex]: + return product(*(range(slc.start, slc.stop) for slc in ndslice)) -def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex: - return tuple(index + offset for index, offset in zip(partial_index, partial_offset)) +def _partial_index(global_index: NDIndex, partial_offset: NDIndex) -> NDIndex: + return tuple(index - offset for index, offset in zip(global_index, partial_offset)) def partial_concatenate( @@ -809,8 +799,8 @@ def partial_rechunk( ) transfer_keys = [] - for partial_index in _partial_ndindex(ndpartial.old): - global_index = _global_index(partial_index, old_partial_offset) + for global_index in _ndindices_of_slice(ndpartial.old): + partial_index = _partial_index(global_index, old_partial_offset) input_key = (input_name,) + global_index @@ -829,8 +819,8 @@ def partial_rechunk( dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys) new_partial_offset = tuple(axis.start for axis in ndpartial.new) - for partial_index in _partial_ndindex(ndpartial.new): - global_index = _global_index(partial_index, new_partial_offset) + for global_index in _ndindices_of_slice(ndpartial.new): + partial_index = _partial_index(global_index, new_partial_offset) if keepmap[global_index]: dsk[(unpack_group,) + global_index] = ( rechunk_unpack, From 4db28ae6d2d7ab2cd874ab76b3f207c080be794c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Aug 2024 20:49:06 +0200 Subject: [PATCH 4/4] Add comments --- distributed/shuffle/_rechunk.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 2470f3c750..efae7d80b7 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -374,10 +374,14 @@ def cull( keepmap[ndindex] = True culled_deps = {} + # Identify the individual partial rechunks for ndpartial in _split_partials(_old_to_new): + # Cull partials for which we do not keep any output tasks if not np.any(keepmap[ndpartial.new]): continue + # Within partials, we have all-to-all communication. + # Thus, all output tasks share the same input tasks. deps = frozenset( (self.name_input,) + ndindex for ndindex in _ndindices_of_slice(ndpartial.old)