diff --git a/CHANGELOG.md b/CHANGELOG.md index 82ef8d647c819..7a2b864bd41df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -109,6 +109,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Fixed - Fix for deserilization bug in weighted round-robin metadata ([#11679](https://github.com/opensearch-project/OpenSearch/pull/11679)) +- [Revert] [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12035](https://github.com/opensearch-project/OpenSearch/pull/12035)) +- Add support of special WrappingSearchAsyncActionPhase so the onPhaseStart() will always be followed by onPhaseEnd() within AbstractSearchAsyncAction ([#12293](https://github.com/opensearch-project/OpenSearch/pull/12293)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 3c27d3ce59e4c..9ce71f2d86798 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -432,18 +432,16 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha } private void onPhaseEnd(SearchRequestContext searchRequestContext) { - if (getCurrentPhase() != null && SearchPhaseName.isValidName(getName())) { + if (getCurrentPhase() != null) { long tookInNanos = System.nanoTime() - getCurrentPhase().getStartTimeInNanos(); searchRequestContext.updatePhaseTookMap(getCurrentPhase().getName(), TimeUnit.NANOSECONDS.toMillis(tookInNanos)); - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - void onPhaseStart(SearchPhase phase) { + private void onPhaseStart(SearchPhase phase) { setCurrentPhase(phase); - if (SearchPhaseName.isValidName(phase.getName())) { - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); - } + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); } private void onRequestEnd(SearchRequestContext searchRequestContext) { @@ -454,10 +452,19 @@ private void executePhase(SearchPhase phase) { try { onPhaseStart(phase); phase.recordAndRun(); + // The WrappingSearchAsyncActionPhase (see please CanMatchPreFilterSearchPhase as one example) is a special case + // of search phase that wraps SearchAsyncActionPhase as SearchPhase. The AbstractSearchAsyncAction manages own + // onPhaseStart / onPhaseFailure / OnPhaseDone callbacks and the wrapping SearchPhase is being abandoned + // (fe, has no onPhaseEnd callbacks called ever). To fix that, the explicit onPhaseEnd is being called + // since SearchPhase::recordAndRun would delegate to AbstractSearchAsyncAction::start internally. + if (phase instanceof WrappingSearchAsyncActionPhase) { + onPhaseEnd(searchRequestContext); + } } catch (Exception e) { if (logger.isDebugEnabled()) { logger.debug(new ParameterizedMessage("Failed to execute [{}] while moving to [{}] phase", request, phase.getName()), e); } + onPhaseFailure(phase, "", e); } } @@ -716,9 +723,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At @Override public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { - if (SearchPhaseName.isValidName(phase.getName())) { - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); - } + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures())); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java index c6f3d4c70632d..8cf92934c8a52 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -10,9 +10,6 @@ import org.opensearch.common.annotation.PublicApi; -import java.util.HashSet; -import java.util.Set; - /** * Enum for different Search Phases in OpenSearch * @@ -28,12 +25,6 @@ public enum SearchPhaseName { CAN_MATCH("can_match"); private final String name; - private static final Set PHASE_NAMES = new HashSet<>(); - static { - for (SearchPhaseName phaseName : SearchPhaseName.values()) { - PHASE_NAMES.add(phaseName.name); - } - } SearchPhaseName(final String name) { this.name = name; @@ -42,8 +33,4 @@ public enum SearchPhaseName { public String getName() { return name; } - - public static boolean isValidName(String phaseName) { - return PHASE_NAMES.contains(phaseName); - } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 79e599ec9387b..3d1a25a8aa01f 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -1220,8 +1220,8 @@ private AbstractSearchAsyncAction searchAsyncAction timeProvider, clusterState, task, - (iter) -> { - AbstractSearchAsyncAction action = searchAsyncAction( + (iter) -> new WrappingSearchAsyncActionPhase( + searchAsyncAction( task, searchRequest, executor, @@ -1237,14 +1237,8 @@ private AbstractSearchAsyncAction searchAsyncAction threadPool, clusters, searchRequestContext - ); - return new SearchPhase("none") { - @Override - public void run() { - action.start(); - } - }; - }, + ) + ), clusters, searchRequestContext ); diff --git a/server/src/main/java/org/opensearch/action/search/WrappingSearchAsyncActionPhase.java b/server/src/main/java/org/opensearch/action/search/WrappingSearchAsyncActionPhase.java new file mode 100644 index 0000000000000..3c1ad52a1fe6a --- /dev/null +++ b/server/src/main/java/org/opensearch/action/search/WrappingSearchAsyncActionPhase.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.search.SearchPhaseResult; + +/** + * The WrappingSearchAsyncActionPhase (see please {@link CanMatchPreFilterSearchPhase} as one example) is a special case + * of search phase that wraps SearchAsyncActionPhase as {@link SearchPhase}. The {@link AbstractSearchAsyncAction} manages own + * onPhaseStart / onPhaseFailure / OnPhaseDone callbacks and but just wrapping it with the SearchPhase causes + * only some callbacks being called. The {@link AbstractSearchAsyncAction} has special treatment of {@link WrappingSearchAsyncActionPhase}. + */ +class WrappingSearchAsyncActionPhase extends SearchPhase { + private final AbstractSearchAsyncAction action; + + protected WrappingSearchAsyncActionPhase(AbstractSearchAsyncAction action) { + super(action.getName()); + this.action = action; + } + + @Override + public void run() { + action.start(); + } + + SearchPhase getSearchPhase() { + return action; + } +} diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index a7cbbffc51ed4..601aa9dc1856e 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -85,6 +85,8 @@ import java.util.function.BiFunction; import java.util.stream.IntStream; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; @@ -95,6 +97,7 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase { private final List> resolvedNodes = new ArrayList<>(); private final Set releasedContexts = new CopyOnWriteArraySet<>(); private ExecutorService executor; + private SearchRequestOperationsListener assertingListener; ThreadPool threadPool; @Before @@ -103,6 +106,27 @@ public void setUp() throws Exception { super.setUp(); executor = Executors.newFixedThreadPool(1); threadPool = new TestThreadPool(getClass().getName()); + assertingListener = new SearchRequestOperationsListener() { + private volatile SearchPhase phase; + + @Override + protected void onPhaseStart(SearchPhaseContext context) { + assertThat(phase, is(nullValue())); + phase = context.getCurrentPhase(); + } + + @Override + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { + assertThat(phase, is(context.getCurrentPhase())); + phase = null; + } + + @Override + protected void onPhaseFailure(SearchPhaseContext context) { + assertThat(phase, is(context.getCurrentPhase())); + phase = null; + } + }; } @After @@ -178,7 +202,10 @@ private AbstractSearchAsyncAction createAction( results, request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request) + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + request + ) ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { @@ -334,18 +361,11 @@ public void testOnPhaseFailureAndVerifyListeners() { ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); SearchRequestStats testListener = new SearchRequestStats(clusterSettings); - final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + final List requestOperationListeners = List.of(testListener); SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase("none") { - @Override - public void run() { - - } - }, "message", null); - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase(action.getName()) { + action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { @@ -359,14 +379,14 @@ public void run() { ); searchDfsQueryThenFetchAsyncAction.start(); assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase(searchDfsQueryThenFetchAsyncAction.getName()) { + searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -375,7 +395,7 @@ public void run() { action.skipShard(searchShardIterator); action.executeNextPhase(action, fetchPhase); assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase(fetchPhase.getName()) { + action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { @@ -410,30 +430,6 @@ public void run() { assertEquals(requestIds, releasedContexts); } - public void testOnPhaseStart() { - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - SearchRequestStats testListener = new SearchRequestStats(clusterSettings); - - final List requestOperationListeners = new ArrayList<>(List.of(testListener)); - SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); - - action.onPhaseStart(new SearchPhase("test") { - @Override - public void run() {} - }); - action.onPhaseStart(new SearchPhase("none") { - @Override - public void run() {} - }); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - - action.onPhaseStart(new SearchPhase(action.getName()) { - @Override - public void run() {} - }); - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); - } - public void testShardNotAvailableWithDisallowPartialFailures() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); AtomicReference exception = new AtomicReference<>(); diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 56dcf66d5607d..30fc50f91dabd 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -32,6 +32,7 @@ package org.opensearch.action.search; import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.action.OriginalIndices; @@ -41,37 +42,84 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.shard.ShardId; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchService; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.sort.MinAndMax; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; +import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.stream.IntStream; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.collection.IsEmptyCollection.empty; public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase { + private SearchRequestOperationsListener assertingListener; + private Set phases; + + @Before + public void setUp() throws Exception { + super.setUp(); + + phases = Collections.newSetFromMap(new IdentityHashMap<>()); + assertingListener = new SearchRequestOperationsListener() { + @Override + protected void onPhaseStart(SearchPhaseContext context) { + assertThat(phases.contains(context.getCurrentPhase()), is(false)); + phases.add(context.getCurrentPhase()); + } + + @Override + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { + assertThat(phases.contains(context.getCurrentPhase()), is(true)); + phases.remove(context.getCurrentPhase()); + } + + @Override + protected void onPhaseFailure(SearchPhaseContext context) { + assertThat(phases.contains(context.getCurrentPhase()), is(true)); + phases.remove(context.getCurrentPhase()); + } + }; + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + assertBusy(() -> assertThat(phases, empty()), 5, TimeUnit.SECONDS); + } public void testFilterShards() throws InterruptedException { @@ -135,11 +183,12 @@ public void sendCanMatch( public void run() throws IOException { result.set(iter); latch.countDown(); + assertingListener.onPhaseEnd(new MockSearchPhaseContext(1, searchRequest, this), null); } }, SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ) ); @@ -230,11 +279,12 @@ public void sendCanMatch( public void run() throws IOException { result.set(iter); latch.countDown(); + assertingListener.onPhaseEnd(new MockSearchPhaseContext(1, searchRequest, this), null); } }, SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ) ); @@ -366,6 +416,7 @@ protected void executePhaseOnShard( canMatchPhase.start(); latch.await(); + executor.shutdown(); } @@ -443,17 +494,19 @@ public void sendCanMatch( public void run() { result.set(iter); latch.countDown(); + assertingListener.onPhaseEnd(new MockSearchPhaseContext(1, searchRequest, this), null); } }, SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ) ); canMatchPhase.start(); latch.await(); + ShardId[] expected = IntStream.range(0, shardIds.size()) .boxed() .sorted(Comparator.comparing(minAndMaxes::get, MinAndMax.getComparator(order)).thenComparing(shardIds::get)) @@ -546,17 +599,19 @@ public void sendCanMatch( public void run() { result.set(iter); latch.countDown(); + assertingListener.onPhaseEnd(new MockSearchPhaseContext(1, searchRequest, this), null); } }, SearchResponse.Clusters.EMPTY, new SearchRequestContext( - new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), searchRequest ) ); canMatchPhase.start(); latch.await(); + int shardId = 0; for (SearchShardIterator i : result.get()) { assertThat(i.shardId().id(), equalTo(shardId++)); @@ -565,4 +620,190 @@ public void run() { assertThat(result.get().size(), equalTo(numShards)); } } + + public void testAsyncAction() throws InterruptedException { + + final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( + 0, + System.nanoTime(), + System::nanoTime + ); + + Map lookup = new ConcurrentHashMap<>(); + DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT); + DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT); + lookup.put("node_1", new SearchAsyncActionTests.MockConnection(primaryNode)); + lookup.put("node_2", new SearchAsyncActionTests.MockConnection(replicaNode)); + final boolean shard1 = randomBoolean(); + final boolean shard2 = randomBoolean(); + + SearchTransportService searchTransportService = new SearchTransportService(null, null) { + @Override + public void sendCanMatch( + Transport.Connection connection, + ShardSearchRequest request, + SearchTask task, + ActionListener listener + ) { + new Thread( + () -> listener.onResponse(new SearchService.CanMatchResponse(request.shardId().id() == 0 ? shard1 : shard2, null)) + ).start(); + } + }; + + AtomicReference> result = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + GroupShardsIterator shardsIter = SearchAsyncActionTests.getShardsIter( + "idx", + new OriginalIndices(new String[] { "idx" }, SearchRequest.DEFAULT_INDICES_OPTIONS), + 2, + randomBoolean(), + primaryNode, + replicaNode + ); + final SearchRequest searchRequest = new SearchRequest(); + searchRequest.allowPartialSearchResults(true); + + SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); + ExecutorService executor = OpenSearchExecutors.newDirectExecutorService(); + SearchRequestContext searchRequestContext = new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), + searchRequest + ); + + SearchPhaseController controller = new SearchPhaseController( + writableRegistry(), + r -> InternalAggregationTestCase.emptyReduceContextBuilder() + ); + + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + task.getProgressListener(), + writableRegistry(), + shardsIter.size(), + exc -> {} + ); + + CanMatchPreFilterSearchPhase canMatchPhase = new CanMatchPreFilterSearchPhase( + logger, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), + Collections.emptyMap(), + Collections.emptyMap(), + executor, + searchRequest, + null, + shardsIter, + timeProvider, + ClusterState.EMPTY_STATE, + null, + (iter) -> { + AbstractSearchAsyncAction action = new SearchDfsQueryAsyncAction( + logger, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), + Collections.emptyMap(), + Collections.emptyMap(), + controller, + executor, + resultConsumer, + searchRequest, + null, + shardsIter, + timeProvider, + ClusterState.EMPTY_STATE, + task, + SearchResponse.Clusters.EMPTY, + searchRequestContext + ); + return new WrappingSearchAsyncActionPhase(action) { + @Override + public void run() { + super.run(); + latch.countDown(); + } + }; + }, + SearchResponse.Clusters.EMPTY, + searchRequestContext + ); + + canMatchPhase.start(); + latch.await(); + + assertThat(result.get(), is(nullValue())); + } + + private static final class SearchDfsQueryAsyncAction extends AbstractSearchAsyncAction { + private final SearchRequestOperationsListener listener; + + SearchDfsQueryAsyncAction( + final Logger logger, + final SearchTransportService searchTransportService, + final BiFunction nodeIdToConnection, + final Map aliasFilter, + final Map concreteIndexBoosts, + final Map> indexRoutings, + final SearchPhaseController searchPhaseController, + final Executor executor, + final QueryPhaseResultConsumer queryPhaseResultConsumer, + final SearchRequest request, + final ActionListener listener, + final GroupShardsIterator shardsIts, + final TransportSearchAction.SearchTimeProvider timeProvider, + final ClusterState clusterState, + final SearchTask task, + SearchResponse.Clusters clusters, + SearchRequestContext searchRequestContext + ) { + super( + SearchPhaseName.DFS_PRE_QUERY.getName(), + logger, + searchTransportService, + nodeIdToConnection, + aliasFilter, + concreteIndexBoosts, + indexRoutings, + executor, + request, + listener, + shardsIts, + timeProvider, + clusterState, + task, + new ArraySearchPhaseResults<>(shardsIts.size()), + request.getMaxConcurrentShardRequests(), + clusters, + searchRequestContext + ); + this.listener = searchRequestContext.getSearchRequestOperationsListener(); + } + + @Override + protected void executePhaseOnShard( + final SearchShardIterator shardIt, + final SearchShardTarget shard, + final SearchActionListener listener + ) { + final DfsSearchResult response = new DfsSearchResult(shardIt.getSearchContextId(), shard, null); + response.setShardIndex(shard.getShardId().getId()); + listener.innerOnResponse(response); + } + + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("last") { + @Override + public void run() throws IOException { + listener.onPhaseEnd(context, null); + } + }; + } + } + } diff --git a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java index 04a00a09dcbc4..cc10da8fc1f12 100644 --- a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java @@ -67,17 +67,27 @@ public final class MockSearchPhaseContext implements SearchPhaseContext { final Set releasedSearchContexts = new HashSet<>(); final SearchRequest searchRequest; final AtomicReference searchResponse = new AtomicReference<>(); + final SearchPhase currentPhase; public MockSearchPhaseContext(int numShards) { this(numShards, new SearchRequest()); } public MockSearchPhaseContext(int numShards, SearchRequest searchRequest) { + this(numShards, searchRequest, null); + } + + public MockSearchPhaseContext(int numShards, SearchRequest searchRequest, SearchPhase currentPhase) { this.numShards = numShards; this.searchRequest = searchRequest; + this.currentPhase = currentPhase; numSuccess = new AtomicInteger(numShards); } + public MockSearchPhaseContext(int numShards, SearchPhase currentPhase) { + this(numShards, new SearchRequest(), currentPhase); + } + public void assertNoFailure() { if (phaseFailure.get() != null) { throw new AssertionError(phaseFailure.get()); @@ -106,7 +116,7 @@ public SearchRequest getRequest() { @Override public SearchPhase getCurrentPhase() { - return null; + return currentPhase; } @Override