Skip to content

Commit

Permalink
[Backport 2.x] Optimise unassigned shards iteration after allocator t…
Browse files Browse the repository at this point in the history
…imeout & Fix responsibility check for existing shards allocator when timed out (#15648)

* Optimise unassigned shards iteration after allocator timeout (#14977)

Signed-off-by: Rishab Nahata <[email protected]>
  • Loading branch information
imRishN authored Sep 4, 2024
1 parent 7fdc07a commit 3fa710b
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ public void run() {
"Time taken to execute timed runnables in this cycle:[{}ms]",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
);
onComplete();
}

/**
* Callback method that is invoked after all {@link TimeoutAwareRunnable} instances in the batch have been processed.
* By default, this method does nothing, but it can be overridden by subclasses or modified in the implementation if
* there is a need to perform additional actions once the batch execution is completed.
*/
public void onComplete() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.opensearch.core.index.shard.ShardId;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -82,23 +81,29 @@ public void allocateUnassigned(
executeDecision(shardRouting, allocateUnassignedDecision, allocation, unassignedAllocationHandler);
}

protected void allocateUnassignedBatchOnTimeout(List<ShardRouting> shardRoutings, RoutingAllocation allocation, boolean primary) {
Set<ShardId> shardIdsFromBatch = new HashSet<>();
for (ShardRouting shardRouting : shardRoutings) {
ShardId shardId = shardRouting.shardId();
shardIdsFromBatch.add(shardId);
protected void allocateUnassignedBatchOnTimeout(Set<ShardId> shardIds, RoutingAllocation allocation, boolean primary) {
if (shardIds.isEmpty()) {
return;
}
RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
while (iterator.hasNext()) {
ShardRouting unassignedShard = iterator.next();
AllocateUnassignedDecision allocationDecision;
if (unassignedShard.primary() == primary && shardIdsFromBatch.contains(unassignedShard.shardId())) {
if (unassignedShard.primary() == primary && shardIds.contains(unassignedShard.shardId())) {
if (isResponsibleFor(unassignedShard) == false) {
continue;
}
allocationDecision = AllocateUnassignedDecision.throttle(null);
executeDecision(unassignedShard, allocationDecision, allocation, iterator);
}
}
}

/**
* Is the allocator responsible for allocating the given {@link ShardRouting}?
*/
protected abstract boolean isResponsibleFor(ShardRouting shardRouting);

protected void executeDecision(
ShardRouting shardRouting,
AllocateUnassignedDecision allocateUnassignedDecision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public abstract class PrimaryShardAllocator extends BaseGatewayShardAllocator {
/**
* Is the allocator responsible for allocating the given {@link ShardRouting}?
*/
protected static boolean isResponsibleFor(final ShardRouting shard) {
protected boolean isResponsibleFor(final ShardRouting shard) {
return shard.primary() // must be primary
&& shard.unassigned() // must be unassigned
// only handle either an existing store or a snapshot recovery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public void processExistingRecoveries(RoutingAllocation allocation) {
/**
* Is the allocator responsible for allocating the given {@link ShardRouting}?
*/
protected static boolean isResponsibleFor(final ShardRouting shard) {
protected boolean isResponsibleFor(final ShardRouting shard) {
return shard.primary() == false // must be a replica
&& shard.unassigned() // must be unassigned
// if we are allocating a replica because of index creation, no need to go and find a copy, there isn't one...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private AllocateUnassignedDecision getUnassignedShardAllocationDecision(
RoutingAllocation allocation,
Supplier<Map<DiscoveryNode, StoreFilesMetadata>> nodeStoreFileMetaDataMapSupplier
) {
if (!isResponsibleFor(shardRouting)) {
if (isResponsibleFor(shardRouting) == false) {
return AllocateUnassignedDecision.NOT_TAKEN;
}
Tuple<Decision, Map<String, NodeAllocationResult>> result = canBeAllocatedToAtLeastOneNode(shardRouting, allocation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,41 +277,51 @@ protected BatchRunnableExecutor innerAllocateUnassignedBatch(
}
List<TimeoutAwareRunnable> runnables = new ArrayList<>();
if (primary) {
Set<ShardId> timedOutPrimaryShardIds = new HashSet<>();
batchIdToStartedShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(shardsBatch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(
shardsBatch.getBatchedShardRoutings(),
allocation,
true
);
timedOutPrimaryShardIds.addAll(shardsBatch.getBatchedShards());
}

@Override
public void run() {
primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] primary shards", timedOutPrimaryShardIds.size());
primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutPrimaryShardIds, allocation, true);
}
};
} else {
Set<ShardId> timedOutReplicaShardIds = new HashSet<>();
batchIdToStoreShardBatch.values()
.stream()
.filter(batch -> batchesToAssign.contains(batch.batchId))
.forEach(batch -> runnables.add(new TimeoutAwareRunnable() {
@Override
public void onTimeout() {
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(batch.getBatchedShardRoutings(), allocation, false);
timedOutReplicaShardIds.addAll(batch.getBatchedShards());
}

@Override
public void run() {
replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation);
}
}));
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout);
return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout) {
@Override
public void onComplete() {
logger.trace("Triggering oncomplete after timeout for [{}] replica shards", timedOutReplicaShardIds.size());
replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(timedOutReplicaShardIds, allocation, false);
}
};
}
}

Expand Down Expand Up @@ -846,11 +856,11 @@ public int getNumberOfStoreShardBatches() {
return batchIdToStoreShardBatch.size();
}

private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
protected void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout;
}

private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
protected void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.function.Supplier;

import static org.mockito.Mockito.atMost;
Expand Down Expand Up @@ -42,33 +43,53 @@ public void setupRunnables() {
public void testRunWithoutTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueSeconds(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).run();
verify(runnable2, times(1)).run();
verify(runnable3, times(1)).run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueNanos(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, times(1)).onTimeout();
verify(runnable2, times(1)).onTimeout();
verify(runnable3, times(1)).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithPartialTimeout() {
setupRunnables();
timeoutSupplier = () -> TimeValue.timeValueMillis(50);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
doAnswer(invocation -> {
Thread.sleep(100);
return null;
Expand All @@ -81,17 +102,25 @@ public void testRunWithPartialTimeout() {
verify(runnable3, atMost(1)).onTimeout();
verify(runnable2, atMost(1)).onTimeout();
verify(runnable3, atMost(1)).onTimeout();
assertEquals(0, countDownLatch.getCount());
}

public void testRunWithEmptyRunnableList() {
setupRunnables();
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier);
CountDownLatch countDownLatch = new CountDownLatch(1);
BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier) {
@Override
public void onComplete() {
countDownLatch.countDown();
}
};
executor.run();
verify(runnable1, never()).onTimeout();
verify(runnable2, never()).onTimeout();
verify(runnable3, never()).onTimeout();
verify(runnable1, never()).run();
verify(runnable2, never()).run();
verify(runnable3, never()).run();
assertEquals(1, countDownLatch.getCount());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BatchRunnableExecutor;
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.index.shard.ShardId;
Expand All @@ -45,6 +46,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.gateway.ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING;
Expand Down Expand Up @@ -423,6 +426,24 @@ public void testReplicaAllocatorTimeout() {
assertEquals(-1, REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(build).getMillis());
}

public void testCollectTimedOutShards() throws InterruptedException {
createIndexAndUpdateClusterState(2, 5, 2);
CountDownLatch latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
latch = new CountDownLatch(10);
testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator(latch);
testShardsBatchGatewayAllocator.setPrimaryBatchAllocatorTimeout(TimeValue.ZERO);
testShardsBatchGatewayAllocator.setReplicaBatchAllocatorTimeout(TimeValue.ZERO);
executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
executor.run();
assertTrue(latch.await(1, TimeUnit.MINUTES));
}

private void createIndexAndUpdateClusterState(int count, int numberOfShards, int numberOfReplicas) {
if (count == 0) return;
Metadata.Builder metadata = Metadata.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -52,6 +51,7 @@
import java.util.stream.Collectors;

import static org.opensearch.cluster.routing.UnassignedInfo.Reason.CLUSTER_RECOVERED;
import static org.opensearch.cluster.routing.UnassignedInfo.Reason.INDEX_CREATED;

public class PrimaryShardBatchAllocatorTests extends OpenSearchAllocationTestCase {

Expand Down Expand Up @@ -264,8 +264,9 @@ public void testAllocateUnassignedBatchOnTimeoutWithMatchingPrimaryShards() {
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();

List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
Expand All @@ -277,30 +278,25 @@ public void testAllocateUnassignedBatchOnTimeoutWithNoMatchingPrimaryShards() {
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
List<ShardRouting> shardRoutings = new ArrayList<>();
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
batchAllocator.allocateUnassignedBatchOnTimeout(new HashSet<>(), routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(0, ignoredShards.size());
}

public void testAllocateUnassignedBatchOnTimeoutWithNonPrimaryShards() {
public void testAllocateUnassignedBatchOnTimeoutSkipIgnoringNewPrimaryShards() {
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
setUpShards(1);
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, INDEX_CREATED);
ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();

ShardRouting shardRouting = routingAllocation.routingTable()
.getIndicesRouting()
.get("test")
.shard(shardId.id())
.replicaShards()
.get(0);
List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, false);
Set<ShardId> shardIds = new HashSet<>();
shardIds.add(shardRouting.shardId());
batchAllocator.allocateUnassignedBatchOnTimeout(shardIds, routingAllocation, true);

List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
assertEquals(1, ignoredShards.size());
assertEquals(0, ignoredShards.size());
}

private RoutingAllocation routingAllocationWithOnePrimary(
Expand Down
Loading

0 comments on commit 3fa710b

Please sign in to comment.