diff --git a/CHANGELOG.md b/CHANGELOG.md index 52bf4b996204b..d0b74dc536030 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -159,6 +159,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix memory leak issue in ReorganizingLongHash ([#11953](https://github.com/opensearch-project/OpenSearch/issues/11953)) - Prevent setting remote_snapshot store type on index creation ([#11867](https://github.com/opensearch-project/OpenSearch/pull/11867)) - [BUG] Fix remote shards balancer when filtering throttled nodes ([#11724](https://github.com/opensearch-project/OpenSearch/pull/11724)) +- [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12094](https://github.com/opensearch-project/OpenSearch/pull/12094)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index f18bbb8a1cc13..6f59945bcf533 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -432,16 +432,18 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha } private void onPhaseEnd(SearchRequestContext searchRequestContext) { - if (getCurrentPhase() != null) { + if (getCurrentPhase() != null && SearchPhaseName.isValidName(getName())) { long tookInNanos = System.nanoTime() - getCurrentPhase().getStartTimeInNanos(); searchRequestContext.updatePhaseTookMap(getCurrentPhase().getName(), TimeUnit.NANOSECONDS.toMillis(tookInNanos)); + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - private void onPhaseStart(SearchPhase phase) { + void onPhaseStart(SearchPhase phase) { setCurrentPhase(phase); - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + if (SearchPhaseName.isValidName(phase.getName())) { + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + } } private void onRequestEnd(SearchRequestContext searchRequestContext) { @@ -714,7 +716,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At @Override public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); + if (SearchPhaseName.isValidName(phase.getName())) { + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); + } raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures())); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java index 8cf92934c8a52..c6f3d4c70632d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -10,6 +10,9 @@ import org.opensearch.common.annotation.PublicApi; +import java.util.HashSet; +import java.util.Set; + /** * Enum for different Search Phases in OpenSearch * @@ -25,6 +28,12 @@ public enum SearchPhaseName { CAN_MATCH("can_match"); private final String name; + private static final Set PHASE_NAMES = new HashSet<>(); + static { + for (SearchPhaseName phaseName : SearchPhaseName.values()) { + PHASE_NAMES.add(phaseName.name); + } + } SearchPhaseName(final String name) { this.name = name; @@ -33,4 +42,8 @@ public enum SearchPhaseName { public String getName() { return name; } + + public static boolean isValidName(String phaseName) { + return PHASE_NAMES.contains(phaseName); + } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 05f4308df74fa..1880be5980a51 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -1323,7 +1323,7 @@ private AbstractSearchAsyncAction searchAsyncAction clusters, searchRequestContext ); - return new SearchPhase(action.getName()) { + return new SearchPhase("none") { @Override public void run() { action.start(); diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index e17fbab32a12e..e226d98d85da5 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -38,6 +38,8 @@ import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.common.UUIDs; import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.set.Sets; @@ -334,7 +336,14 @@ public void testOnPhaseFailureAndVerifyListeners() { SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase("test") { + action.onPhaseFailure(new SearchPhase("none") { + @Override + public void run() { + + } + }, "message", null); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + action.onPhaseFailure(new SearchPhase(action.getName()) { @Override public void run() { @@ -348,14 +357,14 @@ public void run() { ); searchDfsQueryThenFetchAsyncAction.start(); assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { + searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase(searchDfsQueryThenFetchAsyncAction.getName()) { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -364,7 +373,7 @@ public void run() { action.skipShard(searchShardIterator); action.executeNextPhase(action, fetchPhase); assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase("test") { + action.onPhaseFailure(new SearchPhase(fetchPhase.getName()) { @Override public void run() { @@ -399,6 +408,30 @@ public void run() { assertEquals(requestIds, releasedContexts); } + public void testOnPhaseStart() { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + SearchRequestStats testListener = new SearchRequestStats(); + + final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); + + action.onPhaseStart(new SearchPhase("test") { + @Override + public void run() {} + }); + action.onPhaseStart(new SearchPhase("none") { + @Override + public void run() {} + }); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); + + action.onPhaseStart(new SearchPhase(action.getName()) { + @Override + public void run() {} + }); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + } + public void testShardNotAvailableWithDisallowPartialFailures() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); AtomicReference exception = new AtomicReference<>();