diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 15dddc6640e77..b2ca769fbf203 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -86,8 +86,6 @@ import java.util.function.BiFunction; import java.util.stream.IntStream; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; @@ -98,7 +96,7 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; - private SearchRequestOperationsListener assertingListener; + private SearchRequestOperationsListenerAssertingListener assertingListener; ThreadPool threadPool; @Before @@ -107,27 +105,7 @@ public void setUp() throws Exception { super.setUp(); executor = Executors.newFixedThreadPool(1); threadPool = new TestThreadPool(getClass().getName()); - assertingListener = new SearchRequestOperationsListener() { - private volatile SearchPhase phase; - - @Override - protected void onPhaseStart(SearchPhaseContext context) { - assertThat(phase, is(nullValue())); - phase = context.getCurrentPhase(); - } - - @Override - protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - assertThat(phase, is(context.getCurrentPhase())); - phase = null; - } - - @Override - protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - assertThat(phase, is(context.getCurrentPhase())); - phase = null; - } - }; + assertingListener = new SearchRequestOperationsListenerAssertingListener(); } @After @@ -137,6 +115,7 @@ public void tearDown() throws Exception { executor.shutdown(); assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS)); ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS); + assertingListener.assertFinished(); } private AbstractSearchAsyncAction createAction( @@ -363,7 +342,7 @@ public void testOnPhaseFailureAndVerifyListeners() { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); SearchRequestStats testListener = new SearchRequestStats(clusterSettings); - final List requestOperationListeners = List.of(testListener); + final List requestOperationListeners = List.of(testListener, assertingListener); SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); @@ -395,6 +374,7 @@ public void run() { SearchShardIterator searchShardIterator = new SearchShardIterator(null, shardId, Collections.emptyList(), OriginalIndices.NONE); searchShardIterator.resetAndSkip(); action.skipShard(searchShardIterator); + action.start(); action.executeNextPhase(action, fetchPhase); assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); action.onPhaseFailure(new SearchPhase("test") { @@ -626,7 +606,7 @@ public void onFailure(Exception e) { public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedException { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); SearchRequestStats testListener = new SearchRequestStats(clusterSettings); - final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + final List requestOperationListeners = new ArrayList<>(List.of(testListener, assertingListener)); long delay = (randomIntBetween(1, 5)); delay = delay * 10; @@ -676,7 +656,7 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx public void testOnPhaseListenersWithDfsType() throws InterruptedException { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); SearchRequestStats testListener = new SearchRequestStats(clusterSettings); - final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + final List requestOperationListeners = new ArrayList<>(List.of(testListener, assertingListener)); SearchDfsQueryThenFetchAsyncAction searchDfsQueryThenFetchAsyncAction = createSearchDfsQueryThenFetchAsyncAction( requestOperationListeners @@ -697,6 +677,13 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()), greaterThanOrEqualTo(delay)); assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertingListener.onPhaseEnd( + searchDfsQueryThenFetchAsyncAction, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequest() + ) + ); } private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( @@ -812,7 +799,13 @@ ShardSearchFailure[] buildShardFailures() { @Override public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray queryResults) { - start(); + new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger).onPhaseEnd( + this, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger), + searchRequest + ) + ); } }; } diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 7d16ee335ea93..8ed1ade7d91c0 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -59,7 +59,6 @@ import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; -import org.junit.After; import org.junit.Before; import java.io.IOException; @@ -67,7 +66,6 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashSet; -import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -76,7 +74,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.stream.IntStream; @@ -84,42 +81,15 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.collection.IsEmptyCollection.empty; public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { private SearchRequestOperationsListener assertingListener; - private Set phases; @Before public void setUp() throws Exception { super.setUp(); - phases = Collections.newSetFromMap(new IdentityHashMap<>()); - assertingListener = new SearchRequestOperationsListener() { - @Override - protected void onPhaseStart(SearchPhaseContext context) { - assertThat(phases.contains(context.getCurrentPhase()), is(false)); - phases.add(context.getCurrentPhase()); - } - - @Override - protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - assertThat(phases.contains(context.getCurrentPhase()), is(true)); - phases.remove(context.getCurrentPhase()); - } - - @Override - protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - assertThat(phases.contains(context.getCurrentPhase()), is(true)); - phases.remove(context.getCurrentPhase()); - } - }; - } - - @After - public void tearDown() throws Exception { - super.tearDown(); - assertBusy(() -> assertThat(phases, empty()), 5, TimeUnit.SECONDS); + assertingListener = new SearchRequestOperationsListenerAssertingListener(); } public void testFilterShards() throws InterruptedException { @@ -381,7 +351,7 @@ public void sendCanMatch( randomIntBetween(1, 32), SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ), NoopTracer.INSTANCE diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 029ee4f29e76b..3c720dd31bb4d 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -57,6 +57,8 @@ import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportRequestOptions; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -80,6 +82,23 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; public class SearchAsyncActionTests extends OpenSearchTestCase { + private SearchRequestOperationsListenerAssertingListener assertingListener; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + + assertingListener = new SearchRequestOperationsListenerAssertingListener(); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + + assertingListener.assertFinished(); + } public void testSkipSearchShards() throws InterruptedException { SearchRequest request = new SearchRequest(); @@ -139,7 +158,10 @@ public void testSkipSearchShards() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request), + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ), NoopTracer.INSTANCE ) { @@ -189,6 +211,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) { assertEquals(0, searchResponse.getFailedShards()); assertEquals(numSkipped, searchResponse.getSkippedShards()); assertEquals(shardsIter.size(), searchResponse.getSuccessfulShards()); + assertingListener.onPhaseEnd( + asyncAction, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ) + ); } public void testLimitConcurrentShardRequests() throws InterruptedException { @@ -259,7 +288,10 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request), + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ), NoopTracer.INSTANCE ) { @@ -315,6 +347,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) { latch.await(); assertTrue(searchPhaseDidRun.get()); assertEquals(numShards, numRequests.get()); + assertingListener.onPhaseEnd( + asyncAction, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ) + ); } public void testFanOutAndCollect() throws InterruptedException { @@ -378,7 +417,10 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request), + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ), NoopTracer.INSTANCE ) { TestSearchResponse response = new TestSearchResponse(); @@ -436,6 +478,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) { } else { assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty()); } + assertingListener.onPhaseEnd( + asyncAction, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ) + ); executor.shutdown(); } @@ -502,7 +551,10 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request), + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ), NoopTracer.INSTANCE ) { TestSearchResponse response = new TestSearchResponse(); @@ -617,7 +669,10 @@ public void testAllowPartialResults() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request), + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ), NoopTracer.INSTANCE ) { @Override @@ -666,6 +721,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) { assertTrue(searchPhaseDidRun.get()); assertEquals(numShards, numRequests.get()); assertThat(numFailReplicas.get(), greaterThanOrEqualTo(1)); + assertingListener.onPhaseEnd( + asyncAction, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ) + ); } static GroupShardsIterator getShardsIter( diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index bbad2eb935c13..95c96ac56e302 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -63,6 +63,8 @@ import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; +import org.junit.After; +import org.junit.Before; import java.util.Collections; import java.util.List; @@ -78,6 +80,24 @@ import static org.hamcrest.Matchers.instanceOf; public class SearchQueryThenFetchAsyncActionTests extends OpenSearchTestCase { + private SearchRequestOperationsListenerAssertingListener assertingListener; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + + assertingListener = new SearchRequestOperationsListenerAssertingListener(); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + + assertingListener.assertFinished(); + } + public void testBottomFieldSort() throws Exception { testCase(false, false); } @@ -219,7 +239,7 @@ public void sendExecuteQuery( task, SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ), NoopTracer.INSTANCE @@ -262,5 +282,12 @@ public void run() { assertThat(phase.sortedTopDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(((FieldDoc) phase.sortedTopDocs.scoreDocs[0]).fields.length, equalTo(1)); assertThat(((FieldDoc) phase.sortedTopDocs.scoreDocs[0]).fields[0], equalTo(0)); + assertingListener.onPhaseEnd( + action, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + searchRequest + ) + ); } } diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerAssertingListener.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerAssertingListener.java new file mode 100644 index 0000000000000..327371ebcaf0b --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerAssertingListener.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +class SearchRequestOperationsListenerAssertingListener extends SearchRequestOperationsListener { + private volatile SearchPhase phase; + + @Override + protected void onPhaseStart(SearchPhaseContext context) { + assertThat(phase, is(nullValue())); + phase = context.getCurrentPhase(); + } + + @Override + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { + assertThat(phase, is(context.getCurrentPhase())); + phase = null; + } + + @Override + protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { + assertThat(phase, is(context.getCurrentPhase())); + phase = null; + } + + public void assertFinished() { + assertThat(phase, is(nullValue())); + } +}