diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java index e9533c0fe941c..ef880043e863c 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerTests.java @@ -14,7 +14,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -22,30 +21,27 @@ public class SearchRequestOperationsListenerTests extends OpenSearchTestCase { public void testListenersAreExecuted() { - Map searchPhaseStartMap = new HashMap<>(); - Map searchPhaseEndMap = new HashMap<>(); - Map searchPhaseFailureMap = new HashMap<>(); + Map searchPhaseMap = new HashMap<>(); for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { - searchPhaseStartMap.put(searchPhaseName, new AtomicInteger()); - searchPhaseEndMap.put(searchPhaseName, new AtomicInteger()); - searchPhaseFailureMap.put(searchPhaseName, new AtomicInteger()); + searchPhaseMap.put(searchPhaseName, new SearchRequestStats.StatsHolder()); } SearchRequestOperationsListener testListener = new SearchRequestOperationsListener() { @Override public void onPhaseStart(SearchPhaseContext context) { - searchPhaseStartMap.get(context.getCurrentPhase().getSearchPhaseName()).incrementAndGet(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.inc(); } @Override public void onPhaseEnd(SearchPhaseContext context) { - searchPhaseEndMap.get(context.getCurrentPhase().getSearchPhaseName()).incrementAndGet(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).total.inc(); } @Override public void onPhaseFailure(SearchPhaseContext context) { - searchPhaseFailureMap.get(context.getCurrentPhase().getSearchPhaseName()).incrementAndGet(); + searchPhaseMap.get(context.getCurrentPhase().getSearchPhaseName()).current.dec(); } }; @@ -65,9 +61,9 @@ public void onPhaseFailure(SearchPhaseContext context) { for (SearchPhaseName searchPhaseName : SearchPhaseName.values()) { when(ctx.getCurrentPhase()).thenReturn(searchPhase); - when(searchPhase.getName()).thenReturn(searchPhaseName.getName()); + when(searchPhase.getSearchPhaseName()).thenReturn(searchPhaseName); compositeListener.onPhaseStart(ctx); - assertEquals(totalListeners, searchPhaseStartMap.get(searchPhaseName).get()); + assertEquals(totalListeners, searchPhaseMap.get(searchPhaseName).current.count()); } } }