Skip to content

Commit

Permalink
Further reduce allocations in TransportGetSnapshotsAction (elastic#…
Browse files Browse the repository at this point in the history
…110817)

Collecting the list of snapshot IDs over which to iterate within each
repository today involves several other potentially-large intermediate
collections and a bunch of other unnecessary allocations. This commit
replaces those temporary collections with an iterator which saves all
this temporary memory usage.

Relates ES-8906
  • Loading branch information
DaveCTurner authored Jul 12, 2024
1 parent 13c3211 commit c96f801
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
Expand Down Expand Up @@ -248,18 +249,8 @@ void getMultipleReposSnapshotInfo(ActionListener<GetSnapshotsResponse> listener)
return;
}

SubscribableListener

.<RepositoryData>newForked(repositoryDataListener -> {
if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) {
repositoryDataListener.onResponse(null);
} else {
repositoriesService.repository(repoName).getRepositoryData(executor, repositoryDataListener);
}
})

SubscribableListener.<RepositoryData>newForked(l -> maybeGetRepositoryData(repoName, l))
.<Void>andThen((l, repositoryData) -> loadSnapshotInfos(repoName, repositoryData, l))

.addListener(listeners.acquire());
}
}
Expand All @@ -268,6 +259,14 @@ void getMultipleReposSnapshotInfo(ActionListener<GetSnapshotsResponse> listener)
.addListener(listener.map(ignored -> buildResponse()), executor, threadPool.getThreadContext());
}

private void maybeGetRepositoryData(String repositoryName, ActionListener<RepositoryData> listener) {
if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) {
listener.onResponse(null);
} else {
repositoriesService.repository(repositoryName).getRepositoryData(executor, listener);
}
}

private boolean skipRepository(String repositoryName) {
if (sortBy == SnapshotSortKey.REPOSITORY && fromSortValue != null) {
// If we are sorting by repository name with an offset given by fromSortValue, skip earlier repositories
Expand All @@ -277,61 +276,101 @@ private boolean skipRepository(String repositoryName) {
}
}

private void loadSnapshotInfos(String repo, @Nullable RepositoryData repositoryData, ActionListener<Void> listener) {
private void loadSnapshotInfos(String repositoryName, @Nullable RepositoryData repositoryData, ActionListener<Void> listener) {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.MANAGEMENT);

if (cancellableTask.notifyIfCancelled(listener)) {
return;
}

final Set<String> unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames());
final Set<Snapshot> toResolve = new HashSet<>();

for (final var snapshotInProgress : snapshotsInProgress.forRepo(repo)) {
final var snapshotName = snapshotInProgress.snapshot().getSnapshotId().getName();
unmatchedRequiredNames.remove(snapshotName);
if (snapshotNamePredicate.test(snapshotName, true)) {
toResolve.add(snapshotInProgress.snapshot());
}
}

if (repositoryData != null) {
for (final var snapshotId : repositoryData.getSnapshotIds()) {
final var snapshotName = snapshotId.getName();
unmatchedRequiredNames.remove(snapshotName);
if (snapshotNamePredicate.test(snapshotName, false) && matchesPredicates(snapshotId, repositoryData)) {
toResolve.add(new Snapshot(repo, snapshotId));
}
}
}

if (unmatchedRequiredNames.isEmpty() == false) {
throw new SnapshotMissingException(repo, unmatchedRequiredNames.iterator().next());
}
cancellableTask.ensureNotCancelled();
ensureRequiredNamesPresent(repositoryName, repositoryData);

if (verbose) {
loadSnapshotInfos(repo, toResolve.stream().map(Snapshot::getSnapshotId).toList(), listener);
loadSnapshotInfos(repositoryName, getSnapshotIdIterator(repositoryName, repositoryData), listener);
} else {
assert fromSortValuePredicates.isMatchAll() : "filtering is not supported in non-verbose mode";
assert slmPolicyPredicate == SlmPolicyPredicate.MATCH_ALL_POLICIES : "filtering is not supported in non-verbose mode";

addSimpleSnapshotInfos(
toResolve,
repo,
getSnapshotIdIterator(repositoryName, repositoryData),
repositoryName,
repositoryData,
snapshotsInProgress.forRepo(repo).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList()
snapshotsInProgress.forRepo(repositoryName).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList()
);
listener.onResponse(null);
}
}

