diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index ed075277d72bd..916db1bffb738 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -32,6 +32,7 @@ import org.opensearch.indices.recovery.ForceSyncRequest; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationCollection; import org.opensearch.indices.replication.common.ReplicationFailedException; import org.opensearch.indices.replication.common.ReplicationLuceneIndex; @@ -49,6 +50,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.atLeastOnce; @@ -70,10 +72,7 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { private IndexShard replicaShard; private IndexShard primaryShard; private ReplicationCheckpoint checkpoint; - private SegmentReplicationSource replicationSource; private SegmentReplicationTargetService sut; - - private ReplicationCheckpoint initialCheckpoint; private ReplicationCheckpoint aheadCheckpoint; private ReplicationCheckpoint newPrimaryCheckpoint; @@ -83,11 +82,10 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { private DiscoveryNode localNode; private IndicesService indicesService; - private ClusterService clusterService; private SegmentReplicationState state; - private static long TRANSPORT_TIMEOUT = 30000;// 30sec + private static final long TRANSPORT_TIMEOUT = 30000;// 30sec @Override public void setUp() throws Exception { @@ -107,9 +105,6 @@ public void setUp() throws Exception { 0L, replicaShard.getLatestReplicationCheckpoint().getCodec() ); - SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class); - replicationSource = mock(SegmentReplicationSource.class); - when(replicationSourceFactory.get(replicaShard)).thenReturn(replicationSource); testThreadPool = new TestThreadPool("test", Settings.EMPTY); localNode = new DiscoveryNode( @@ -130,7 +125,7 @@ public void setUp() throws Exception { transportService.acceptIncomingRequests(); indicesService = mock(IndicesService.class); - clusterService = mock(ClusterService.class); + ClusterService clusterService = mock(ClusterService.class); ClusterState clusterState = mock(ClusterState.class); RoutingTable mockRoutingTable = mock(RoutingTable.class); when(clusterService.state()).thenReturn(clusterState); @@ -139,7 +134,7 @@ public void setUp() throws Exception { when(clusterState.nodes()).thenReturn(DiscoveryNodes.builder().add(localNode).build()); sut = prepareForReplication(primaryShard, replicaShard, transportService, indicesService, clusterService); - initialCheckpoint = replicaShard.getLatestReplicationCheckpoint(); + ReplicationCheckpoint initialCheckpoint = replicaShard.getLatestReplicationCheckpoint(); aheadCheckpoint = new ReplicationCheckpoint( initialCheckpoint.getShardId(), initialCheckpoint.getPrimaryTerm(), @@ -246,7 +241,46 @@ public void testAlreadyOnNewCheckpoint() { } public void testShardAlreadyReplicating() { - sut.startReplication(replicaShard, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)); + CountDownLatch blockGetCheckpointMetadata = new CountDownLatch(1); + SegmentReplicationSource source = new TestReplicationSource() { + @Override + public void getCheckpointMetadata( + long replicationId, + ReplicationCheckpoint checkpoint, + ActionListener listener + ) { + try { + blockGetCheckpointMetadata.await(); + final CopyState copyState = new CopyState( + ReplicationCheckpoint.empty(primaryShard.shardId(), primaryShard.getLatestReplicationCheckpoint().getCodec()), + primaryShard + ); + listener.onResponse( + new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) + ); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void getSegmentFiles( + long replicationId, + ReplicationCheckpoint checkpoint, + List filesToFetch, + IndexShard indexShard, + ActionListener listener + ) { + listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); + } + }; + final SegmentReplicationTarget target = spy( + new SegmentReplicationTarget(replicaShard, source, mock(SegmentReplicationTargetService.SegmentReplicationListener.class)) + ); + // Start first round of segment replication. + sut.startReplication(target); + + // Start second round of segment replication, this should fail to start as first round is still in-progress sut.startReplication(replicaShard, new SegmentReplicationTargetService.SegmentReplicationListener() { @Override public void onReplicationDone(SegmentReplicationState state) { @@ -259,6 +293,7 @@ public void onReplicationFailure(SegmentReplicationState state, ReplicationFaile assertFalse(sendShardFailure); } }); + blockGetCheckpointMetadata.countDown(); } public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws InterruptedException {