diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/get/TransportGetSnapshotsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/get/TransportGetSnapshotsAction.java index ff5fdbaa787fe..269e53dafe0ce 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/get/TransportGetSnapshotsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/get/TransportGetSnapshotsAction.java @@ -57,6 +57,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Queue; @@ -248,18 +249,8 @@ void getMultipleReposSnapshotInfo(ActionListener listener) return; } - SubscribableListener - - .newForked(repositoryDataListener -> { - if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) { - repositoryDataListener.onResponse(null); - } else { - repositoriesService.repository(repoName).getRepositoryData(executor, repositoryDataListener); - } - }) - + SubscribableListener.newForked(l -> maybeGetRepositoryData(repoName, l)) .andThen((l, repositoryData) -> loadSnapshotInfos(repoName, repositoryData, l)) - .addListener(listeners.acquire()); } } @@ -268,6 +259,14 @@ void getMultipleReposSnapshotInfo(ActionListener listener) .addListener(listener.map(ignored -> buildResponse()), executor, threadPool.getThreadContext()); } + private void maybeGetRepositoryData(String repositoryName, ActionListener listener) { + if (snapshotNamePredicate == SnapshotNamePredicate.MATCH_CURRENT_ONLY) { + listener.onResponse(null); + } else { + repositoriesService.repository(repositoryName).getRepositoryData(executor, listener); + } + } + private boolean skipRepository(String repositoryName) { if (sortBy == SnapshotSortKey.REPOSITORY && fromSortValue != null) { // If we are sorting by repository name with an offset given by fromSortValue, skip earlier repositories @@ -277,61 +276,101 @@ private boolean skipRepository(String repositoryName) { } } - private void loadSnapshotInfos(String repo, @Nullable RepositoryData repositoryData, ActionListener listener) { + private void loadSnapshotInfos(String repositoryName, @Nullable RepositoryData repositoryData, ActionListener listener) { assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.MANAGEMENT); - if (cancellableTask.notifyIfCancelled(listener)) { - return; - } - - final Set unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames()); - final Set toResolve = new HashSet<>(); - - for (final var snapshotInProgress : snapshotsInProgress.forRepo(repo)) { - final var snapshotName = snapshotInProgress.snapshot().getSnapshotId().getName(); - unmatchedRequiredNames.remove(snapshotName); - if (snapshotNamePredicate.test(snapshotName, true)) { - toResolve.add(snapshotInProgress.snapshot()); - } - } - - if (repositoryData != null) { - for (final var snapshotId : repositoryData.getSnapshotIds()) { - final var snapshotName = snapshotId.getName(); - unmatchedRequiredNames.remove(snapshotName); - if (snapshotNamePredicate.test(snapshotName, false) && matchesPredicates(snapshotId, repositoryData)) { - toResolve.add(new Snapshot(repo, snapshotId)); - } - } - } - - if (unmatchedRequiredNames.isEmpty() == false) { - throw new SnapshotMissingException(repo, unmatchedRequiredNames.iterator().next()); - } + cancellableTask.ensureNotCancelled(); + ensureRequiredNamesPresent(repositoryName, repositoryData); if (verbose) { - loadSnapshotInfos(repo, toResolve.stream().map(Snapshot::getSnapshotId).toList(), listener); + loadSnapshotInfos(repositoryName, getSnapshotIdIterator(repositoryName, repositoryData), listener); } else { assert fromSortValuePredicates.isMatchAll() : "filtering is not supported in non-verbose mode"; assert slmPolicyPredicate == SlmPolicyPredicate.MATCH_ALL_POLICIES : "filtering is not supported in non-verbose mode"; addSimpleSnapshotInfos( - toResolve, - repo, + getSnapshotIdIterator(repositoryName, repositoryData), + repositoryName, repositoryData, - snapshotsInProgress.forRepo(repo).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList() + snapshotsInProgress.forRepo(repositoryName).stream().map(entry -> SnapshotInfo.inProgress(entry).basic()).toList() ); listener.onResponse(null); } } - private void loadSnapshotInfos(String repositoryName, Collection snapshotIds, ActionListener listener) { + /** + * Check that the repository contains every required name according to {@link #snapshotNamePredicate}. + * + * @throws SnapshotMissingException if one or more required names are missing. + */ + private void ensureRequiredNamesPresent(String repositoryName, @Nullable RepositoryData repositoryData) { + if (snapshotNamePredicate.requiredNames().isEmpty()) { + return; + } + + final var unmatchedRequiredNames = new HashSet<>(snapshotNamePredicate.requiredNames()); + for (final var snapshotInProgress : snapshotsInProgress.forRepo(repositoryName)) { + unmatchedRequiredNames.remove(snapshotInProgress.snapshot().getSnapshotId().getName()); + } + if (unmatchedRequiredNames.isEmpty()) { + return; + } + if (repositoryData != null) { + for (final var snapshotId : repositoryData.getSnapshotIds()) { + unmatchedRequiredNames.remove(snapshotId.getName()); + } + if (unmatchedRequiredNames.isEmpty()) { + return; + } + } + throw new SnapshotMissingException(repositoryName, unmatchedRequiredNames.iterator().next()); + } + + /** + * @return an iterator over the snapshot IDs in the given repository which match {@link #snapshotNamePredicate}. + */ + private Iterator getSnapshotIdIterator(String repositoryName, @Nullable RepositoryData repositoryData) { + + // now iterate through the snapshots again, returning matching IDs (or null) + final Set matchingInProgressSnapshots = new HashSet<>(); + return Iterators.concat( + // matching in-progress snapshots first + Iterators.filter( + Iterators.map( + snapshotsInProgress.forRepo(repositoryName).iterator(), + snapshotInProgress -> snapshotInProgress.snapshot().getSnapshotId() + ), + snapshotId -> { + if (snapshotNamePredicate.test(snapshotId.getName(), true)) { + matchingInProgressSnapshots.add(snapshotId); + return true; + } else { + return false; + } + } + ), + repositoryData == null + // only returning in-progress snapshots + ? Collections.emptyIterator() + // also return matching completed snapshots (except any ones that were also found to be in-progress) + : Iterators.filter( + repositoryData.getSnapshotIds().iterator(), + snapshotId -> matchingInProgressSnapshots.contains(snapshotId) == false + && snapshotNamePredicate.test(snapshotId.getName(), false) + && matchesPredicates(snapshotId, repositoryData) + ) + ); + } + + private void loadSnapshotInfos(String repositoryName, Iterator snapshotIdIterator, ActionListener listener) { if (cancellableTask.notifyIfCancelled(listener)) { return; } final AtomicInteger repositoryTotalCount = new AtomicInteger(); - final List snapshots = new ArrayList<>(snapshotIds.size()); - final Set snapshotIdsToIterate = new HashSet<>(snapshotIds); + final Set snapshotIdsToIterate = new HashSet<>(); + snapshotIdIterator.forEachRemaining(snapshotIdsToIterate::add); + + final List snapshots = new ArrayList<>(snapshotIdsToIterate.size()); // first, look at the snapshots in progress final List entries = SnapshotsService.currentSnapshots( snapshotsInProgress, @@ -409,7 +448,7 @@ public void onFailure(Exception e) { } }) - // no need to synchronize access to snapshots: Repository#getSnapshotInfo fails fast but we're on the success path here + // no need to synchronize access to snapshots: all writes happen-before this read .andThenAccept(ignored -> addResults(repositoryTotalCount.get(), snapshots)) .addListener(listener); @@ -422,9 +461,9 @@ private void addResults(int repositoryTotalCount, List snapshots) } private void addSimpleSnapshotInfos( - final Set toResolve, - final String repoName, - final RepositoryData repositoryData, + final Iterator snapshotIdIterator, + final String repositoryName, + @Nullable final RepositoryData repositoryData, final List currentSnapshots ) { if (repositoryData == null) { @@ -433,11 +472,14 @@ private void addSimpleSnapshotInfos( return; } // else want non-current snapshots as well, which are found in the repository data + final Set toResolve = new HashSet<>(); + snapshotIdIterator.forEachRemaining(toResolve::add); + List snapshotInfos = new ArrayList<>(currentSnapshots.size() + toResolve.size()); int repositoryTotalCount = 0; for (SnapshotInfo snapshotInfo : currentSnapshots) { assert snapshotInfo.startTime() == 0L && snapshotInfo.endTime() == 0L && snapshotInfo.totalShards() == 0L : snapshotInfo; - if (toResolve.remove(snapshotInfo.snapshot())) { + if (toResolve.remove(snapshotInfo.snapshot().getSnapshotId())) { repositoryTotalCount += 1; if (afterPredicate.test(snapshotInfo)) { snapshotInfos.add(snapshotInfo); @@ -448,19 +490,19 @@ private void addSimpleSnapshotInfos( if (indices) { for (IndexId indexId : repositoryData.getIndices().values()) { for (SnapshotId snapshotId : repositoryData.getSnapshots(indexId)) { - if (toResolve.contains(new Snapshot(repoName, snapshotId))) { + if (toResolve.contains(snapshotId)) { snapshotsToIndices.computeIfAbsent(snapshotId, (k) -> new ArrayList<>()).add(indexId.getName()); } } } } - for (Snapshot snapshot : toResolve) { + for (SnapshotId snapshotId : toResolve) { final var snapshotInfo = new SnapshotInfo( - snapshot, - snapshotsToIndices.getOrDefault(snapshot.getSnapshotId(), Collections.emptyList()), + new Snapshot(repositoryName, snapshotId), + snapshotsToIndices.getOrDefault(snapshotId, Collections.emptyList()), Collections.emptyList(), Collections.emptyList(), - repositoryData.getSnapshotState(snapshot.getSnapshotId()) + repositoryData.getSnapshotState(snapshotId) ); repositoryTotalCount += 1; if (afterPredicate.test(snapshotInfo)) { diff --git a/server/src/main/java/org/elasticsearch/common/collect/Iterators.java b/server/src/main/java/org/elasticsearch/common/collect/Iterators.java index 165280e370025..d029f8e3becc0 100644 --- a/server/src/main/java/org/elasticsearch/common/collect/Iterators.java +++ b/server/src/main/java/org/elasticsearch/common/collect/Iterators.java @@ -21,6 +21,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntFunction; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.function.ToIntFunction; @@ -179,6 +180,59 @@ public void forEachRemaining(Consumer action) { } } + /** + * @param input An iterator over non-null values. + * @param predicate The predicate with which to filter the input. + * @return an iterator which returns the values from {@code input} which match {@code predicate}. + */ + public static Iterator filter(Iterator input, Predicate predicate) { + while (input.hasNext()) { + final var value = input.next(); + assert value != null; + if (predicate.test(value)) { + return new FilterIterator<>(value, input, predicate); + } + } + return Collections.emptyIterator(); + } + + private static final class FilterIterator implements Iterator { + private final Iterator input; + private final Predicate predicate; + private T next; + + FilterIterator(T value, Iterator input, Predicate predicate) { + this.next = value; + this.input = input; + this.predicate = predicate; + assert next != null; + assert predicate.test(next); + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public T next() { + if (hasNext() == false) { + throw new NoSuchElementException(); + } + final var value = next; + while (input.hasNext()) { + final var laterValue = input.next(); + assert laterValue != null; + if (predicate.test(laterValue)) { + next = laterValue; + return value; + } + } + next = null; + return value; + } + } + public static Iterator flatMap(Iterator input, Function> fn) { while (input.hasNext()) { final var value = fn.apply(input.next()); diff --git a/server/src/test/java/org/elasticsearch/common/collect/IteratorsTests.java b/server/src/test/java/org/elasticsearch/common/collect/IteratorsTests.java index 67f74df78e256..a3573d081397a 100644 --- a/server/src/test/java/org/elasticsearch/common/collect/IteratorsTests.java +++ b/server/src/test/java/org/elasticsearch/common/collect/IteratorsTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.common.collect; import org.elasticsearch.common.Randomness; +import org.elasticsearch.core.Assertions; import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; @@ -23,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiPredicate; +import java.util.function.Predicate; import java.util.function.ToIntFunction; import java.util.stream.IntStream; @@ -219,6 +221,27 @@ public void testMap() { assertEquals(array.length, index.get()); } + public void testFilter() { + assertSame(Collections.emptyIterator(), Iterators.filter(Collections.emptyIterator(), i -> fail(null, "not called"))); + + final var array = randomIntegerArray(); + assertSame(Collections.emptyIterator(), Iterators.filter(Iterators.forArray(array), i -> false)); + + final var threshold = array.length > 0 && randomBoolean() ? randomFrom(array) : randomIntBetween(0, 1000); + final Predicate predicate = i -> i <= threshold; + final var expectedResults = Arrays.stream(array).filter(predicate).toList(); + final var index = new AtomicInteger(); + Iterators.filter(Iterators.forArray(array), predicate) + .forEachRemaining(i -> assertEquals(expectedResults.get(index.getAndIncrement()), i)); + + if (Assertions.ENABLED) { + final var predicateCalled = new AtomicBoolean(); + final var inputIterator = Iterators.forArray(new Object[] { null }); + expectThrows(AssertionError.class, () -> Iterators.filter(inputIterator, i -> predicateCalled.compareAndSet(false, true))); + assertFalse(predicateCalled.get()); + } + } + public void testFailFast() { final var array = randomIntegerArray(); assertEmptyIterator(Iterators.failFast(Iterators.forArray(array), () -> true));