private void loadSnapshotInfos(String repositoryName, Collection<SnapshotId> snapshotIds, ActionListener<Void> listener) {
/**
* Check that the repository contains every <i>required</i> name according to {@link #snapshotNamePredicate}.
*
* @throws SnapshotMissingException if one or more required names are missing.
*/
private void ensureRequiredNamesPresent(String repositoryName, @Nullable RepositoryData repositoryData) {
if (snapshotNamePredicate.requiredNames().isEmpty()) {
return;
}

final var unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames());
for (final var snapshotInProgress : snapshotsInProgress.forRepo(repositoryName)) {
unmatchedRequiredNames.remove(snapshotInProgress.snapshot().getSnapshotId().getName());
}
if (unmatchedRequiredNames.isEmpty()) {
return;
}
if (repositoryData != null) {
for (final var snapshotId : repositoryData.getSnapshotIds()) {
unmatchedRequiredNames.remove(snapshotId.getName());
}
if (unmatchedRequiredNames.isEmpty()) {
return;
}
}
throw new SnapshotMissingException(repositoryName, unmatchedRequiredNames.iterator().next());
}

/**
* @return an iterator over the snapshot IDs in the given repository which match {@link #snapshotNamePredicate}.
*/
private Iterator<SnapshotId> getSnapshotIdIterator(String repositoryName, @Nullable RepositoryData repositoryData) {

// now iterate through the snapshots again, returning matching IDs (or null)
final Set<SnapshotId> matchingInProgressSnapshots = new HashSet<>();
return Iterators.concat(
// matching in-progress snapshots first
Iterators.filter(
Iterators.map(
snapshotsInProgress.forRepo(repositoryName).iterator(),
snapshotInProgress -> snapshotInProgress.snapshot().getSnapshotId()
),
snapshotId -> {
if (snapshotNamePredicate.test(snapshotId.getName(), true)) {
matchingInProgressSnapshots.add(snapshotId);
return true;
} else {
return false;
}
}
),
repositoryData == null
// only returning in-progress snapshots
? Collections.emptyIterator()
// also return matching completed snapshots (except any ones that were also found to be in-progress)
: Iterators.filter(
repositoryData.getSnapshotIds().iterator(),
snapshotId -> matchingInProgressSnapshots.contains(snapshotId) == false
&& snapshotNamePredicate.test(snapshotId.getName(), false)
&& matchesPredicates(snapshotId, repositoryData)
)
);
}

private void loadSnapshotInfos(String repositoryName, Iterator<SnapshotId> snapshotIdIterator, ActionListener<Void> listener) {
if (cancellableTask.notifyIfCancelled(listener)) {
return;
}
final AtomicInteger repositoryTotalCount = new AtomicInteger();
final List<SnapshotInfo> snapshots = new ArrayList<>(snapshotIds.size());
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>(snapshotIds);
final Set<SnapshotId> snapshotIdsToIterate = new HashSet<>();
snapshotIdIterator.forEachRemaining(snapshotIdsToIterate::add);

final List<SnapshotInfo> snapshots = new ArrayList<>(snapshotIdsToIterate.size());
// first, look at the snapshots in progress
final List<SnapshotsInProgress.Entry> entries = SnapshotsService.currentSnapshots(
snapshotsInProgress,
Expand Down Expand Up @@ -409,7 +448,7 @@ public void onFailure(Exception e) {
}
})

// no need to synchronize access to snapshots: Repository#getSnapshotInfo fails fast but we're on the success path here
// no need to synchronize access to snapshots: all writes happen-before this read
.andThenAccept(ignored -> addResults(repositoryTotalCount.get(), snapshots))

.addListener(listener);
Expand All @@ -422,9 +461,9 @@ private void addResults(int repositoryTotalCount, List<SnapshotInfo> snapshots)
}

