Skip to content

Commit

Permalink
update nodeconnectionsservice test
Browse files Browse the repository at this point in the history
Signed-off-by: Rahul Karajgikar <[email protected]>
  • Loading branch information
Rahul Karajgikar committed Sep 25, 2024
1 parent 1ded2cc commit 9a060cc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public class NodeConnectionsService extends AbstractLifecycleComponent {
protected final Map<DiscoveryNode, ConnectionTarget> targetsByNode = new HashMap<>();

private final TimeValue reconnectInterval;
private volatile ConnectionChecker connectionChecker;
protected volatile ConnectionChecker connectionChecker;

@Inject
public NodeConnectionsService(Settings settings, ThreadPool threadPool, TransportService transportService) {
Expand Down Expand Up @@ -224,7 +224,7 @@ private void awaitPendingActivity(Runnable onCompletion) {
* nodes which are in the process of disconnecting. The onCompletion handler is called after all ongoing connection/disconnection
* attempts have completed.
*/
private void connectDisconnectedTargets(Runnable onCompletion) {
protected void connectDisconnectedTargets(Runnable onCompletion) {
final List<Runnable> runnables = new ArrayList<>();
synchronized (mutex) {
final Collection<ConnectionTarget> connectionTargets = targetsByNode.values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ public void testConnectionCheckerRetriesIfPendingDisconnection() throws Interrup
transportService.start();
transportService.acceptIncomingRequests();

final NodeConnectionsService service = new NodeConnectionsService(
final TestNodeConnectionsService service = new TestNodeConnectionsService(
settings.build(),
deterministicTaskQueue.getThreadPool(),
transportService
Expand All @@ -532,37 +532,53 @@ public void testConnectionCheckerRetriesIfPendingDisconnection() throws Interrup
deterministicTaskQueue.runAllRunnableTasks();
assertTrue(connectionCompleted.get());

// now trigger a disconnect, and then set pending disconnections to true to fail any new connections
// reset any logs as we want to assert for exceptions that show up after this
// reset connect to node count to assert for later
logger.info("--> resetting captured logs and counters");
testLogsAppender.clearCapturedLogs();
// this ensures we only track connection attempts that happen after the disconnection
transportService.resetConnectToNodeCallCount();

// block connection checker reconnection attempts until after we set pending disconnections
logger.info("--> disabling connection checker, and triggering disconnect");
service.setShouldReconnect(false);
transportService.disconnectFromNode(node);

// set pending disconnections to true to fail future reconnection attempts
final long maxDisconnectionTime = 1000;
deterministicTaskQueue.scheduleNow(new Runnable() {
@Override
public void run() {
transportService.disconnectFromNode(node);
logger.info("--> setting pending disconnections to fail next connection attempts");
service.setPendingDisconnections(new HashSet<>(Collections.singleton(node)));
// we reset the connection count during the first disconnection
// we also clear the captured logs as we want to assert for exceptions that show up after this
testLogsAppender.clearCapturedLogs();
transportService.resetConnectToNodeCallCount();
}

@Override
public String toString() {
return "scheduled disconnection of " + node;
}
});
// our task queue will have the first task as the runnable to set pending disconnections
// here we re-enable the connection checker to enqueue next tasks for attempting reconnection
logger.info("--> re-enabling reconnection checker");
service.setShouldReconnect(true);

final long maxReconnectionTime = 2000;
final int expectedReconnectionAttempts = 5;
final int expectedReconnectionAttempts = 10;

// ensure the disconnect task completes, and run for additional time to check for reconnections
// exit early if we see enough reconnection attempts
logger.info("--> verifying connectionchecker is trying to reconnect");
runTasksUntilExpectedReconnectionAttempts(
// this will first run the task to set the pending disconnections, then will execute the reconnection tasks
// exit early when we have enough reconnection attempts
logger.info("--> running tasks in order until expected reconnection attempts");
runTasksInOrderUntilExpectedReconnectionAttempts(
deterministicTaskQueue,
maxDisconnectionTime + maxReconnectionTime,
transportService,
expectedReconnectionAttempts
);
logger.info("--> verifying that connectionchecker tried to reconnect");

// assert that the connections failed
assertFalse("connected to " + node, transportService.nodeConnected(node));

// assert that we saw at least the required number of reconnection attempts, and the exceptions that showed up are as expected
logger.info("--> number of reconnection attempts: {}", transportService.getConnectToNodeCallCount());
Expand All @@ -578,7 +594,6 @@ public String toString() {
TimeUnit.SECONDS
);
assertTrue("Expected log for reconnection failure was not found in the required time period", logFound);
assertFalse("connected to " + node, transportService.nodeConnected(node));

// clear the pending disconnections and ensure the connection gets re-established automatically by connectionchecker
logger.info("--> clearing pending disconnections to allow connections to re-establish");
Expand All @@ -598,7 +613,7 @@ private void runTasksUntil(DeterministicTaskQueue deterministicTaskQueue, long e
deterministicTaskQueue.runAllRunnableTasks();
}

private void runTasksUntilExpectedReconnectionAttempts(
private void runTasksInOrderUntilExpectedReconnectionAttempts(
DeterministicTaskQueue deterministicTaskQueue,
long endTimeMillis,
TestTransportService transportService,
Expand All @@ -608,12 +623,12 @@ private void runTasksUntilExpectedReconnectionAttempts(
while ((deterministicTaskQueue.getCurrentTimeMillis() < endTimeMillis)
&& (transportService.getConnectToNodeCallCount() <= expectedReconnectionAttempts)) {
if (deterministicTaskQueue.hasRunnableTasks() && randomBoolean()) {
deterministicTaskQueue.runRandomTask();
deterministicTaskQueue.runNextTask();
} else if (deterministicTaskQueue.hasDeferredTasks()) {
deterministicTaskQueue.advanceTime();
}
}
deterministicTaskQueue.runAllRunnableTasks();
deterministicTaskQueue.runAllRunnableTasksInEnqueuedOrder();
}

private void ensureConnections(NodeConnectionsService service) {
Expand Down Expand Up @@ -736,6 +751,37 @@ public void resetConnectToNodeCallCount() {
}
}

private class TestNodeConnectionsService extends NodeConnectionsService {
private boolean shouldReconnect = true;

public TestNodeConnectionsService(Settings settings, ThreadPool threadPool, TransportService transportService) {
super(settings, threadPool, transportService);
}

public void setShouldReconnect(boolean shouldReconnect) {
this.shouldReconnect = shouldReconnect;
}

@Override
protected void doStart() {
final StoppableConnectionChecker connectionChecker = new StoppableConnectionChecker();
this.connectionChecker = connectionChecker;
connectionChecker.scheduleNextCheck();
}

class StoppableConnectionChecker extends NodeConnectionsService.ConnectionChecker {
@Override
protected void doRun() {
if (connectionChecker == this && shouldReconnect) {
connectDisconnectedTargets(this::scheduleNextCheck);
} else {
// Skip reconnection attempt but still schedule the next check
scheduleNextCheck();
}
}
}
}

private static final class MockTransport implements Transport {
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final RequestHandlers requestHandlers = new RequestHandlers();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ public void runAllRunnableTasks() {
}
}

public void runAllRunnableTasksInEnqueuedOrder() {
while (hasRunnableTasks()) {
runTask(0);
}
}

public void runAllTasks() {
while (hasDeferredTasks() || hasRunnableTasks()) {
if (hasDeferredTasks() && random.nextBoolean()) {
Expand Down Expand Up @@ -141,6 +147,11 @@ public void runRandomTask() {
runTask(RandomNumbers.randomIntBetween(random, 0, runnableTasks.size() - 1));
}

public void runNextTask() {
assert hasRunnableTasks();
runTask(0);
}

private void runTask(final int index) {
final Runnable task = runnableTasks.remove(index);
logger.trace("running task {} of {}: {}", index, runnableTasks.size() + 1, task);
Expand Down

0 comments on commit 9a060cc

Please sign in to comment.