Skip to content

Commit

Permalink
Add asserting listener class
Browse files Browse the repository at this point in the history
Signed-off-by: David Zane <[email protected]>
  • Loading branch information
dzane17 committed Feb 28, 2024
1 parent d4c9e72 commit dfe74d0
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
Expand All @@ -98,7 +96,7 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {
private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
private ExecutorService executor;
private SearchRequestOperationsListener assertingListener;
private SearchRequestOperationsListenerAssertingListener assertingListener;
ThreadPool threadPool;

@Before
Expand All @@ -107,27 +105,7 @@ public void setUp() throws Exception {
super.setUp();
executor = Executors.newFixedThreadPool(1);
threadPool = new TestThreadPool(getClass().getName());
assertingListener = new SearchRequestOperationsListener() {
private volatile SearchPhase phase;

@Override
protected void onPhaseStart(SearchPhaseContext context) {
assertThat(phase, is(nullValue()));
phase = context.getCurrentPhase();
}

@Override
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {
assertThat(phase, is(context.getCurrentPhase()));
phase = null;
}

@Override
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {
assertThat(phase, is(context.getCurrentPhase()));
phase = null;
}
};
assertingListener = new SearchRequestOperationsListenerAssertingListener();
}

@After
Expand All @@ -137,6 +115,7 @@ public void tearDown() throws Exception {
executor.shutdown();
assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS);
assertingListener.assertFinished();
}

private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
Expand Down Expand Up @@ -363,7 +342,7 @@ public void testOnPhaseFailureAndVerifyListeners() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats(clusterSettings);

final List<SearchRequestOperationsListener> requestOperationListeners = List.of(testListener);
final List<SearchRequestOperationsListener> requestOperationListeners = List.of(testListener, assertingListener);
SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners);
action.start();
assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName()));
Expand Down Expand Up @@ -395,6 +374,7 @@ public void run() {
SearchShardIterator searchShardIterator = new SearchShardIterator(null, shardId, Collections.emptyList(), OriginalIndices.NONE);
searchShardIterator.resetAndSkip();
action.skipShard(searchShardIterator);
action.start();
action.executeNextPhase(action, fetchPhase);
assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName()));
action.onPhaseFailure(new SearchPhase("test") {
Expand Down Expand Up @@ -626,7 +606,7 @@ public void onFailure(Exception e) {
public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedException {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats(clusterSettings);
final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener));
final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener, assertingListener));

long delay = (randomIntBetween(1, 5));
delay = delay * 10;
Expand Down Expand Up @@ -676,7 +656,7 @@ public void testOnPhaseListenersWithQueryAndThenFetchType() throws InterruptedEx
public void testOnPhaseListenersWithDfsType() throws InterruptedException {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
SearchRequestStats testListener = new SearchRequestStats(clusterSettings);
final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener));
final List<SearchRequestOperationsListener> requestOperationListeners = new ArrayList<>(List.of(testListener, assertingListener));