private void addSimpleSnapshotInfos(
final Set<Snapshot> toResolve,
final String repoName,
final RepositoryData repositoryData,
final Iterator<SnapshotId> snapshotIdIterator,
final String repositoryName,
@Nullable final RepositoryData repositoryData,
final List<SnapshotInfo> currentSnapshots
) {
if (repositoryData == null) {
Expand All @@ -433,11 +472,14 @@ private void addSimpleSnapshotInfos(
return;
} // else want non-current snapshots as well, which are found in the repository data

final Set<SnapshotId> toResolve = new HashSet<>();
snapshotIdIterator.forEachRemaining(toResolve::add);

List<SnapshotInfo> snapshotInfos = new ArrayList<>(currentSnapshots.size() + toResolve.size());
int repositoryTotalCount = 0;
for (SnapshotInfo snapshotInfo : currentSnapshots) {
assert snapshotInfo.startTime() == 0L && snapshotInfo.endTime() == 0L && snapshotInfo.totalShards() == 0L : snapshotInfo;
if (toResolve.remove(snapshotInfo.snapshot())) {
if (toResolve.remove(snapshotInfo.snapshot().getSnapshotId())) {
repositoryTotalCount += 1;
if (afterPredicate.test(snapshotInfo)) {
snapshotInfos.add(snapshotInfo);
Expand All @@ -448,19 +490,19 @@ private void addSimpleSnapshotInfos(
if (indices) {
for (IndexId indexId : repositoryData.getIndices().values()) {
for (SnapshotId snapshotId : repositoryData.getSnapshots(indexId)) {
if (toResolve.contains(new Snapshot(repoName, snapshotId))) {
if (toResolve.contains(snapshotId)) {
snapshotsToIndices.computeIfAbsent(snapshotId, (k) -> new ArrayList<>()).add(indexId.getName());
}
}
}
}
for (Snapshot snapshot : toResolve) {
for (SnapshotId snapshotId : toResolve) {
final var snapshotInfo = new SnapshotInfo(
snapshot,
snapshotsToIndices.getOrDefault(snapshot.getSnapshotId(), Collections.emptyList()),
new Snapshot(repositoryName, snapshotId),
snapshotsToIndices.getOrDefault(snapshotId, Collections.emptyList()),
Collections.emptyList(),
Collections.emptyList(),
repositoryData.getSnapshotState(snapshot.getSnapshotId())
repositoryData.getSnapshotState(snapshotId)
);
repositoryTotalCount += 1;
if (afterPredicate.test(snapshotInfo)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.function.ToIntFunction;

Expand Down Expand Up @@ -179,6 +180,59 @@ public void forEachRemaining(Consumer<? super U> action) {
}
}

/**
* @param input An iterator over <i>non-null</i> values.
* @param predicate The predicate with which to filter the input.
* @return an iterator which returns the values from {@code input} which match {@code predicate}.
*/
public static <T> Iterator<T> filter(Iterator<? extends T> input, Predicate<T> predicate) {
while (input.hasNext()) {
final var value = input.next();
assert value != null;
if (predicate.test(value)) {
return new FilterIterator<>(value, input, predicate);
}
}
return Collections.emptyIterator();
}

private static final class FilterIterator<T> implements Iterator<T> {
private final Iterator<? extends T> input;
private final Predicate<T> predicate;
private T next;

FilterIterator(T value, Iterator<? extends T> input, Predicate<T> predicate) {
this.next = value;
this.input = input;
this.predicate = predicate;
assert next != null;
assert predicate.test(next);
}

@Override
public boolean hasNext() {
return next != null;
}

@Override
public T next() {
if (hasNext() == false) {
throw new NoSuchElementException();
}
final var value = next;
while (input.hasNext()) {
final var laterValue = input.next();
assert laterValue != null;
if (predicate.test(laterValue)) {
next = laterValue;
return value;
}
}
next = null;
return value;
}
}

public static <T, U> Iterator<U> flatMap(Iterator<? extends T> input, Function<T, Iterator<? extends U>> fn) {
while (input.hasNext()) {
final var value = fn.apply(input.next());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.common.collect;

import org.elasticsearch.common.Randomness;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;

Expand All @@ -23,6 +24,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -219,6 +221,27 @@ public void testMap() {
assertEquals(array.length, index.get());
}

public void testFilter() {
assertSame(Collections.emptyIterator(), Iterators.filter(Collections.emptyIterator(), i -> fail(null, "not called")));

final var array = randomIntegerArray();
assertSame(Collections.emptyIterator(), Iterators.filter(Iterators.forArray(array), i -> false));

final var threshold = array.length > 0 && randomBoolean() ? randomFrom(array) : randomIntBetween(0, 1000);
final Predicate<Integer> predicate = i -> i <= threshold;
final var expectedResults = Arrays.stream(array).filter(predicate).toList();
final var index = new AtomicInteger();
Iterators.filter(Iterators.forArray(array), predicate)
.forEachRemaining(i -> assertEquals(expectedResults.get(index.getAndIncrement()), i));

if (Assertions.ENABLED) {
final var predicateCalled = new AtomicBoolean();
final var inputIterator = Iterators.forArray(new Object[] { null });
expectThrows(AssertionError.class, () -> Iterators.filter(inputIterator, i -> predicateCalled.compareAndSet(false, true)));
assertFalse(predicateCalled.get());
}
}

public void testFailFast() {
final var array = randomIntegerArray();
assertEmptyIterator(Iterators.failFast(Iterators.forArray(array), () -> true));
Expand Down

0 comments on commit c96f801

Please sign in to comment.