From 681087a57a6800f2eef025468e24d1aedcae7b4f Mon Sep 17 00:00:00 2001
From: Bukhtawar Khan <bukhtawa@amazon.com>
Date: Wed, 24 Jul 2024 17:23:55 +0530
Subject: [PATCH] Make reroute iteration time-bound for large shard allocations
 (#14848)

* Make reroute iteration time-bound for large shard allocations

Signed-off-by: Bukhtawar Khan <bukhtawa@amazon.com>
Co-authored-by: Rishab Nahata <rnnahata@amazon.com>
---
 CHANGELOG.md                                  |   1 +
 .../gateway/RecoveryFromGatewayIT.java        | 128 +++++++++++++++++-
 .../routing/allocation/AllocationService.java |   5 +-
 .../allocation/ExistingShardsAllocator.java   |   7 +-
 .../common/settings/ClusterSettings.java      |   2 +
 .../common/util/BatchRunnableExecutor.java    |  66 +++++++++
 .../util/concurrent/TimeoutAwareRunnable.java |  19 +++
 .../gateway/BaseGatewayShardAllocator.java    |  21 +++
 .../gateway/ShardsBatchGatewayAllocator.java  |  86 ++++++++++--
 .../ExistingShardsAllocatorTests.java         | 118 ++++++++++++++++
 .../util/BatchRunnableExecutorTests.java      |  97 +++++++++++++
 .../gateway/GatewayAllocatorTests.java        |  32 +++++
 .../PrimaryShardBatchAllocatorTests.java      |  47 +++++++
 .../ReplicaShardBatchAllocatorTests.java      |  27 ++++
 .../TestShardBatchGatewayAllocator.java       |   5 +-
 15 files changed, 645 insertions(+), 16 deletions(-)
 create mode 100644 server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java
 create mode 100644 server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java
 create mode 100644 server/src/test/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocatorTests.java
 create mode 100644 server/src/test/java/org/opensearch/common/util/BatchRunnableExecutorTests.java

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6aa3d7a58dda4..edc0ca2732f25 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -62,6 +62,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Allow @InternalApi annotation on classes not meant to be constructed outside of the OpenSearch core ([#14575](https://github.com/opensearch-project/OpenSearch/pull/14575))
 - Add @InternalApi annotation to japicmp exclusions ([#14597](https://github.com/opensearch-project/OpenSearch/pull/14597))
 - Allow system index warning in OpenSearchRestTestCase.refreshAllIndices ([#14635](https://github.com/opensearch-project/OpenSearch/pull/14635))
+- Make reroute iteration time-bound for large shard allocations ([#14848](https://github.com/opensearch-project/OpenSearch/pull/14848))
 
 ### Deprecated
 - Deprecate batch_size parameter on bulk API ([#14725](https://github.com/opensearch-project/OpenSearch/pull/14725))
diff --git a/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java b/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java
index 6296608c64d37..4085cc3890f30 100644
--- a/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java
+++ b/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java
@@ -769,7 +769,7 @@ public void testMessyElectionsStillMakeClusterGoGreen() throws Exception {
         ensureGreen("test");
     }
 
-    public void testBatchModeEnabled() throws Exception {
+    public void testBatchModeEnabledWithoutTimeout() throws Exception {
         internalCluster().startClusterManagerOnlyNodes(
             1,
             Settings.builder().put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.getKey(), true).build()
@@ -810,6 +810,132 @@ public void testBatchModeEnabled() throws Exception {
         assertEquals(0, gatewayAllocator.getNumberOfInFlightFetches());
     }
 
+    public void testBatchModeEnabledWithSufficientTimeoutAndClusterGreen() throws Exception {
+        internalCluster().startClusterManagerOnlyNodes(
+            1,
+            Settings.builder()
+                .put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.getKey(), true)
+                .put(ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING.getKey(), "20s")
+                .put(ShardsBatchGatewayAllocator.REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.getKey(), "20s")
+                .build()
+        );
+        List<String> dataOnlyNodes = internalCluster().startDataOnlyNodes(2);
+        createIndex(
+            "test",
+            Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1).build()
+        );
+        ensureGreen("test");
+        Settings node0DataPathSettings = internalCluster().dataPathSettings(dataOnlyNodes.get(0));
+        Settings node1DataPathSettings = internalCluster().dataPathSettings(dataOnlyNodes.get(1));
+        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(dataOnlyNodes.get(0)));
+        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(dataOnlyNodes.get(1)));
+        ensureRed("test");
+        ensureStableCluster(1);
+
+        logger.info("--> Now do a protective reroute");
+        ClusterRerouteResponse clusterRerouteResponse = client().admin().cluster().prepareReroute().setRetryFailed(true).get();
+        assertTrue(clusterRerouteResponse.isAcknowledged());
+
+        ShardsBatchGatewayAllocator gatewayAllocator = internalCluster().getInstance(
+            ShardsBatchGatewayAllocator.class,
+            internalCluster().getClusterManagerName()
+        );
+        assertTrue(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.get(internalCluster().clusterService().getSettings()));
+        assertEquals(1, gatewayAllocator.getNumberOfStartedShardBatches());
+        assertEquals(1, gatewayAllocator.getNumberOfStoreShardBatches());
+
+        // Now start both data nodes and ensure batch mode is working
+        logger.info("--> restarting the stopped nodes");
+        internalCluster().startDataOnlyNode(Settings.builder().put("node.name", dataOnlyNodes.get(0)).put(node0DataPathSettings).build());
+        internalCluster().startDataOnlyNode(Settings.builder().put("node.name", dataOnlyNodes.get(1)).put(node1DataPathSettings).build());
+        ensureStableCluster(3);
+        ensureGreen("test");
+        assertEquals(0, gatewayAllocator.getNumberOfStartedShardBatches());
+        assertEquals(0, gatewayAllocator.getNumberOfStoreShardBatches());
+        assertEquals(0, gatewayAllocator.getNumberOfInFlightFetches());
+    }
+
+    public void testBatchModeEnabledWithInSufficientTimeoutButClusterGreen() throws Exception {
+
+        internalCluster().startClusterManagerOnlyNodes(
+            1,
+            Settings.builder().put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.getKey(), true).build()
+        );
+        List<String> dataOnlyNodes = internalCluster().startDataOnlyNodes(2);
+        createNIndices(50, "test"); // this will create 50p, 50r shards
+        ensureStableCluster(3);
+        IndicesStatsResponse indicesStats = dataNodeClient().admin().indices().prepareStats().get();
+        assertThat(indicesStats.getSuccessfulShards(), equalTo(100));
+        ClusterHealthResponse health = client().admin()
+            .cluster()
+            .health(Requests.clusterHealthRequest().waitForGreenStatus().timeout("1m"))
+            .actionGet();
+        assertFalse(health.isTimedOut());
+        assertEquals(GREEN, health.getStatus());
+
+        String clusterManagerName = internalCluster().getClusterManagerName();
+        Settings clusterManagerDataPathSettings = internalCluster().dataPathSettings(clusterManagerName);
+        Settings node0DataPathSettings = internalCluster().dataPathSettings(dataOnlyNodes.get(0));
+        Settings node1DataPathSettings = internalCluster().dataPathSettings(dataOnlyNodes.get(1));
+
+        internalCluster().stopCurrentClusterManagerNode();
+        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(dataOnlyNodes.get(0)));
+        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(dataOnlyNodes.get(1)));
+
+        // Now start cluster manager node and post that verify batches created
+        internalCluster().startClusterManagerOnlyNodes(
+            1,
+            Settings.builder()
+                .put("node.name", clusterManagerName)
+                .put(clusterManagerDataPathSettings)
+                .put(ShardsBatchGatewayAllocator.GATEWAY_ALLOCATOR_BATCH_SIZE.getKey(), 5)
+                .put(ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING.getKey(), "10ms")
+                .put(ShardsBatchGatewayAllocator.REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.getKey(), "10ms")
+                .put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.getKey(), true)
+                .build()
+        );
+        ensureStableCluster(1);
+
+        logger.info("--> Now do a protective reroute"); // to avoid any race condition in test
+        ClusterRerouteResponse clusterRerouteResponse = client().admin().cluster().prepareReroute().setRetryFailed(true).get();
+        assertTrue(clusterRerouteResponse.isAcknowledged());
+
+        ShardsBatchGatewayAllocator gatewayAllocator = internalCluster().getInstance(
+            ShardsBatchGatewayAllocator.class,
+            internalCluster().getClusterManagerName()
+        );
+
+        assertTrue(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_BATCH_MODE.get(internalCluster().clusterService().getSettings()));
+        assertEquals(10, gatewayAllocator.getNumberOfStartedShardBatches());
+        assertEquals(10, gatewayAllocator.getNumberOfStoreShardBatches());
+        health = client(internalCluster().getClusterManagerName()).admin().cluster().health(Requests.clusterHealthRequest()).actionGet();
+        assertFalse(health.isTimedOut());
+        assertEquals(RED, health.getStatus());
+        assertEquals(100, health.getUnassignedShards());
+        assertEquals(0, health.getInitializingShards());
+        assertEquals(0, health.getActiveShards());
+        assertEquals(0, health.getRelocatingShards());
+        assertEquals(0, health.getNumberOfDataNodes());
+
+        // Now start both data nodes and ensure batch mode is working
+        logger.info("--> restarting the stopped nodes");
+        internalCluster().startDataOnlyNode(Settings.builder().put("node.name", dataOnlyNodes.get(0)).put(node0DataPathSettings).build());
+        internalCluster().startDataOnlyNode(Settings.builder().put("node.name", dataOnlyNodes.get(1)).put(node1DataPathSettings).build());
+        ensureStableCluster(3);
+
+        // wait for cluster to turn green
+        health = client().admin().cluster().health(Requests.clusterHealthRequest().waitForGreenStatus().timeout("5m")).actionGet();
+        assertFalse(health.isTimedOut());
+        assertEquals(GREEN, health.getStatus());
+        assertEquals(0, health.getUnassignedShards());
+        assertEquals(0, health.getInitializingShards());
+        assertEquals(100, health.getActiveShards());
+        assertEquals(0, health.getRelocatingShards());
+        assertEquals(2, health.getNumberOfDataNodes());
+        assertEquals(0, gatewayAllocator.getNumberOfStartedShardBatches());
+        assertEquals(0, gatewayAllocator.getNumberOfStoreShardBatches());
+    }
+
     public void testBatchModeDisabled() throws Exception {
         internalCluster().startClusterManagerOnlyNodes(
             1,
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java
index 5ad3a2fd47ce3..e29a81a2c131f 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java
@@ -72,6 +72,7 @@
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -617,10 +618,10 @@ private void allocateExistingUnassignedShards(RoutingAllocation allocation) {
 
     private void allocateAllUnassignedShards(RoutingAllocation allocation) {
         ExistingShardsAllocator allocator = existingShardsAllocators.get(ShardsBatchGatewayAllocator.ALLOCATOR_NAME);
-        allocator.allocateAllUnassignedShards(allocation, true);
+        Optional.ofNullable(allocator.allocateAllUnassignedShards(allocation, true)).ifPresent(Runnable::run);
         allocator.afterPrimariesBeforeReplicas(allocation);
         // Replicas Assignment
-        allocator.allocateAllUnassignedShards(allocation, false);
+        Optional.ofNullable(allocator.allocateAllUnassignedShards(allocation, false)).ifPresent(Runnable::run);
     }
 
     private void disassociateDeadNodes(RoutingAllocation allocation) {
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java
index fb2a37237f8b6..eb7a1e7209c37 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java
@@ -41,6 +41,7 @@
 import org.opensearch.gateway.GatewayAllocator;
 import org.opensearch.gateway.ShardsBatchGatewayAllocator;
 
+import java.util.ArrayList;
 import java.util.List;
 
 /**
@@ -108,14 +109,16 @@ void allocateUnassigned(
      *
      * Allocation service will currently run the default implementation of it implemented by {@link ShardsBatchGatewayAllocator}
      */
-    default void allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) {
+    default Runnable allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) {
         RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
+        List<Runnable> runnables = new ArrayList<>();
         while (iterator.hasNext()) {
             ShardRouting shardRouting = iterator.next();
             if (shardRouting.primary() == primary) {
-                allocateUnassigned(shardRouting, allocation, iterator);
+                runnables.add(() -> allocateUnassigned(shardRouting, allocation, iterator));
             }
         }
+        return () -> runnables.forEach(Runnable::run);
     }
 
     /**
diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
index 49801fd3834b8..2f60c731bc554 100644
--- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
+++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
@@ -343,6 +343,8 @@ public void apply(Settings value, Settings current, Settings previous) {
                 GatewayService.RECOVER_AFTER_NODES_SETTING,
                 GatewayService.RECOVER_AFTER_TIME_SETTING,
                 ShardsBatchGatewayAllocator.GATEWAY_ALLOCATOR_BATCH_SIZE,
+                ShardsBatchGatewayAllocator.PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING,
+                ShardsBatchGatewayAllocator.REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING,
                 PersistedClusterStateService.SLOW_WRITE_LOGGING_THRESHOLD,
                 NetworkModule.HTTP_DEFAULT_TYPE_SETTING,
                 NetworkModule.TRANSPORT_DEFAULT_TYPE_SETTING,
diff --git a/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java b/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java
new file mode 100644
index 0000000000000..d3d3304cb909a
--- /dev/null
+++ b/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java
@@ -0,0 +1,66 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.common.util;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.opensearch.common.Randomness;
+import org.opensearch.common.unit.TimeValue;
+import org.opensearch.common.util.concurrent.TimeoutAwareRunnable;
+
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+/**
+ * A {@link Runnable} that iteratively executes a batch of {@link TimeoutAwareRunnable}s. If the elapsed time exceeds the timeout defined by {@link TimeValue} timeout, then all subsequent {@link TimeoutAwareRunnable}s will have their {@link TimeoutAwareRunnable#onTimeout} method invoked and will not be run.
+ *
+ * @opensearch.internal
+ */
+public class BatchRunnableExecutor implements Runnable {
+
+    private final Supplier<TimeValue> timeoutSupplier;
+
+    private final List<TimeoutAwareRunnable> timeoutAwareRunnables;
+
+    private static final Logger logger = LogManager.getLogger(BatchRunnableExecutor.class);
+
+    public BatchRunnableExecutor(List<TimeoutAwareRunnable> timeoutAwareRunnables, Supplier<TimeValue> timeoutSupplier) {
+        this.timeoutSupplier = timeoutSupplier;
+        this.timeoutAwareRunnables = timeoutAwareRunnables;
+    }
+
+    // for tests
+    public List<TimeoutAwareRunnable> getTimeoutAwareRunnables() {
+        return this.timeoutAwareRunnables;
+    }
+
+    @Override
+    public void run() {
+        logger.debug("Starting execution of runnable of size [{}]", timeoutAwareRunnables.size());
+        long startTime = System.nanoTime();
+        if (timeoutAwareRunnables.isEmpty()) {
+            return;
+        }
+        Randomness.shuffle(timeoutAwareRunnables);
+        for (TimeoutAwareRunnable runnable : timeoutAwareRunnables) {
+            if (timeoutSupplier.get().nanos() < 0 || System.nanoTime() - startTime < timeoutSupplier.get().nanos()) {
+                runnable.run();
+            } else {
+                logger.debug("Executing timeout for runnable of size [{}]", timeoutAwareRunnables.size());
+                runnable.onTimeout();
+            }
+        }
+        logger.debug(
+            "Time taken to execute timed runnables in this cycle:[{}ms]",
+            TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)
+        );
+    }
+
+}
diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java b/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java
new file mode 100644
index 0000000000000..8d3357ad93095
--- /dev/null
+++ b/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java
@@ -0,0 +1,19 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.common.util.concurrent;
+
+/**
+ * Runnable that is aware of a timeout
+ *
+ * @opensearch.internal
+ */
+public interface TimeoutAwareRunnable extends Runnable {
+
+    void onTimeout();
+}
diff --git a/server/src/main/java/org/opensearch/gateway/BaseGatewayShardAllocator.java b/server/src/main/java/org/opensearch/gateway/BaseGatewayShardAllocator.java
index 58982e869794f..0d6af943d39e0 100644
--- a/server/src/main/java/org/opensearch/gateway/BaseGatewayShardAllocator.java
+++ b/server/src/main/java/org/opensearch/gateway/BaseGatewayShardAllocator.java
@@ -36,6 +36,7 @@
 import org.apache.logging.log4j.Logger;
 import org.opensearch.cluster.routing.RecoverySource;
 import org.opensearch.cluster.routing.RoutingNode;
+import org.opensearch.cluster.routing.RoutingNodes;
 import org.opensearch.cluster.routing.ShardRouting;
 import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.opensearch.cluster.routing.allocation.AllocationDecision;
@@ -43,9 +44,12 @@
 import org.opensearch.cluster.routing.allocation.NodeAllocationResult;
 import org.opensearch.cluster.routing.allocation.RoutingAllocation;
 import org.opensearch.cluster.routing.allocation.decider.Decision;
+import org.opensearch.core.index.shard.ShardId;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 /**
  * An abstract class that implements basic functionality for allocating
@@ -78,6 +82,23 @@ public void allocateUnassigned(
         executeDecision(shardRouting, allocateUnassignedDecision, allocation, unassignedAllocationHandler);
     }
 
+    protected void allocateUnassignedBatchOnTimeout(List<ShardRouting> shardRoutings, RoutingAllocation allocation, boolean primary) {
+        Set<ShardId> shardIdsFromBatch = new HashSet<>();
+        for (ShardRouting shardRouting : shardRoutings) {
+            ShardId shardId = shardRouting.shardId();
+            shardIdsFromBatch.add(shardId);
+        }
+        RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
+        while (iterator.hasNext()) {
+            ShardRouting unassignedShard = iterator.next();
+            AllocateUnassignedDecision allocationDecision;
+            if (unassignedShard.primary() == primary && shardIdsFromBatch.contains(unassignedShard.shardId())) {
+                allocationDecision = AllocateUnassignedDecision.throttle(null);
+                executeDecision(unassignedShard, allocationDecision, allocation, iterator);
+            }
+        }
+    }
+
     protected void executeDecision(
         ShardRouting shardRouting,
         AllocateUnassignedDecision allocateUnassignedDecision,
diff --git a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java
index 3c0797cd450d2..55f5388d8f454 100644
--- a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java
+++ b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java
@@ -27,9 +27,13 @@
 import org.opensearch.common.UUIDs;
 import org.opensearch.common.inject.Inject;
 import org.opensearch.common.lease.Releasables;
+import org.opensearch.common.settings.ClusterSettings;
 import org.opensearch.common.settings.Setting;
 import org.opensearch.common.settings.Settings;
+import org.opensearch.common.unit.TimeValue;
+import org.opensearch.common.util.BatchRunnableExecutor;
 import org.opensearch.common.util.concurrent.ConcurrentCollections;
+import org.opensearch.common.util.concurrent.TimeoutAwareRunnable;
 import org.opensearch.common.util.set.Sets;
 import org.opensearch.core.action.ActionListener;
 import org.opensearch.core.index.shard.ShardId;
@@ -41,6 +45,7 @@
 import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper;
 import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper.StoreFilesMetadata;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -68,6 +73,14 @@ public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator {
     private final long maxBatchSize;
     private static final short DEFAULT_SHARD_BATCH_SIZE = 2000;
 
+    private static final String PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING_KEY =
+        "cluster.routing.allocation.shards_batch_gateway_allocator.primary_allocator_timeout";
+    private static final String REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING_KEY =
+        "cluster.routing.allocation.shards_batch_gateway_allocator.replica_allocator_timeout";
+
+    private TimeValue primaryShardsBatchGatewayAllocatorTimeout;
+    private TimeValue replicaShardsBatchGatewayAllocatorTimeout;
+
     /**
      * Number of shards we send in one batch to data nodes for fetching metadata
      */
@@ -79,6 +92,20 @@ public class ShardsBatchGatewayAllocator implements ExistingShardsAllocator {
         Setting.Property.NodeScope
     );
 
+    public static final Setting<TimeValue> PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING = Setting.timeSetting(
+        PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING_KEY,
+        TimeValue.MINUS_ONE,
+        Setting.Property.NodeScope,
+        Setting.Property.Dynamic
+    );
+
+    public static final Setting<TimeValue> REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING = Setting.timeSetting(
+        REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING_KEY,
+        TimeValue.MINUS_ONE,
+        Setting.Property.NodeScope,
+        Setting.Property.Dynamic
+    );
+
     private final RerouteService rerouteService;
     private final PrimaryShardBatchAllocator primaryShardBatchAllocator;
     private final ReplicaShardBatchAllocator replicaShardBatchAllocator;
@@ -97,7 +124,8 @@ public ShardsBatchGatewayAllocator(
         RerouteService rerouteService,
         TransportNodesListGatewayStartedShardsBatch batchStartedAction,
         TransportNodesListShardStoreMetadataBatch batchStoreAction,
-        Settings settings
+        Settings settings,
+        ClusterSettings clusterSettings
     ) {
         this.rerouteService = rerouteService;
         this.primaryShardBatchAllocator = new InternalPrimaryBatchShardAllocator();
@@ -105,6 +133,10 @@ public ShardsBatchGatewayAllocator(
         this.batchStartedAction = batchStartedAction;
         this.batchStoreAction = batchStoreAction;
         this.maxBatchSize = GATEWAY_ALLOCATOR_BATCH_SIZE.get(settings);
+        this.primaryShardsBatchGatewayAllocatorTimeout = PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(settings);
+        clusterSettings.addSettingsUpdateConsumer(PRIMARY_BATCH_ALLOCATOR_TIMEOUT_SETTING, this::setPrimaryBatchAllocatorTimeout);
+        this.replicaShardsBatchGatewayAllocatorTimeout = REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING.get(settings);
+        clusterSettings.addSettingsUpdateConsumer(REPLICA_BATCH_ALLOCATOR_TIMEOUT_SETTING, this::setReplicaBatchAllocatorTimeout);
     }
 
     @Override
@@ -127,7 +159,10 @@ protected ShardsBatchGatewayAllocator(long batchSize) {
         this.batchStoreAction = null;
         this.replicaShardBatchAllocator = null;
         this.maxBatchSize = batchSize;
+        this.primaryShardsBatchGatewayAllocatorTimeout = null;
+        this.replicaShardsBatchGatewayAllocatorTimeout = null;
     }
+
     // for tests
 
     @Override
@@ -187,14 +222,14 @@ public void allocateUnassigned(
     }
 
     @Override
-    public void allocateAllUnassignedShards(final RoutingAllocation allocation, boolean primary) {
+    public BatchRunnableExecutor allocateAllUnassignedShards(final RoutingAllocation allocation, boolean primary) {
 
         assert primaryShardBatchAllocator != null;
         assert replicaShardBatchAllocator != null;
-        innerAllocateUnassignedBatch(allocation, primaryShardBatchAllocator, replicaShardBatchAllocator, primary);
+        return innerAllocateUnassignedBatch(allocation, primaryShardBatchAllocator, replicaShardBatchAllocator, primary);
     }
 
-    protected void innerAllocateUnassignedBatch(
+    protected BatchRunnableExecutor innerAllocateUnassignedBatch(
         RoutingAllocation allocation,
         PrimaryShardBatchAllocator primaryBatchShardAllocator,
         ReplicaShardBatchAllocator replicaBatchShardAllocator,
@@ -203,20 +238,45 @@ protected void innerAllocateUnassignedBatch(
         // create batches for unassigned shards
         Set<String> batchesToAssign = createAndUpdateBatches(allocation, primary);
         if (batchesToAssign.isEmpty()) {
-            return;
+            return null;
         }
+        List<TimeoutAwareRunnable> runnables = new ArrayList<>();
         if (primary) {
             batchIdToStartedShardBatch.values()
                 .stream()
                 .filter(batch -> batchesToAssign.contains(batch.batchId))
-                .forEach(
-                    shardsBatch -> primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation)
-                );
+                .forEach(shardsBatch -> runnables.add(new TimeoutAwareRunnable() {
+                    @Override
+                    public void onTimeout() {
+                        primaryBatchShardAllocator.allocateUnassignedBatchOnTimeout(
+                            shardsBatch.getBatchedShardRoutings(),
+                            allocation,
+                            true
+                        );
+                    }
+
+                    @Override
+                    public void run() {
+                        primaryBatchShardAllocator.allocateUnassignedBatch(shardsBatch.getBatchedShardRoutings(), allocation);
+                    }
+                }));
+            return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout);
         } else {
             batchIdToStoreShardBatch.values()
                 .stream()
                 .filter(batch -> batchesToAssign.contains(batch.batchId))
-                .forEach(batch -> replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation));
+                .forEach(batch -> runnables.add(new TimeoutAwareRunnable() {
+                    @Override
+                    public void onTimeout() {
+                        replicaBatchShardAllocator.allocateUnassignedBatchOnTimeout(batch.getBatchedShardRoutings(), allocation, false);
+                    }
+
+                    @Override
+                    public void run() {
+                        replicaBatchShardAllocator.allocateUnassignedBatch(batch.getBatchedShardRoutings(), allocation);
+                    }
+                }));
+            return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout);
         }
     }
 
@@ -721,4 +781,12 @@ public int getNumberOfStartedShardBatches() {
     public int getNumberOfStoreShardBatches() {
         return batchIdToStoreShardBatch.size();
     }
+
+    private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) {
+        this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout;
+    }
+
+    private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) {
+        this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout;
+    }
 }
diff --git a/server/src/test/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocatorTests.java b/server/src/test/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocatorTests.java
new file mode 100644
index 0000000000000..1da8f5ef7f695
--- /dev/null
+++ b/server/src/test/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocatorTests.java
@@ -0,0 +1,118 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.cluster.routing.allocation;
+
+import org.opensearch.Version;
+import org.opensearch.cluster.ClusterName;
+import org.opensearch.cluster.ClusterState;
+import org.opensearch.cluster.OpenSearchAllocationTestCase;
+import org.opensearch.cluster.metadata.IndexMetadata;
+import org.opensearch.cluster.metadata.Metadata;
+import org.opensearch.cluster.node.DiscoveryNodes;
+import org.opensearch.cluster.routing.RoutingTable;
+import org.opensearch.cluster.routing.ShardRouting;
+import org.opensearch.common.settings.Settings;
+
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+public class ExistingShardsAllocatorTests extends OpenSearchAllocationTestCase {
+
+    public void testRunnablesExecutedForUnassignedShards() throws InterruptedException {
+
+        Metadata metadata = Metadata.builder()
+            .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(3).numberOfReplicas(2))
+            .build();
+        RoutingTable initialRoutingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build();
+
+        ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
+            .metadata(metadata)
+            .routingTable(initialRoutingTable)
+            .build();
+        clusterState = ClusterState.builder(clusterState)
+            .nodes(DiscoveryNodes.builder().add(newNode("node1")).add(newNode("node2")).add(newNode("node3")))
+            .build();
+        RoutingAllocation allocation = new RoutingAllocation(
+            yesAllocationDeciders(),
+            clusterState.getRoutingNodes(),
+            clusterState,
+            null,
+            null,
+            0L
+        );
+        CountDownLatch expectedStateLatch = new CountDownLatch(3);
+        TestAllocator testAllocator = new TestAllocator(expectedStateLatch);
+        testAllocator.allocateAllUnassignedShards(allocation, true).run();
+        // if the below condition is passed, then we are sure runnable executed for all primary shards
+        assertTrue(expectedStateLatch.await(30, TimeUnit.SECONDS));
+
+        expectedStateLatch = new CountDownLatch(6);
+        testAllocator = new TestAllocator(expectedStateLatch);
+        testAllocator.allocateAllUnassignedShards(allocation, false).run();
+        // if the below condition is passed, then we are sure runnable executed for all replica shards
+        assertTrue(expectedStateLatch.await(30, TimeUnit.SECONDS));
+    }
+
+    private static class TestAllocator implements ExistingShardsAllocator {
+
+        final CountDownLatch countDownLatch;
+
+        TestAllocator(CountDownLatch latch) {
+            this.countDownLatch = latch;
+        }
+
+        @Override
+        public void beforeAllocation(RoutingAllocation allocation) {
+
+        }
+
+        @Override
+        public void afterPrimariesBeforeReplicas(RoutingAllocation allocation) {
+
+        }
+
+        @Override
+        public void allocateUnassigned(
+            ShardRouting shardRouting,
+            RoutingAllocation allocation,
+            UnassignedAllocationHandler unassignedAllocationHandler
+        ) {
+            countDownLatch.countDown();
+        }
+
+        @Override
+        public AllocateUnassignedDecision explainUnassignedShardAllocation(
+            ShardRouting unassignedShard,
+            RoutingAllocation routingAllocation
+        ) {
+            return null;
+        }
+
+        @Override
+        public void cleanCaches() {
+
+        }
+
+        @Override
+        public void applyStartedShards(List<ShardRouting> startedShards, RoutingAllocation allocation) {
+
+        }
+
+        @Override
+        public void applyFailedShards(List<FailedShard> failedShards, RoutingAllocation allocation) {
+
+        }
+
+        @Override
+        public int getNumberOfInFlightFetches() {
+            return 0;
+        }
+    }
+}
diff --git a/server/src/test/java/org/opensearch/common/util/BatchRunnableExecutorTests.java b/server/src/test/java/org/opensearch/common/util/BatchRunnableExecutorTests.java
new file mode 100644
index 0000000000000..269f89faec54d
--- /dev/null
+++ b/server/src/test/java/org/opensearch/common/util/BatchRunnableExecutorTests.java
@@ -0,0 +1,97 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.common.util;
+
+import org.opensearch.common.unit.TimeValue;
+import org.opensearch.common.util.concurrent.TimeoutAwareRunnable;
+import org.opensearch.test.OpenSearchTestCase;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.mockito.Mockito.atMost;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+public class BatchRunnableExecutorTests extends OpenSearchTestCase {
+    private Supplier<TimeValue> timeoutSupplier;
+    private TimeoutAwareRunnable runnable1;
+    private TimeoutAwareRunnable runnable2;
+    private TimeoutAwareRunnable runnable3;
+    private List<TimeoutAwareRunnable> runnableList;
+
+    public void setupRunnables() {
+        timeoutSupplier = mock(Supplier.class);
+        runnable1 = mock(TimeoutAwareRunnable.class);
+        runnable2 = mock(TimeoutAwareRunnable.class);
+        runnable3 = mock(TimeoutAwareRunnable.class);
+        runnableList = Arrays.asList(runnable1, runnable2, runnable3);
+    }
+
+    public void testRunWithoutTimeout() {
+        setupRunnables();
+        timeoutSupplier = () -> TimeValue.timeValueSeconds(1);
+        BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
+        executor.run();
+        verify(runnable1, times(1)).run();
+        verify(runnable2, times(1)).run();
+        verify(runnable3, times(1)).run();
+        verify(runnable1, never()).onTimeout();
+        verify(runnable2, never()).onTimeout();
+        verify(runnable3, never()).onTimeout();
+    }
+
+    public void testRunWithTimeout() {
+        setupRunnables();
+        timeoutSupplier = () -> TimeValue.timeValueNanos(1);
+        BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
+        executor.run();
+        verify(runnable1, times(1)).onTimeout();
+        verify(runnable2, times(1)).onTimeout();
+        verify(runnable3, times(1)).onTimeout();
+        verify(runnable1, never()).run();
+        verify(runnable2, never()).run();
+        verify(runnable3, never()).run();
+    }
+
+    public void testRunWithPartialTimeout() {
+        setupRunnables();
+        timeoutSupplier = () -> TimeValue.timeValueMillis(50);
+        BatchRunnableExecutor executor = new BatchRunnableExecutor(runnableList, timeoutSupplier);
+        doAnswer(invocation -> {
+            Thread.sleep(100);
+            return null;
+        }).when(runnable1).run();
+        executor.run();
+        verify(runnable1, atMost(1)).run();
+        verify(runnable2, atMost(1)).run();
+        verify(runnable3, atMost(1)).run();
+        verify(runnable2, atMost(1)).onTimeout();
+        verify(runnable3, atMost(1)).onTimeout();
+        verify(runnable2, atMost(1)).onTimeout();
+        verify(runnable3, atMost(1)).onTimeout();
+    }
+
+    public void testRunWithEmptyRunnableList() {
+        setupRunnables();
+        BatchRunnableExecutor executor = new BatchRunnableExecutor(Collections.emptyList(), timeoutSupplier);
+        executor.run();
+        verify(runnable1, never()).onTimeout();
+        verify(runnable2, never()).onTimeout();
+        verify(runnable3, never()).onTimeout();
+        verify(runnable1, never()).run();
+        verify(runnable2, never()).run();
+        verify(runnable3, never()).run();
+    }
+}
diff --git a/server/src/test/java/org/opensearch/gateway/GatewayAllocatorTests.java b/server/src/test/java/org/opensearch/gateway/GatewayAllocatorTests.java
index aa31c710c1fbd..bd56123f6df1f 100644
--- a/server/src/test/java/org/opensearch/gateway/GatewayAllocatorTests.java
+++ b/server/src/test/java/org/opensearch/gateway/GatewayAllocatorTests.java
@@ -32,6 +32,7 @@
 import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.opensearch.common.collect.Tuple;
 import org.opensearch.common.settings.Settings;
+import org.opensearch.common.util.BatchRunnableExecutor;
 import org.opensearch.common.util.set.Sets;
 import org.opensearch.core.index.shard.ShardId;
 import org.opensearch.snapshots.SnapshotShardSizeInfo;
@@ -61,6 +62,13 @@ public void setUp() throws Exception {
         testShardsBatchGatewayAllocator = new TestShardBatchGatewayAllocator();
     }
 
+    public void testExecutorNotNull() {
+        createIndexAndUpdateClusterState(1, 3, 1);
+        createBatchesAndAssert(1);
+        BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
+        assertNotNull(executor);
+    }
+
     public void testSingleBatchCreation() {
         createIndexAndUpdateClusterState(1, 3, 1);
         createBatchesAndAssert(1);
@@ -336,6 +344,30 @@ public void testGetBatchIdNonExisting() {
         allShardRoutings.forEach(shard -> assertNull(testShardsBatchGatewayAllocator.getBatchId(shard, shard.primary())));
     }
 
+    public void testCreatePrimaryAndReplicaExecutorOfSizeOne() {
+        createIndexAndUpdateClusterState(1, 3, 2);
+        BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
+        assertEquals(executor.getTimeoutAwareRunnables().size(), 1);
+        executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
+        assertEquals(executor.getTimeoutAwareRunnables().size(), 1);
+    }
+
+    public void testCreatePrimaryExecutorOfSizeOneAndReplicaExecutorOfSizeZero() {
+        createIndexAndUpdateClusterState(1, 3, 0);
+        BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
+        assertEquals(executor.getTimeoutAwareRunnables().size(), 1);
+        executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
+        assertNull(executor);
+    }
+
+    public void testCreatePrimaryAndReplicaExecutorOfSizeTwo() {
+        createIndexAndUpdateClusterState(2, 1001, 1);
+        BatchRunnableExecutor executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, true);
+        assertEquals(executor.getTimeoutAwareRunnables().size(), 2);
+        executor = testShardsBatchGatewayAllocator.allocateAllUnassignedShards(testAllocation, false);
+        assertEquals(executor.getTimeoutAwareRunnables().size(), 2);
+    }
+
     private void createIndexAndUpdateClusterState(int count, int numberOfShards, int numberOfReplicas) {
         if (count == 0) return;
         Metadata.Builder metadata = Metadata.builder();
diff --git a/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java b/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java
index 8ad8bcda95f40..270cf465d0f80 100644
--- a/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java
+++ b/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java
@@ -41,6 +41,7 @@
 import org.junit.Before;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -256,6 +257,52 @@ public void testAllocateUnassignedBatchThrottlingAllocationDeciderIsHonoured() {
         assertEquals(UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED, ignoredShards.get(0).unassignedInfo().getLastAllocationStatus());
     }
 
+    public void testAllocateUnassignedBatchOnTimeoutWithMatchingPrimaryShards() {
+        ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
+        AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
+        setUpShards(1);
+        final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
+        ShardRouting shardRouting = routingAllocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard();
+
+        List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
+        batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
+
+        List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
+        assertEquals(1, ignoredShards.size());
+        assertEquals(UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED, ignoredShards.get(0).unassignedInfo().getLastAllocationStatus());
+    }
+
+    public void testAllocateUnassignedBatchOnTimeoutWithNoMatchingPrimaryShards() {
+        ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
+        AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
+        setUpShards(1);
+        final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
+        List<ShardRouting> shardRoutings = new ArrayList<>();
+        batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, true);
+
+        List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
+        assertEquals(0, ignoredShards.size());
+    }
+
+    public void testAllocateUnassignedBatchOnTimeoutWithNonPrimaryShards() {
+        ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
+        AllocationDeciders allocationDeciders = randomAllocationDeciders(Settings.builder().build(), clusterSettings, random());
+        setUpShards(1);
+        final RoutingAllocation routingAllocation = routingAllocationWithOnePrimary(allocationDeciders, CLUSTER_RECOVERED, "allocId-0");
+
+        ShardRouting shardRouting = routingAllocation.routingTable()
+            .getIndicesRouting()
+            .get("test")
+            .shard(shardId.id())
+            .replicaShards()
+            .get(0);
+        List<ShardRouting> shardRoutings = Arrays.asList(shardRouting);
+        batchAllocator.allocateUnassignedBatchOnTimeout(shardRoutings, routingAllocation, false);
+
+        List<ShardRouting> ignoredShards = routingAllocation.routingNodes().unassigned().ignored();
+        assertEquals(1, ignoredShards.size());
+    }
+
     private RoutingAllocation routingAllocationWithOnePrimary(
         AllocationDeciders deciders,
         UnassignedInfo.Reason reason,
diff --git a/server/src/test/java/org/opensearch/gateway/ReplicaShardBatchAllocatorTests.java b/server/src/test/java/org/opensearch/gateway/ReplicaShardBatchAllocatorTests.java
index 526a3990955b8..435fd78be2bcd 100644
--- a/server/src/test/java/org/opensearch/gateway/ReplicaShardBatchAllocatorTests.java
+++ b/server/src/test/java/org/opensearch/gateway/ReplicaShardBatchAllocatorTests.java
@@ -717,6 +717,33 @@ public void testAllocateUnassignedBatchThrottlingAllocationDeciderIsHonoured() t
         assertEquals(UnassignedInfo.AllocationStatus.DECIDERS_THROTTLED, allocateUnassignedDecision.getAllocationStatus());
     }
 
+    public void testAllocateUnassignedBatchOnTimeoutWithUnassignedReplicaShard() {
+        RoutingAllocation allocation = onePrimaryOnNode1And1Replica(yesAllocationDeciders());
+        final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
+        List<ShardRouting> shards = new ArrayList<>();
+        while (iterator.hasNext()) {
+            shards.add(iterator.next());
+        }
+        testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
+        assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(1));
+        assertThat(allocation.routingNodes().unassigned().ignored().get(0).shardId(), equalTo(shardId));
+        assertEquals(
+            UnassignedInfo.AllocationStatus.NO_ATTEMPT,
+            allocation.routingNodes().unassigned().ignored().get(0).unassignedInfo().getLastAllocationStatus()
+        );
+    }
+
+    public void testAllocateUnassignedBatchOnTimeoutWithAlreadyRecoveringReplicaShard() {
+        RoutingAllocation allocation = onePrimaryOnNode1And1ReplicaRecovering(yesAllocationDeciders());
+        final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
+        List<ShardRouting> shards = new ArrayList<>();
+        while (iterator.hasNext()) {
+            shards.add(iterator.next());
+        }
+        testBatchAllocator.allocateUnassignedBatchOnTimeout(shards, allocation, false);
+        assertThat(allocation.routingNodes().unassigned().ignored().size(), equalTo(0));
+    }
+
     private RoutingAllocation onePrimaryOnNode1And1Replica(AllocationDeciders deciders) {
         return onePrimaryOnNode1And1Replica(deciders, Settings.EMPTY, UnassignedInfo.Reason.CLUSTER_RECOVERED);
     }
diff --git a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java
index fbb39c284f0ff..0eb4bb6935bac 100644
--- a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java
+++ b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java
@@ -13,6 +13,7 @@
 import org.opensearch.cluster.routing.ShardRouting;
 import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.opensearch.cluster.routing.allocation.RoutingAllocation;
+import org.opensearch.common.util.BatchRunnableExecutor;
 import org.opensearch.core.index.shard.ShardId;
 import org.opensearch.gateway.AsyncShardFetch;
 import org.opensearch.gateway.PrimaryShardBatchAllocator;
@@ -102,9 +103,9 @@ protected boolean hasInitiatedFetching(ShardRouting shard) {
     };
 
     @Override
-    public void allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) {
+    public BatchRunnableExecutor allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) {
         currentNodes = allocation.nodes();
-        innerAllocateUnassignedBatch(allocation, primaryBatchShardAllocator, replicaBatchShardAllocator, primary);
+        return innerAllocateUnassignedBatch(allocation, primaryBatchShardAllocator, replicaBatchShardAllocator, primary);
     }
 
     @Override