SearchDfsQueryThenFetchAsyncAction searchDfsQueryThenFetchAsyncAction = createSearchDfsQueryThenFetchAsyncAction(
requestOperationListeners
Expand All @@ -697,6 +677,13 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException {
assertThat(testListener.getPhaseMetric(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()), greaterThanOrEqualTo(delay));
assertEquals(1, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
assertingListener.onPhaseEnd(
searchDfsQueryThenFetchAsyncAction,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
new SearchRequest()
)
);
}

private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction(
Expand Down Expand Up @@ -812,7 +799,13 @@ ShardSearchFailure[] buildShardFailures() {

@Override
public void sendSearchResponse(InternalSearchResponse internalSearchResponse, AtomicArray<SearchPhaseResult> queryResults) {
start();
new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger).onPhaseEnd(
this,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger),
searchRequest
)
);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@
import org.opensearch.test.InternalAggregationTestCase;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.Transport;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -76,50 +74,22 @@
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.collection.IsEmptyCollection.empty;

public class CanMatchPreFilterSearchPhaseTests extends OpenSearchTestCase {
private SearchRequestOperationsListener assertingListener;
private Set<SearchPhase> phases;

@Before
public void setUp() throws Exception {
super.setUp();

phases = Collections.newSetFromMap(new IdentityHashMap<>());
assertingListener = new SearchRequestOperationsListener() {
@Override
protected void onPhaseStart(SearchPhaseContext context) {
assertThat(phases.contains(context.getCurrentPhase()), is(false));
phases.add(context.getCurrentPhase());
}

@Override
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {
assertThat(phases.contains(context.getCurrentPhase()), is(true));
phases.remove(context.getCurrentPhase());
}

@Override
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {
assertThat(phases.contains(context.getCurrentPhase()), is(true));
phases.remove(context.getCurrentPhase());
}
};
}

@After
public void tearDown() throws Exception {
super.tearDown();
assertBusy(() -> assertThat(phases, empty()), 5, TimeUnit.SECONDS);
assertingListener = new SearchRequestOperationsListenerAssertingListener();
}

public void testFilterShards() throws InterruptedException {
Expand Down Expand Up @@ -381,7 +351,7 @@ public void sendCanMatch(
randomIntBetween(1, 32),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()),
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
searchRequest
),
NoopTracer.INSTANCE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestOptions;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -80,6 +82,23 @@
import static org.hamcrest.Matchers.greaterThanOrEqualTo;

public class SearchAsyncActionTests extends OpenSearchTestCase {
private SearchRequestOperationsListenerAssertingListener assertingListener;

@Before
@Override
public void setUp() throws Exception {
super.setUp();

assertingListener = new SearchRequestOperationsListenerAssertingListener();
}

@After
@Override
public void tearDown() throws Exception {
super.tearDown();

assertingListener.assertFinished();
}

public void testSkipSearchShards() throws InterruptedException {
SearchRequest request = new SearchRequest();
Expand Down Expand Up @@ -139,7 +158,10 @@ public void testSkipSearchShards() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
),
NoopTracer.INSTANCE
) {

Expand Down Expand Up @@ -189,6 +211,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
assertEquals(0, searchResponse.getFailedShards());
assertEquals(numSkipped, searchResponse.getSkippedShards());
assertEquals(shardsIter.size(), searchResponse.getSuccessfulShards());
assertingListener.onPhaseEnd(
asyncAction,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
)
);
}

public void testLimitConcurrentShardRequests() throws InterruptedException {
Expand Down Expand Up @@ -259,7 +288,10 @@ public void testLimitConcurrentShardRequests() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
),
NoopTracer.INSTANCE
) {

Expand Down Expand Up @@ -315,6 +347,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
latch.await();
assertTrue(searchPhaseDidRun.get());
assertEquals(numShards, numRequests.get());
assertingListener.onPhaseEnd(
asyncAction,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
)
);
}

public void testFanOutAndCollect() throws InterruptedException {
Expand Down Expand Up @@ -378,7 +417,10 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
),
NoopTracer.INSTANCE
) {
TestSearchResponse response = new TestSearchResponse();
Expand Down Expand Up @@ -436,6 +478,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
} else {
assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty());
}
assertingListener.onPhaseEnd(
asyncAction,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
)
);
executor.shutdown();
}

Expand Down Expand Up @@ -502,7 +551,10 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
),
NoopTracer.INSTANCE
) {
TestSearchResponse response = new TestSearchResponse();
Expand Down Expand Up @@ -617,7 +669,10 @@ public void testAllowPartialResults() throws InterruptedException {
new ArraySearchPhaseResults<>(shardsIter.size()),
request.getMaxConcurrentShardRequests(),
SearchResponse.Clusters.EMPTY,
new SearchRequestContext(new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), request),
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
),
NoopTracer.INSTANCE
) {
@Override
Expand Down Expand Up @@ -666,6 +721,13 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
assertTrue(searchPhaseDidRun.get());
assertEquals(numShards, numRequests.get());
assertThat(numFailReplicas.get(), greaterThanOrEqualTo(1));
assertingListener.onPhaseEnd(
asyncAction,
new SearchRequestContext(
new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()),
request
)
);
}

static GroupShardsIterator<SearchShardIterator> getShardsIter(
Expand Down
Loading

0 comments on commit dfe74d0

Please sign in to comment.