diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f439e48ecab7..c563daabcd4c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Skip remote-repositories validations for node-joins when RepositoriesService is not in sync with cluster-state ([#16763](https://github.com/opensearch-project/OpenSearch/pull/16763)) - Fix _list/shards API failing when closed indices are present ([#16606](https://github.com/opensearch-project/OpenSearch/pull/16606)) - Fix remote shards balance ([#15335](https://github.com/opensearch-project/OpenSearch/pull/15335)) +- Fix illegal argument exception when creating a PIT ([#16781](https://github.com/opensearch-project/OpenSearch/pull/16781)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhase.java b/server/src/main/java/org/opensearch/action/search/SearchPhase.java index 0890e9f5de8d4..8eab2ee8dedac 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhase.java @@ -37,6 +37,7 @@ import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Optional; /** * Base class for all individual search phases like collecting distributed frequencies, fetching documents, querying shards. @@ -69,11 +70,15 @@ public String getName() { } /** - * Returns the SearchPhase name as {@link SearchPhaseName}. Exception will come if SearchPhase name is not defined - * in {@link SearchPhaseName} - * @return {@link SearchPhaseName} + * Returns an Optional of the SearchPhase name as {@link SearchPhaseName}. If there's not a matching SearchPhaseName, + * returns an empty Optional. + * @return {@link Optional} */ - public SearchPhaseName getSearchPhaseName() { - return SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT)); + public Optional getSearchPhaseName() { + try { + return Optional.of(SearchPhaseName.valueOf(name.toUpperCase(Locale.ROOT))); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java index 94200d29a4f21..88728436df847 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestStats.java @@ -73,20 +73,22 @@ public long getTookMetric() { @Override protected void onPhaseStart(SearchPhaseContext context) { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> phaseStatsMap.get(name).current.inc()); } @Override protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - StatsHolder phaseStats = phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()); - phaseStats.current.dec(); - phaseStats.total.inc(); - phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos())); + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> { + StatsHolder phaseStats = phaseStatsMap.get(name); + phaseStats.current.dec(); + phaseStats.total.inc(); + phaseStats.timing.inc(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - context.getCurrentPhase().getStartTimeInNanos())); + }); } @Override protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - phaseStatsMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + context.getCurrentPhase().getSearchPhaseName().ifPresent(name -> phaseStatsMap.get(name).current.dec()); } @Override diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 27336e86e52b0..3bc9282e17fc4 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -399,29 +399,29 @@ public void testOnPhaseFailureAndVerifyListeners() { final List requestOperationListeners = List.of(testListener, assertingListener); SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName().get())); SearchDfsQueryThenFetchAsyncAction searchDfsQueryThenFetchAsyncAction = createSearchDfsQueryThenFetchAsyncAction( requestOperationListeners ); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName().get())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -430,15 +430,15 @@ public void run() { action.skipShard(searchShardIterator); action.start(); action.executeNextPhase(action, fetchPhase); - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); action.onPhaseFailure(new SearchPhase("test") { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName().get())); } public void testOnPhaseFailure() { @@ -722,7 +722,7 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.start(); // Verify queryPhase current metric - assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); FetchSearchPhase fetchPhase = createFetchSearchPhase(); @@ -733,12 +733,12 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx action.executeNextPhase(action, fetchPhase); // Verify queryPhase total, current and latency metrics - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertThat(testListener.getPhaseMetric(action.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName().get())); + assertThat(testListener.getPhaseMetric(action.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(action.getSearchPhaseName().get())); // Verify fetchPhase current metric - assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); ExpandSearchPhase expandPhase = createExpandSearchPhase(); @@ -746,18 +746,18 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx TimeUnit.MILLISECONDS.sleep(delay); // Verify fetchPhase total, current and latency metrics - assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(fetchPhase.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(fetchPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName().get())); - assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName().get())); action.executeNextPhase(expandPhase, fetchPhase); action.onPhaseDone(); /* finish phase since we don't have reponse being sent */ - assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(expandPhase.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(expandPhase.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(expandPhase.getSearchPhaseName().get())); } public void testOnPhaseListenersWithDfsType() throws InterruptedException { @@ -772,7 +772,7 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { FetchSearchPhase fetchPhase = createFetchSearchPhase(); searchDfsQueryThenFetchAsyncAction.start(); - assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); TimeUnit.MILLISECONDS.sleep(delay); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); SearchShardIterator searchShardIterator = new SearchShardIterator(null, shardId, Collections.emptyList(), OriginalIndices.NONE); @@ -786,9 +786,9 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { null ); /* finalizing the fetch phase since we do adhoc phase lifecycle calls */ - assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()), greaterThanOrEqualTo(delay)); - assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get()), greaterThanOrEqualTo(delay)); + assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName().get())); } private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java index 990ed95f1aebc..0b62d5a16427a 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java @@ -14,6 +14,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -30,18 +31,18 @@ public void testListenersAreExecuted() { @Override public void onPhaseStart(SearchPhaseContext context) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.inc(); } @Override public void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).total.inc(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).total.inc(); } @Override public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { - searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName().get()).current.dec(); } }; @@ -61,7 +62,7 @@ public void onPhaseFailure(SearchPhaseContext context, Throwable cause) { for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { when(ctx.getCurrentPhase()).thenReturn(searchPhase); - when(searchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(searchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); compositeListener.onPhaseStart(ctx); assertEquals(totalListeners, searchPhaseMap.get(searchPhaseName).current.count()); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java index 3bad3ec3e7d21..876bc395dcd52 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java @@ -16,6 +16,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; import java.util.concurrent.TimeUnit; @@ -68,7 +69,7 @@ public void testSearchRequestPhaseFailure() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); testRequestStats.onPhaseStart(ctx); assertEquals(1, testRequestStats.getPhaseCurrent(searchPhaseName)); testRequestStats.onPhaseFailure(ctx, new Throwable()); @@ -85,7 +86,7 @@ public void testSearchRequestStats() { when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); testRequestStats.onPhaseStart(ctx); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); @@ -116,7 +117,7 @@ public void testSearchRequestStatsOnPhaseStartConcurrently() throws InterruptedE SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); @@ -145,7 +146,7 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); long tookTimeInMillis = randomIntBetween(1, 10); long startTime = System.nanoTime() - TimeUnit.MILLISECONDS.toNanos(tookTimeInMillis); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(startTime); @@ -188,7 +189,7 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte SearchPhaseContext ctx = mock(SearchPhaseContext.class); SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int i = 0; i < numTasks; i++) { threads[i] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); @@ -205,4 +206,25 @@ public void testSearchRequestStatsOnPhaseFailureConcurrently() throws Interrupte assertEquals(0, testRequestStats.getPhaseCurrent(searchPhaseName)); } } + + public void testOtherPhaseNamesAreIgnored() { + // Unrecognized phase names shouldn't be tracked, but should not throw any error. + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + SearchRequestStats testRequestStats = new SearchRequestStats(clusterSettings); + SearchPhaseContext ctx = mock(SearchPhaseContext.class); + SearchPhase mockSearchPhase = new SearchPhase("unrecognized_phase") { + @Override + public void run() {} + }; + when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); + testRequestStats.onPhaseStart(ctx); + testRequestStats.onPhaseEnd( + ctx, + new SearchRequestContext( + new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), + new SearchRequest(), + () -> null + ) + ); + } } diff --git a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java index 594700ea60b3e..519c937e348bc 100644 --- a/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java +++ b/server/src/test/java/org/opensearch/index/search/stats/SearchStatsTests.java @@ -44,6 +44,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -86,7 +87,7 @@ public void testShardLevelSearchGroupStats() throws Exception { SearchPhase mockSearchPhase = mock(SearchPhase.class); when(ctx.getCurrentPhase()).thenReturn(mockSearchPhase); when(mockSearchPhase.getStartTimeInNanos()).thenReturn(System.nanoTime() - TimeUnit.SECONDS.toNanos(paramValue)); - when(mockSearchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); + when(mockSearchPhase.getSearchPhaseName()).thenReturn(Optional.of(searchPhaseName)); for (int iterator = 0; iterator < paramValue; iterator++) { onPhaseStart(testRequestStats, ctx); onPhaseEnd(testRequestStats, ctx);