Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory footprint of culling P2P rechunking #8845

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,28 +370,31 @@
indices_to_keep = self._keys_to_indices(keys)
_old_to_new = old_to_new(self.chunks_input, self.chunks)

culled_deps: defaultdict[Key, set[Key]] = defaultdict(set)
for nindex in indices_to_keep:
old_indices_per_axis = []
keepmap[nindex] = True
for index, new_axis in zip(nindex, _old_to_new):
old_indices_per_axis.append(
[old_chunk_index for old_chunk_index, _ in new_axis[index]]
)
for old_nindex in product(*old_indices_per_axis):
culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex)
for ndindex in indices_to_keep:
keepmap[ndindex] = True

# Protect against mutations later on with frozenset
frozen_deps = {
output_task: frozenset(input_tasks)
for output_task, input_tasks in culled_deps.items()
}
culled_deps = {}
# Identify the individual partial rechunks
for ndpartial in _split_partials(_old_to_new):
# Cull partials for which we do not keep any output tasks
if not np.any(keepmap[ndpartial.new]):
continue

Check warning on line 381 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L381

Added line #L381 was not covered by tests

# Within partials, we have all-to-all communication.
# Thus, all output tasks share the same input tasks.
deps = frozenset(
(self.name_input,) + ndindex
for ndindex in _ndindices_of_slice(ndpartial.old)
)

for ndindex in _ndindices_of_slice(ndpartial.new):
culled_deps[(self.name,) + ndindex] = deps

if np.array_equal(keepmap, self.keepmap):
return self, frozen_deps
return self, culled_deps
else:
culled_layer = self._cull(keepmap)
return culled_layer, frozen_deps
return culled_layer, culled_deps

Check warning on line 397 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L397

Added line #L397 was not covered by tests

def _construct_graph(self) -> _T_LowLevelGraph:
import numpy as np
Expand Down Expand Up @@ -695,14 +698,12 @@
return tuple(sliced_axes)


def _partial_ndindex(ndslice: NDSlice) -> np.ndindex:
import numpy as np

return np.ndindex(tuple(slice.stop - slice.start for slice in ndslice))
def _ndindices_of_slice(ndslice: NDSlice) -> Iterator[NDIndex]:
return product(*(range(slc.start, slc.stop) for slc in ndslice))


def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
return tuple(index + offset for index, offset in zip(partial_index, partial_offset))
def _partial_index(global_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
return tuple(index - offset for index, offset in zip(global_index, partial_offset))


def partial_concatenate(
Expand Down Expand Up @@ -802,8 +803,8 @@
)

transfer_keys = []
for partial_index in _partial_ndindex(ndpartial.old):
global_index = _global_index(partial_index, old_partial_offset)
for global_index in _ndindices_of_slice(ndpartial.old):
partial_index = _partial_index(global_index, old_partial_offset)

input_key = (input_name,) + global_index

Expand All @@ -822,8 +823,8 @@
dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys)

new_partial_offset = tuple(axis.start for axis in ndpartial.new)
for partial_index in _partial_ndindex(ndpartial.new):
global_index = _global_index(partial_index, new_partial_offset)
for global_index in _ndindices_of_slice(ndpartial.new):
partial_index = _partial_index(global_index, new_partial_offset)
if keepmap[global_index]:
dsk[(unpack_group,) + global_index] = (
rechunk_unpack,
Expand Down
Loading