Skip to content

Commit

Permalink
[Bug] Check phase name before SearchRequestOperationsListener onPhase…
Browse files Browse the repository at this point in the history
…Start (opensearch-project#12094) (opensearch-project#12094)

Signed-off-by: David Zane <[email protected]>
(cherry picked from commit fb2c5f2)
  • Loading branch information
dzane17 authored Jan 31, 2024
1 parent de636c1 commit 3a951e5
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix memory leak issue in ReorganizingLongHash ([#11953](https://github.com/opensearch-project/OpenSearch/issues/11953))
- Prevent setting remote_snapshot store type on index creation ([#11867](https://github.com/opensearch-project/OpenSearch/pull/11867))
- [BUG] Fix remote shards balancer when filtering throttled nodes ([#11724](https://github.com/opensearch-project/OpenSearch/pull/11724))
- [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12094](https://github.com/opensearch-project/OpenSearch/pull/12094))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,18 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha
}

private void onPhaseEnd(SearchRequestContext searchRequestContext) {
if (getCurrentPhase() != null) {
if (getCurrentPhase() != null && SearchPhaseName.isValidName(getName())) {
long tookInNanos = System.nanoTime() - getCurrentPhase().getStartTimeInNanos();
searchRequestContext.updatePhaseTookMap(getCurrentPhase().getName(), TimeUnit.NANOSECONDS.toMillis(tookInNanos));
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext);
}
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext);
}

private void onPhaseStart(SearchPhase phase) {
void onPhaseStart(SearchPhase phase) {
setCurrentPhase(phase);
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this);
if (SearchPhaseName.isValidName(phase.getName())) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this);
}
}

private void onRequestEnd(SearchRequestContext searchRequestContext) {
Expand Down Expand Up @@ -714,7 +716,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At

@Override
public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this);
if (SearchPhaseName.isValidName(phase.getName())) {
this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this);
}
raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import org.opensearch.common.annotation.PublicApi;

import java.util.HashSet;
import java.util.Set;

/**
* Enum for different Search Phases in OpenSearch
*
Expand All @@ -25,6 +28,12 @@ public enum SearchPhaseName {
CAN_MATCH("can_match");

private final String name;
private static final Set<String> PHASE_NAMES = new HashSet<>();
static {
for (SearchPhaseName phaseName : SearchPhaseName.values()) {
PHASE_NAMES.add(phaseName.name);
}
}

SearchPhaseName(final String name) {
this.name = name;
Expand All @@ -33,4 +42,8 @@ public enum SearchPhaseName {
public String getName() {
return name;
}

public static boolean isValidName(String phaseName) {
return PHASE_NAMES.contains(phaseName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
clusters,
searchRequestContext
);
return new SearchPhase(action.getName()) {
return new SearchPhase("none") {
@Override
public void run() {
action.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.common.UUIDs;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.set.Sets;
Expand Down Expand Up @@ -334,7 +336,14 @@ public void testOnPhaseFailureAndVerifyListeners() {
SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners);
action.start();
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase("test") {
action.onPhaseFailure(new SearchPhase("none") {
@Override
public void run() {

}
}, "message", null);
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase(action.getName()) {
@Override
public void run() {

Expand All @@ -348,14 +357,14 @@ public void run() {
);
searchDfsQueryThenFetchAsyncAction.start();
assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") {
searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase(searchDfsQueryThenFetchAsyncAction.getName()) {
@Override
public void run() {

}
}, "message", null);
assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));

FetchSearchPhase fetchPhase = createFetchSearchPhase();
ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt());
Expand All @@ -364,7 +373,7 @@ public void run() {
action.skipShard(searchShardIterator);
action.executeNextPhase(action, fetchPhase);
assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase("test") {
action.onPhaseFailure(new SearchPhase(fetchPhase.getName()) {
@Override
public void run() {

Expand Down Expand Up @@ -399,6 +408,30 @@ public void run() {
assertEquals(requestIds, releasedContexts);
}

public void testOnPhaseStart() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats();

final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener));
SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners);

action.onPhaseStart(new SearchPhase("test") {
@Override
public void run() {}
});
action.onPhaseStart(new SearchPhase("none") {
@Override
public void run() {}
});
assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName()));

action.onPhaseStart(new SearchPhase(action.getName()) {
@Override
public void run() {}
});
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
}

public void testShardNotAvailableWithDisallowPartialFailures() {
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false);
AtomicReference<Exception> exception = new AtomicReference<>();
Expand Down

0 comments on commit 3a951e5

Please sign in to comment.