From 1ea89420c446683115bc6eafd0de2f2d3ea01f53 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 19 Mar 2024 20:06:30 +0000 Subject: [PATCH] Add PrimaryShardBatchAllocator to take allocation decisions for a batch of shards (#8916) * Add PrimaryShardBatchAllocator to take allocation decisions for a batch of shards Signed-off-by: Aman Khare (cherry picked from commit a499d1eb1a0683bd93c170c6a6e90be2a1690f65) Signed-off-by: github-actions[bot] --- .../gateway/RecoveryFromGatewayIT.java | 36 +- .../TransportIndicesShardStoresAction.java | 7 +- .../gateway/PrimaryShardAllocator.java | 72 ++-- .../gateway/PrimaryShardBatchAllocator.java | 150 ++++++++ ...ansportNodesGatewayStartedShardHelper.java | 166 ++++++++- ...ransportNodesListGatewayStartedShards.java | 97 ++--- ...ortNodesListGatewayStartedShardsBatch.java | 134 +------ .../gateway/PrimaryShardAllocatorTests.java | 10 +- .../PrimaryShardBatchAllocatorTests.java | 340 ++++++++++++++++++ .../test/gateway/TestGatewayAllocator.java | 10 +- 10 files changed, 754 insertions(+), 268 deletions(-) create mode 100644 server/src/main/java/org/opensearch/gateway/PrimaryShardBatchAllocator.java create mode 100644 server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java diff --git a/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java b/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java index 6c248a32c9928..ba03532a9aa2f 100644 --- a/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/gateway/RecoveryFromGatewayIT.java @@ -56,6 +56,7 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.env.NodeEnvironment; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard; import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; import org.opensearch.index.MergePolicyProvider; @@ -720,11 +721,11 @@ public Settings onNodeStopped(String nodeName) throws Exception { ); assertThat(response.getNodes(), hasSize(1)); - assertThat(response.getNodes().get(0).allocationId(), notNullValue()); + assertThat(response.getNodes().get(0).getGatewayShardStarted().allocationId(), notNullValue()); if (corrupt) { - assertThat(response.getNodes().get(0).storeException(), notNullValue()); + assertThat(response.getNodes().get(0).getGatewayShardStarted().storeException(), notNullValue()); } else { - assertThat(response.getNodes().get(0).storeException(), nullValue()); + assertThat(response.getNodes().get(0).getGatewayShardStarted().storeException(), nullValue()); } // start another node so cluster consistency checks won't time out due to the lack of state @@ -764,11 +765,11 @@ public void testSingleShardFetchUsingBatchAction() { ); final Index index = resolveIndex(indexName); final ShardId shardId = new ShardId(index, 0); - TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap() + GatewayStartedShard gatewayStartedShard = response.getNodesMap() .get(searchShardsResponse.getNodes()[0].getId()) .getNodeGatewayStartedShardsBatch() .get(shardId); - assertNodeGatewayStartedShardsHappyCase(nodeGatewayStartedShards); + assertNodeGatewayStartedShardsHappyCase(gatewayStartedShard); } public void testShardFetchMultiNodeMultiIndexesUsingBatchAction() { @@ -792,11 +793,8 @@ public void testShardFetchMultiNodeMultiIndexesUsingBatchAction() { ShardId shardId = clusterSearchShardsGroup.getShardId(); assertEquals(1, clusterSearchShardsGroup.getShards().length); String nodeId = clusterSearchShardsGroup.getShards()[0].currentNodeId(); - TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap() - .get(nodeId) - .getNodeGatewayStartedShardsBatch() - .get(shardId); - assertNodeGatewayStartedShardsHappyCase(nodeGatewayStartedShards); + GatewayStartedShard gatewayStartedShard = response.getNodesMap().get(nodeId).getNodeGatewayStartedShardsBatch().get(shardId); + assertNodeGatewayStartedShardsHappyCase(gatewayStartedShard); } } @@ -816,13 +814,13 @@ public void testShardFetchCorruptedShardsUsingBatchAction() throws Exception { new TransportNodesListGatewayStartedShardsBatch.Request(getDiscoveryNodes(), shardIdShardAttributesMap) ); DiscoveryNode[] discoveryNodes = getDiscoveryNodes(); - TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards = response.getNodesMap() + GatewayStartedShard gatewayStartedShard = response.getNodesMap() .get(discoveryNodes[0].getId()) .getNodeGatewayStartedShardsBatch() .get(shardId); - assertNotNull(nodeGatewayStartedShards.storeException()); - assertNotNull(nodeGatewayStartedShards.allocationId()); - assertTrue(nodeGatewayStartedShards.primary()); + assertNotNull(gatewayStartedShard.storeException()); + assertNotNull(gatewayStartedShard.allocationId()); + assertTrue(gatewayStartedShard.primary()); } public void testSingleShardStoreFetchUsingBatchAction() throws ExecutionException, InterruptedException { @@ -950,12 +948,10 @@ private void assertNodeStoreFilesMetadataSuccessCase( assertNotNull(storeFileMetadata.peerRecoveryRetentionLeases()); } - private void assertNodeGatewayStartedShardsHappyCase( - TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeGatewayStartedShards - ) { - assertNull(nodeGatewayStartedShards.storeException()); - assertNotNull(nodeGatewayStartedShards.allocationId()); - assertTrue(nodeGatewayStartedShards.primary()); + private void assertNodeGatewayStartedShardsHappyCase(GatewayStartedShard gatewayStartedShard) { + assertNull(gatewayStartedShard.storeException()); + assertNotNull(gatewayStartedShard.allocationId()); + assertTrue(gatewayStartedShard.primary()); } private void prepareIndex(String indexName, int numberOfPrimaryShards) { diff --git a/server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java b/server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java index 04166c88a00ad..3fbf9ac1bb570 100644 --- a/server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java +++ b/server/src/main/java/org/opensearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java @@ -258,9 +258,9 @@ void finish() { storeStatuses.add( new IndicesShardStoresResponse.StoreStatus( response.getNode(), - response.allocationId(), + response.getGatewayShardStarted().allocationId(), allocationStatus, - response.storeException() + response.getGatewayShardStarted().storeException() ) ); } @@ -308,7 +308,8 @@ private IndicesShardStoresResponse.StoreStatus.AllocationStatus getAllocationSta * A shard exists/existed in a node only if shard state file exists in the node */ private boolean shardExistsInNode(final NodeGatewayStartedShards response) { - return response.storeException() != null || response.allocationId() != null; + return response.getGatewayShardStarted().storeException() != null + || response.getGatewayShardStarted().allocationId() != null; } @Override diff --git a/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java b/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java index 5046873830c01..f41545cbdf9bf 100644 --- a/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/PrimaryShardAllocator.java @@ -50,6 +50,7 @@ import org.opensearch.cluster.routing.allocation.decider.Decision.Type; import org.opensearch.env.ShardLockObtainFailedException; import org.opensearch.gateway.AsyncShardFetch.FetchResult; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.NodeGatewayStartedShard; import org.opensearch.gateway.TransportNodesListGatewayStartedShards.NodeGatewayStartedShards; import java.util.ArrayList; @@ -125,27 +126,37 @@ public AllocateUnassignedDecision makeAllocationDecision( return decision; } final FetchResult shardState = fetchData(unassignedShard, allocation); - List nodeShardStates = adaptToNodeStartedShardList(shardState); + List nodeShardStates = adaptToNodeStartedShardList(shardState); return getAllocationDecision(unassignedShard, allocation, nodeShardStates, logger); } /** - * Transforms {@link FetchResult} of {@link NodeGatewayStartedShards} to {@link List} of {@link NodeGatewayStartedShards} + * Transforms {@link FetchResult} of {@link NodeGatewayStartedShards} to {@link List} of {@link NodeGatewayStartedShard} * Returns null if {@link FetchResult} does not have any data. */ - private static List adaptToNodeStartedShardList(FetchResult shardsState) { + private static List adaptToNodeStartedShardList(FetchResult shardsState) { if (!shardsState.hasData()) { return null; } - List nodeShardStates = new ArrayList<>(); - shardsState.getData().forEach((node, nodeGatewayStartedShard) -> { nodeShardStates.add(nodeGatewayStartedShard); }); + List nodeShardStates = new ArrayList<>(); + shardsState.getData().forEach((node, nodeGatewayStartedShard) -> { + nodeShardStates.add( + new NodeGatewayStartedShard( + nodeGatewayStartedShard.getGatewayShardStarted().allocationId(), + nodeGatewayStartedShard.getGatewayShardStarted().primary(), + nodeGatewayStartedShard.getGatewayShardStarted().replicationCheckpoint(), + nodeGatewayStartedShard.getGatewayShardStarted().storeException(), + node + ) + ); + }); return nodeShardStates; } protected AllocateUnassignedDecision getAllocationDecision( ShardRouting unassignedShard, RoutingAllocation allocation, - List shardState, + List shardState, Logger logger ) { final boolean explain = allocation.debugDecision(); @@ -236,7 +247,7 @@ protected AllocateUnassignedDecision getAllocationDecision( nodesToAllocate = buildNodesToAllocate(allocation, nodeShardsResult.orderedAllocationCandidates, unassignedShard, true); if (nodesToAllocate.yesNodeShards.isEmpty() == false) { final DecidedNode decidedNode = nodesToAllocate.yesNodeShards.get(0); - final NodeGatewayStartedShards nodeShardState = decidedNode.nodeShardState; + final NodeGatewayStartedShard nodeShardState = decidedNode.nodeShardState; logger.debug( "[{}][{}]: allocating [{}] to [{}] on forced primary allocation", unassignedShard.index(), @@ -296,11 +307,11 @@ protected AllocateUnassignedDecision getAllocationDecision( */ private static List buildNodeDecisions( NodesToAllocate nodesToAllocate, - List fetchedShardData, + List fetchedShardData, Set inSyncAllocationIds ) { List nodeResults = new ArrayList<>(); - Collection ineligibleShards = new ArrayList<>(); + Collection ineligibleShards = new ArrayList<>(); if (nodesToAllocate != null) { final Set discoNodes = new HashSet<>(); nodeResults.addAll( @@ -334,21 +345,21 @@ private static List buildNodeDecisions( return nodeResults; } - private static ShardStoreInfo shardStoreInfo(NodeGatewayStartedShards nodeShardState, Set inSyncAllocationIds) { + private static ShardStoreInfo shardStoreInfo(NodeGatewayStartedShard nodeShardState, Set inSyncAllocationIds) { final Exception storeErr = nodeShardState.storeException(); final boolean inSync = nodeShardState.allocationId() != null && inSyncAllocationIds.contains(nodeShardState.allocationId()); return new ShardStoreInfo(nodeShardState.allocationId(), inSync, storeErr); } - private static final Comparator NO_STORE_EXCEPTION_FIRST_COMPARATOR = Comparator.comparing( - (NodeGatewayStartedShards state) -> state.storeException() == null + private static final Comparator NO_STORE_EXCEPTION_FIRST_COMPARATOR = Comparator.comparing( + (NodeGatewayStartedShard state) -> state.storeException() == null ).reversed(); - private static final Comparator PRIMARY_FIRST_COMPARATOR = Comparator.comparing( - NodeGatewayStartedShards::primary + private static final Comparator PRIMARY_FIRST_COMPARATOR = Comparator.comparing( + NodeGatewayStartedShard::primary ).reversed(); - private static final Comparator HIGHEST_REPLICATION_CHECKPOINT_FIRST_COMPARATOR = Comparator.comparing( - NodeGatewayStartedShards::replicationCheckpoint, + private static final Comparator HIGHEST_REPLICATION_CHECKPOINT_FIRST_COMPARATOR = Comparator.comparing( + NodeGatewayStartedShard::replicationCheckpoint, Comparator.nullsLast(Comparator.naturalOrder()) ); @@ -362,12 +373,12 @@ protected static NodeShardsResult buildNodeShardsResult( boolean matchAnyShard, Set ignoreNodes, Set inSyncAllocationIds, - List shardState, + List shardState, Logger logger ) { - List nodeShardStates = new ArrayList<>(); + List nodeShardStates = new ArrayList<>(); int numberOfAllocationsFound = 0; - for (NodeGatewayStartedShards nodeShardState : shardState) { + for (NodeGatewayStartedShard nodeShardState : shardState) { DiscoveryNode node = nodeShardState.getNode(); String allocationId = nodeShardState.allocationId(); @@ -432,21 +443,18 @@ protected static NodeShardsResult buildNodeShardsResult( return new NodeShardsResult(nodeShardStates, numberOfAllocationsFound); } - private static Comparator createActiveShardComparator( - boolean matchAnyShard, - Set inSyncAllocationIds - ) { + private static Comparator createActiveShardComparator(boolean matchAnyShard, Set inSyncAllocationIds) { /** * Orders the active shards copies based on below comparators * 1. No store exception i.e. shard copy is readable * 2. Prefer previous primary shard * 3. Prefer shard copy with the highest replication checkpoint. It is NO-OP for doc rep enabled indices. */ - final Comparator comparator; // allocation preference + final Comparator comparator; // allocation preference if (matchAnyShard) { // prefer shards with matching allocation ids - Comparator matchingAllocationsFirst = Comparator.comparing( - (NodeGatewayStartedShards state) -> inSyncAllocationIds.contains(state.allocationId()) + Comparator matchingAllocationsFirst = Comparator.comparing( + (NodeGatewayStartedShard state) -> inSyncAllocationIds.contains(state.allocationId()) ).reversed(); comparator = matchingAllocationsFirst.thenComparing(NO_STORE_EXCEPTION_FIRST_COMPARATOR) .thenComparing(PRIMARY_FIRST_COMPARATOR) @@ -464,14 +472,14 @@ private static Comparator createActiveShardComparator( */ private static NodesToAllocate buildNodesToAllocate( RoutingAllocation allocation, - List nodeShardStates, + List nodeShardStates, ShardRouting shardRouting, boolean forceAllocate ) { List yesNodeShards = new ArrayList<>(); List throttledNodeShards = new ArrayList<>(); List noNodeShards = new ArrayList<>(); - for (NodeGatewayStartedShards nodeShardState : nodeShardStates) { + for (NodeGatewayStartedShard nodeShardState : nodeShardStates) { RoutingNode node = allocation.routingNodes().node(nodeShardState.getNode().getId()); if (node == null) { continue; @@ -502,10 +510,10 @@ private static NodesToAllocate buildNodesToAllocate( * This class encapsulates the result of a call to {@link #buildNodeShardsResult} */ static class NodeShardsResult { - final List orderedAllocationCandidates; + final List orderedAllocationCandidates; final int allocationsFound; - NodeShardsResult(List orderedAllocationCandidates, int allocationsFound) { + NodeShardsResult(List orderedAllocationCandidates, int allocationsFound) { this.orderedAllocationCandidates = orderedAllocationCandidates; this.allocationsFound = allocationsFound; } @@ -531,10 +539,10 @@ protected static class NodesToAllocate { * by the allocator for allocating to the node that holds the shard copy. */ private static class DecidedNode { - final NodeGatewayStartedShards nodeShardState; + final NodeGatewayStartedShard nodeShardState; final Decision decision; - private DecidedNode(NodeGatewayStartedShards nodeShardState, Decision decision) { + private DecidedNode(NodeGatewayStartedShard nodeShardState, Decision decision) { this.nodeShardState = nodeShardState; this.decision = decision; } diff --git a/server/src/main/java/org/opensearch/gateway/PrimaryShardBatchAllocator.java b/server/src/main/java/org/opensearch/gateway/PrimaryShardBatchAllocator.java new file mode 100644 index 0000000000000..8d222903b6f29 --- /dev/null +++ b/server/src/main/java/org/opensearch/gateway/PrimaryShardBatchAllocator.java @@ -0,0 +1,150 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.gateway; + +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.routing.RoutingNodes; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision; +import org.opensearch.cluster.routing.allocation.RoutingAllocation; +import org.opensearch.gateway.AsyncShardFetch.FetchResult; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.NodeGatewayStartedShard; +import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * PrimaryShardBatchAllocator is similar to {@link org.opensearch.gateway.PrimaryShardAllocator} only difference is + * that it can allocate multiple unassigned primary shards wherein PrimaryShardAllocator can only allocate single + * unassigned shard. + * The primary shard batch allocator allocates multiple unassigned primary shards to nodes that hold + * valid copies of the unassigned primaries. It does this by iterating over all unassigned + * primary shards in the routing table and fetching shard metadata from each node in the cluster + * that holds a copy of the shard. The shard metadata from each node is compared against the + * set of valid allocation IDs and for all valid shard copies (if any), the primary shard batch allocator + * executes the allocation deciders to chose a copy to assign the primary shard to. + *

+ * Note that the PrimaryShardBatchAllocator does *not* allocate primaries on index creation + * (see {@link org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator}), + * nor does it allocate primaries when a primary shard failed and there is a valid replica + * copy that can immediately be promoted to primary, as this takes place in {@link RoutingNodes#failShard}. + * + * @opensearch.internal + */ +public abstract class PrimaryShardBatchAllocator extends PrimaryShardAllocator { + + abstract protected FetchResult fetchData( + List eligibleShards, + List inEligibleShards, + RoutingAllocation allocation + ); + + protected FetchResult fetchData( + ShardRouting shard, + RoutingAllocation allocation + ) { + logger.error("fetchData for single shard called via batch allocator, shard id {}", shard.shardId()); + throw new IllegalStateException("PrimaryShardBatchAllocator should only be used for a batch of shards"); + } + + @Override + public AllocateUnassignedDecision makeAllocationDecision(ShardRouting unassignedShard, RoutingAllocation allocation, Logger logger) { + return makeAllocationDecision(Collections.singletonList(unassignedShard), allocation, logger).get(unassignedShard); + } + + /** + * Build allocation decisions for all the shards present in the batch identified by batchId. + * + * @param shards set of shards given for allocation + * @param allocation current allocation of all the shards + * @param logger logger used for logging + * @return shard to allocation decision map + */ + @Override + public HashMap makeAllocationDecision( + List shards, + RoutingAllocation allocation, + Logger logger + ) { + HashMap shardAllocationDecisions = new HashMap<>(); + List eligibleShards = new ArrayList<>(); + List inEligibleShards = new ArrayList<>(); + // identify ineligible shards + for (ShardRouting shard : shards) { + AllocateUnassignedDecision decision = getInEligibleShardDecision(shard, allocation); + if (decision != null) { + inEligibleShards.add(shard); + shardAllocationDecisions.put(shard, decision); + } else { + eligibleShards.add(shard); + } + } + // Do not call fetchData if there are no eligible shards + if (eligibleShards.isEmpty()) { + return shardAllocationDecisions; + } + // only fetch data for eligible shards + final FetchResult shardsState = fetchData(eligibleShards, inEligibleShards, allocation); + + // process the received data + for (ShardRouting unassignedShard : eligibleShards) { + List nodeShardStates = adaptToNodeShardStates(unassignedShard, shardsState); + // get allocation decision for this shard + shardAllocationDecisions.put(unassignedShard, getAllocationDecision(unassignedShard, allocation, nodeShardStates, logger)); + } + return shardAllocationDecisions; + } + + /** + * Transforms {@link FetchResult} of {@link NodeGatewayStartedShardsBatch} to {@link List} of {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards}. + *

+ * Returns null if {@link FetchResult} does not have any data. + *

+ * shardsState contain the Data, there key is DiscoveryNode but value is Map of ShardId + * and NodeGatewayStartedShardsBatch so to get one shard level data (from all the nodes), we'll traverse the map + * and construct the nodeShardState along the way before making any allocation decision. As metadata for a + * particular shard is needed from all the discovery nodes. + * + * @param unassignedShard unassigned shard + * @param shardsState fetch data result for the whole batch + * @return shard state returned from each node + */ + private static List adaptToNodeShardStates( + ShardRouting unassignedShard, + FetchResult shardsState + ) { + if (!shardsState.hasData()) { + return null; + } + List nodeShardStates = new ArrayList<>(); + Map nodeResponses = shardsState.getData(); + + // build data for a shard from all the nodes + nodeResponses.forEach((node, nodeGatewayStartedShardsBatch) -> { + TransportNodesGatewayStartedShardHelper.GatewayStartedShard shardData = nodeGatewayStartedShardsBatch + .getNodeGatewayStartedShardsBatch() + .get(unassignedShard.shardId()); + nodeShardStates.add( + new NodeGatewayStartedShard( + shardData.allocationId(), + shardData.primary(), + shardData.replicationCheckpoint(), + shardData.storeException(), + node + ) + ); + }); + return nodeShardStates; + } +} diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java b/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java index 403e3e96fa209..27cce76b1b694 100644 --- a/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java @@ -12,8 +12,11 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.OpenSearchException; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.NodeEnvironment; @@ -23,8 +26,10 @@ import org.opensearch.index.shard.ShardStateMetadata; import org.opensearch.index.store.Store; import org.opensearch.indices.IndicesService; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import java.io.IOException; +import java.util.Objects; /** * This class has the common code used in {@link TransportNodesListGatewayStartedShards} and @@ -37,7 +42,7 @@ * @opensearch.internal */ public class TransportNodesGatewayStartedShardHelper { - public static TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard getShardInfoOnLocalNode( + public static GatewayStartedShard getShardInfoOnLocalNode( Logger logger, final ShardId shardId, NamedXContentRegistry namedXContentRegistry, @@ -90,25 +95,168 @@ public static TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShar exception ); String allocationId = shardStateMetadata.allocationId != null ? shardStateMetadata.allocationId.getId() : null; - return new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard( - allocationId, - shardStateMetadata.primary, - null, - exception - ); + return new GatewayStartedShard(allocationId, shardStateMetadata.primary, null, exception); } } logger.debug("{} shard state info found: [{}]", shardId, shardStateMetadata); String allocationId = shardStateMetadata.allocationId != null ? shardStateMetadata.allocationId.getId() : null; final IndexShard shard = indicesService.getShardOrNull(shardId); - return new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard( + return new GatewayStartedShard( allocationId, shardStateMetadata.primary, shard != null ? shard.getLatestReplicationCheckpoint() : null ); } logger.trace("{} no local shard info found", shardId); - return new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(null, false, null); + return new GatewayStartedShard(null, false, null); + } + + /** + * This class encapsulates the metadata about a started shard that needs to be persisted or sent between nodes. + * This is used in {@link TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch} to construct the response for each node, instead of + * {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards} because we don't need to save an extra + * {@link DiscoveryNode} object like in {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards} + * which reduces memory footprint of its objects. + * + * @opensearch.internal + */ + public static class GatewayStartedShard { + private final String allocationId; + private final boolean primary; + private final Exception storeException; + private final ReplicationCheckpoint replicationCheckpoint; + + public GatewayStartedShard(StreamInput in) throws IOException { + allocationId = in.readOptionalString(); + primary = in.readBoolean(); + if (in.readBoolean()) { + storeException = in.readException(); + } else { + storeException = null; + } + if (in.readBoolean()) { + replicationCheckpoint = new ReplicationCheckpoint(in); + } else { + replicationCheckpoint = null; + } + } + + public GatewayStartedShard(String allocationId, boolean primary, ReplicationCheckpoint replicationCheckpoint) { + this(allocationId, primary, replicationCheckpoint, null); + } + + public GatewayStartedShard( + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint, + Exception storeException + ) { + this.allocationId = allocationId; + this.primary = primary; + this.replicationCheckpoint = replicationCheckpoint; + this.storeException = storeException; + } + + public String allocationId() { + return this.allocationId; + } + + public boolean primary() { + return this.primary; + } + + public ReplicationCheckpoint replicationCheckpoint() { + return this.replicationCheckpoint; + } + + public Exception storeException() { + return this.storeException; + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(allocationId); + out.writeBoolean(primary); + if (storeException != null) { + out.writeBoolean(true); + out.writeException(storeException); + } else { + out.writeBoolean(false); + } + if (replicationCheckpoint != null) { + out.writeBoolean(true); + replicationCheckpoint.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + GatewayStartedShard that = (GatewayStartedShard) o; + + return primary == that.primary + && Objects.equals(allocationId, that.allocationId) + && Objects.equals(storeException, that.storeException) + && Objects.equals(replicationCheckpoint, that.replicationCheckpoint); + } + + @Override + public int hashCode() { + int result = (allocationId != null ? allocationId.hashCode() : 0); + result = 31 * result + (primary ? 1 : 0); + result = 31 * result + (storeException != null ? storeException.hashCode() : 0); + result = 31 * result + (replicationCheckpoint != null ? replicationCheckpoint.hashCode() : 0); + return result; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + buf.append("NodeGatewayStartedShards[").append("allocationId=").append(allocationId).append(",primary=").append(primary); + if (storeException != null) { + buf.append(",storeException=").append(storeException); + } + if (replicationCheckpoint != null) { + buf.append(",ReplicationCheckpoint=").append(replicationCheckpoint.toString()); + } + buf.append("]"); + return buf.toString(); + } + } + + /** + * This class extends the {@link GatewayStartedShard} which contains all necessary shard metadata like + * allocationId and replication checkpoint. It also has DiscoveryNode which is needed by + * {@link PrimaryShardAllocator} and {@link PrimaryShardBatchAllocator} to make allocation decision. + * This class removes the dependency of + * {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards} to make allocation decisions by + * {@link PrimaryShardAllocator} or {@link PrimaryShardBatchAllocator}. + */ + public static class NodeGatewayStartedShard extends GatewayStartedShard { + + private final DiscoveryNode node; + + public NodeGatewayStartedShard( + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint, + Exception storeException, + DiscoveryNode node + ) { + super(allocationId, primary, replicationCheckpoint, storeException); + this.node = node; + } + + public DiscoveryNode getNode() { + return node; + } } } diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShards.java b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShards.java index fbb5afd1ab36c..939051e962f5e 100644 --- a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShards.java +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShards.java @@ -54,6 +54,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.NodeEnvironment; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard; import org.opensearch.indices.IndicesService; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.store.ShardAttributes; @@ -154,7 +155,7 @@ protected NodesGatewayStartedShards newResponse( @Override protected NodeGatewayStartedShards nodeOperation(NodeRequest request) { try { - TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard shardInfo = getShardInfoOnLocalNode( + GatewayStartedShard shardInfo = getShardInfoOnLocalNode( logger, request.getShardId(), namedXContentRegistry, @@ -166,10 +167,12 @@ protected NodeGatewayStartedShards nodeOperation(NodeRequest request) { ); return new NodeGatewayStartedShards( clusterService.localNode(), - shardInfo.allocationId(), - shardInfo.primary(), - shardInfo.replicationCheckpoint(), - shardInfo.storeException() + new GatewayStartedShard( + shardInfo.allocationId(), + shardInfo.primary(), + shardInfo.replicationCheckpoint(), + shardInfo.storeException() + ) ); } catch (Exception e) { throw new OpenSearchException("failed to load started shards", e); @@ -302,81 +305,51 @@ public String getCustomDataPath() { * @opensearch.internal */ public static class NodeGatewayStartedShards extends BaseNodeResponse { - private final String allocationId; - private final boolean primary; - private final Exception storeException; - private final ReplicationCheckpoint replicationCheckpoint; + private final GatewayStartedShard gatewayStartedShard; public NodeGatewayStartedShards(StreamInput in) throws IOException { super(in); - allocationId = in.readOptionalString(); - primary = in.readBoolean(); + String allocationId = in.readOptionalString(); + boolean primary = in.readBoolean(); + Exception storeException; if (in.readBoolean()) { storeException = in.readException(); } else { storeException = null; } + ReplicationCheckpoint replicationCheckpoint; if (in.getVersion().onOrAfter(Version.V_2_3_0) && in.readBoolean()) { replicationCheckpoint = new ReplicationCheckpoint(in); } else { replicationCheckpoint = null; } + this.gatewayStartedShard = new GatewayStartedShard(allocationId, primary, replicationCheckpoint, storeException); } - public NodeGatewayStartedShards( - DiscoveryNode node, - String allocationId, - boolean primary, - ReplicationCheckpoint replicationCheckpoint - ) { - this(node, allocationId, primary, replicationCheckpoint, null); + public GatewayStartedShard getGatewayShardStarted() { + return gatewayStartedShard; } - public NodeGatewayStartedShards( - DiscoveryNode node, - String allocationId, - boolean primary, - ReplicationCheckpoint replicationCheckpoint, - Exception storeException - ) { + public NodeGatewayStartedShards(DiscoveryNode node, GatewayStartedShard gatewayStartedShard) { super(node); - this.allocationId = allocationId; - this.primary = primary; - this.replicationCheckpoint = replicationCheckpoint; - this.storeException = storeException; - } - - public String allocationId() { - return this.allocationId; - } - - public boolean primary() { - return this.primary; - } - - public ReplicationCheckpoint replicationCheckpoint() { - return this.replicationCheckpoint; - } - - public Exception storeException() { - return this.storeException; + this.gatewayStartedShard = gatewayStartedShard; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeOptionalString(allocationId); - out.writeBoolean(primary); - if (storeException != null) { + out.writeOptionalString(gatewayStartedShard.allocationId()); + out.writeBoolean(gatewayStartedShard.primary()); + if (gatewayStartedShard.storeException() != null) { out.writeBoolean(true); - out.writeException(storeException); + out.writeException(gatewayStartedShard.storeException()); } else { out.writeBoolean(false); } if (out.getVersion().onOrAfter(Version.V_2_3_0)) { - if (replicationCheckpoint != null) { + if (gatewayStartedShard.replicationCheckpoint() != null) { out.writeBoolean(true); - replicationCheckpoint.writeTo(out); + gatewayStartedShard.replicationCheckpoint().writeTo(out); } else { out.writeBoolean(false); } @@ -394,33 +367,17 @@ public boolean equals(Object o) { NodeGatewayStartedShards that = (NodeGatewayStartedShards) o; - return primary == that.primary - && Objects.equals(allocationId, that.allocationId) - && Objects.equals(storeException, that.storeException) - && Objects.equals(replicationCheckpoint, that.replicationCheckpoint); + return gatewayStartedShard.equals(that.gatewayStartedShard); } @Override public int hashCode() { - int result = (allocationId != null ? allocationId.hashCode() : 0); - result = 31 * result + (primary ? 1 : 0); - result = 31 * result + (storeException != null ? storeException.hashCode() : 0); - result = 31 * result + (replicationCheckpoint != null ? replicationCheckpoint.hashCode() : 0); - return result; + return gatewayStartedShard.hashCode(); } @Override public String toString() { - StringBuilder buf = new StringBuilder(); - buf.append("NodeGatewayStartedShards[").append("allocationId=").append(allocationId).append(",primary=").append(primary); - if (storeException != null) { - buf.append(",storeException=").append(storeException); - } - if (replicationCheckpoint != null) { - buf.append(",ReplicationCheckpoint=").append(replicationCheckpoint.toString()); - } - buf.append("]"); - return buf.toString(); + return gatewayStartedShard.toString(); } } } diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java index 638eae84bc8be..dad9cd94d9ba7 100644 --- a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java @@ -28,8 +28,8 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.NodeEnvironment; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard; import org.opensearch.indices.IndicesService; -import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.store.ShardAttributes; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -135,7 +135,7 @@ protected NodesGatewayStartedShardsBatch newResponse( */ @Override protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) { - Map shardsOnNode = new HashMap<>(); + Map shardsOnNode = new HashMap<>(); for (ShardAttributes shardAttr : request.shardAttributes.values()) { final ShardId shardId = shardAttr.getShardId(); try { @@ -155,7 +155,7 @@ protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) { } catch (Exception e) { shardsOnNode.put( shardId, - new NodeGatewayStartedShard(null, false, null, new OpenSearchException("failed to load started shards", e)) + new GatewayStartedShard(null, false, null, new OpenSearchException("failed to load started shards", e)) ); } } @@ -248,126 +248,6 @@ public void writeTo(StreamOutput out) throws IOException { } } - /** - * This class encapsulates the metadata about a started shard that needs to be persisted or sent between nodes. - * This is used in {@link NodeGatewayStartedShardsBatch} to construct the response for each node, instead of - * {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards} because we don't need to save an extra - * {@link DiscoveryNode} object like in {@link TransportNodesListGatewayStartedShards.NodeGatewayStartedShards} - * which reduces memory footprint of its objects. - * - * @opensearch.internal - */ - public static class NodeGatewayStartedShard { - private final String allocationId; - private final boolean primary; - private final Exception storeException; - private final ReplicationCheckpoint replicationCheckpoint; - - public NodeGatewayStartedShard(StreamInput in) throws IOException { - allocationId = in.readOptionalString(); - primary = in.readBoolean(); - if (in.readBoolean()) { - storeException = in.readException(); - } else { - storeException = null; - } - if (in.readBoolean()) { - replicationCheckpoint = new ReplicationCheckpoint(in); - } else { - replicationCheckpoint = null; - } - } - - public NodeGatewayStartedShard(String allocationId, boolean primary, ReplicationCheckpoint replicationCheckpoint) { - this(allocationId, primary, replicationCheckpoint, null); - } - - public NodeGatewayStartedShard( - String allocationId, - boolean primary, - ReplicationCheckpoint replicationCheckpoint, - Exception storeException - ) { - this.allocationId = allocationId; - this.primary = primary; - this.replicationCheckpoint = replicationCheckpoint; - this.storeException = storeException; - } - - public String allocationId() { - return this.allocationId; - } - - public boolean primary() { - return this.primary; - } - - public ReplicationCheckpoint replicationCheckpoint() { - return this.replicationCheckpoint; - } - - public Exception storeException() { - return this.storeException; - } - - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(allocationId); - out.writeBoolean(primary); - if (storeException != null) { - out.writeBoolean(true); - out.writeException(storeException); - } else { - out.writeBoolean(false); - } - if (replicationCheckpoint != null) { - out.writeBoolean(true); - replicationCheckpoint.writeTo(out); - } else { - out.writeBoolean(false); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - NodeGatewayStartedShard that = (NodeGatewayStartedShard) o; - - return primary == that.primary - && Objects.equals(allocationId, that.allocationId) - && Objects.equals(storeException, that.storeException) - && Objects.equals(replicationCheckpoint, that.replicationCheckpoint); - } - - @Override - public int hashCode() { - int result = (allocationId != null ? allocationId.hashCode() : 0); - result = 31 * result + (primary ? 1 : 0); - result = 31 * result + (storeException != null ? storeException.hashCode() : 0); - result = 31 * result + (replicationCheckpoint != null ? replicationCheckpoint.hashCode() : 0); - return result; - } - - @Override - public String toString() { - StringBuilder buf = new StringBuilder(); - buf.append("NodeGatewayStartedShards[").append("allocationId=").append(allocationId).append(",primary=").append(primary); - if (storeException != null) { - buf.append(",storeException=").append(storeException); - } - if (replicationCheckpoint != null) { - buf.append(",ReplicationCheckpoint=").append(replicationCheckpoint.toString()); - } - buf.append("]"); - return buf.toString(); - } - } - /** * This is the response from a single node, this is used in {@link NodesGatewayStartedShardsBatch} for creating * node to its response mapping for this transport request. @@ -376,15 +256,15 @@ public String toString() { * @opensearch.internal */ public static class NodeGatewayStartedShardsBatch extends BaseNodeResponse { - private final Map nodeGatewayStartedShardsBatch; + private final Map nodeGatewayStartedShardsBatch; - public Map getNodeGatewayStartedShardsBatch() { + public Map getNodeGatewayStartedShardsBatch() { return nodeGatewayStartedShardsBatch; } public NodeGatewayStartedShardsBatch(StreamInput in) throws IOException { super(in); - this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, NodeGatewayStartedShard::new); + this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, GatewayStartedShard::new); } @Override @@ -393,7 +273,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o), (o, v) -> v.writeTo(o)); } - public NodeGatewayStartedShardsBatch(DiscoveryNode node, Map nodeGatewayStartedShardsBatch) { + public NodeGatewayStartedShardsBatch(DiscoveryNode node, Map nodeGatewayStartedShardsBatch) { super(node); this.nodeGatewayStartedShardsBatch = nodeGatewayStartedShardsBatch; } diff --git a/server/src/test/java/org/opensearch/gateway/PrimaryShardAllocatorTests.java b/server/src/test/java/org/opensearch/gateway/PrimaryShardAllocatorTests.java index dceda6433575c..e849f12143b4d 100644 --- a/server/src/test/java/org/opensearch/gateway/PrimaryShardAllocatorTests.java +++ b/server/src/test/java/org/opensearch/gateway/PrimaryShardAllocatorTests.java @@ -843,10 +843,12 @@ public TestAllocator addData( node, new TransportNodesListGatewayStartedShards.NodeGatewayStartedShards( node, - allocationId, - primary, - replicationCheckpoint, - storeException + new TransportNodesGatewayStartedShardHelper.GatewayStartedShard( + allocationId, + primary, + replicationCheckpoint, + storeException + ) ) ); return this; diff --git a/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java b/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java new file mode 100644 index 0000000000000..4796def2b8902 --- /dev/null +++ b/server/src/test/java/org/opensearch/gateway/PrimaryShardBatchAllocatorTests.java @@ -0,0 +1,340 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.gateway; + +import org.apache.lucene.codecs.Codec; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.OpenSearchAllocationTestCase; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.routing.RoutingNodes; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.UnassignedInfo; +import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision; +import org.opensearch.cluster.routing.allocation.AllocationDecision; +import org.opensearch.cluster.routing.allocation.RoutingAllocation; +import org.opensearch.cluster.routing.allocation.decider.AllocationDeciders; +import org.opensearch.common.Nullable; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.set.Sets; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.codec.CodecService; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import org.opensearch.test.IndexSettingsModule; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.opensearch.cluster.routing.UnassignedInfo.Reason.CLUSTER_RECOVERED; + +public class PrimaryShardBatchAllocatorTests extends OpenSearchAllocationTestCase { + + private final ShardId shardId = new ShardId("test", "_na_", 0); + private static Set shardsInBatch; + private final DiscoveryNode node1 = newNode("node1"); + private final DiscoveryNode node2 = newNode("node2"); + private final DiscoveryNode node3 = newNode("node3"); + private TestBatchAllocator batchAllocator; + + public static void setUpShards(int numberOfShards) { + shardsInBatch = new HashSet<>(); + for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) { + ShardId shardId = new ShardId("test", "_na_", shardNumber); + shardsInBatch.add(shardId); + } + } + + @Before + public void buildTestAllocator() { + this.batchAllocator = new TestBatchAllocator(); + } + + private void allocateAllUnassigned(final RoutingAllocation allocation) { + final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator(); + while (iterator.hasNext()) { + batchAllocator.allocateUnassigned(iterator.next(), allocation, iterator); + } + } + + private void allocateAllUnassignedBatch(final RoutingAllocation allocation) { + final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator(); + List shardsToBatch = new ArrayList<>(); + while (iterator.hasNext()) { + shardsToBatch.add(iterator.next()); + } + batchAllocator.allocateUnassignedBatch(shardsToBatch, allocation); + } + + public void testMakeAllocationDecisionDataFetching() { + final RoutingAllocation allocation = routingAllocationWithOnePrimary(noAllocationDeciders(), CLUSTER_RECOVERED, "allocId1"); + + List shards = new ArrayList<>(); + allocateAllUnassignedBatch(allocation); + ShardRouting shard = allocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard(); + shards.add(shard); + HashMap allDecisions = batchAllocator.makeAllocationDecision(shards, allocation, logger); + // verify we get decisions for all the shards + assertEquals(shards.size(), allDecisions.size()); + assertEquals(shards, new ArrayList<>(allDecisions.keySet())); + assertEquals(AllocationDecision.AWAITING_INFO, allDecisions.get(shard).getAllocationDecision()); + } + + public void testMakeAllocationDecisionForReplicaShard() { + final RoutingAllocation allocation = routingAllocationWithOnePrimary(noAllocationDeciders(), CLUSTER_RECOVERED, "allocId1"); + + List replicaShards = allocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).replicaShards(); + List shards = new ArrayList<>(replicaShards); + HashMap allDecisions = batchAllocator.makeAllocationDecision(shards, allocation, logger); + // verify we get decisions for all the shards + assertEquals(shards.size(), allDecisions.size()); + assertEquals(shards, new ArrayList<>(allDecisions.keySet())); + assertFalse(allDecisions.get(replicaShards.get(0)).isDecisionTaken()); + } + + public void testMakeAllocationDecisionDataFetched() { + final RoutingAllocation allocation = routingAllocationWithOnePrimary(noAllocationDeciders(), CLUSTER_RECOVERED, "allocId1"); + + List shards = new ArrayList<>(); + ShardRouting shard = allocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard(); + shards.add(shard); + batchAllocator.addData(node1, "allocId1", true, new ReplicationCheckpoint(shardId, 20, 101, 1, Codec.getDefault().getName())); + HashMap allDecisions = batchAllocator.makeAllocationDecision(shards, allocation, logger); + // verify we get decisions for all the shards + assertEquals(shards.size(), allDecisions.size()); + assertEquals(shards, new ArrayList<>(allDecisions.keySet())); + assertEquals(AllocationDecision.YES, allDecisions.get(shard).getAllocationDecision()); + } + + public void testMakeAllocationDecisionDataFetchedMultipleShards() { + setUpShards(2); + final RoutingAllocation allocation = routingAllocationWithMultiplePrimaries( + noAllocationDeciders(), + CLUSTER_RECOVERED, + 2, + 0, + "allocId-0", + "allocId-1" + ); + List shards = new ArrayList<>(); + for (ShardId shardId : shardsInBatch) { + ShardRouting shard = allocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard(); + allocation.routingTable().getIndicesRouting().get("test").shard(shardId.id()).primaryShard().recoverySource(); + shards.add(shard); + batchAllocator.addShardData( + node1, + "allocId-" + shardId.id(), + shardId, + true, + new ReplicationCheckpoint(shardId, 20, 101, 1, Codec.getDefault().getName()), + null + ); + } + HashMap allDecisions = batchAllocator.makeAllocationDecision(shards, allocation, logger); + // verify we get decisions for all the shards + assertEquals(shards.size(), allDecisions.size()); + assertEquals(new HashSet<>(shards), allDecisions.keySet()); + for (ShardRouting shard : shards) { + assertEquals(AllocationDecision.YES, allDecisions.get(shard).getAllocationDecision()); + } + } + + private RoutingAllocation routingAllocationWithOnePrimary( + AllocationDeciders deciders, + UnassignedInfo.Reason reason, + String... activeAllocationIds + ) { + Metadata metadata = Metadata.builder() + .put( + IndexMetadata.builder(shardId.getIndexName()) + .settings(settings(Version.CURRENT)) + .numberOfShards(1) + .numberOfReplicas(1) + .putInSyncAllocationIds(shardId.id(), Sets.newHashSet(activeAllocationIds)) + ) + .build(); + RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); + switch (reason) { + + case INDEX_CREATED: + routingTableBuilder.addAsNew(metadata.index(shardId.getIndex())); + break; + case CLUSTER_RECOVERED: + routingTableBuilder.addAsRecovery(metadata.index(shardId.getIndex())); + break; + case INDEX_REOPENED: + routingTableBuilder.addAsFromCloseToOpen(metadata.index(shardId.getIndex())); + break; + default: + throw new IllegalArgumentException("can't do " + reason + " for you. teach me"); + } + ClusterState state = ClusterState.builder(org.opensearch.cluster.ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTableBuilder.build()) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, null, System.nanoTime()); + } + + private RoutingAllocation routingAllocationWithMultiplePrimaries( + AllocationDeciders deciders, + UnassignedInfo.Reason reason, + int numberOfShards, + int replicas, + String... activeAllocationIds + ) { + Iterator shardIterator = shardsInBatch.iterator(); + Metadata metadata = Metadata.builder() + .put( + IndexMetadata.builder(shardId.getIndexName()) + .settings(settings(Version.CURRENT)) + .numberOfShards(numberOfShards) + .numberOfReplicas(replicas) + .putInSyncAllocationIds(shardIterator.next().id(), Sets.newHashSet(activeAllocationIds[0])) + .putInSyncAllocationIds(shardIterator.next().id(), Sets.newHashSet(activeAllocationIds[1])) + ) + .build(); + + RoutingTable.Builder routingTableBuilder = RoutingTable.builder(); + for (ShardId shardIdFromBatch : shardsInBatch) { + switch (reason) { + case INDEX_CREATED: + routingTableBuilder.addAsNew(metadata.index(shardIdFromBatch.getIndex())); + break; + case CLUSTER_RECOVERED: + routingTableBuilder.addAsRecovery(metadata.index(shardIdFromBatch.getIndex())); + break; + case INDEX_REOPENED: + routingTableBuilder.addAsFromCloseToOpen(metadata.index(shardIdFromBatch.getIndex())); + break; + default: + throw new IllegalArgumentException("can't do " + reason + " for you. teach me"); + } + } + ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTableBuilder.build()) + .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)) + .build(); + return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, null, System.nanoTime()); + } + + class TestBatchAllocator extends PrimaryShardBatchAllocator { + + private Map data; + + public TestBatchAllocator clear() { + data = null; + return this; + } + + public TestBatchAllocator addData( + DiscoveryNode node, + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint + ) { + return addData(node, allocationId, primary, replicationCheckpoint, null); + } + + public TestBatchAllocator addData(DiscoveryNode node, String allocationId, boolean primary) { + Settings nodeSettings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir()).build(); + IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", nodeSettings); + return addData( + node, + allocationId, + primary, + ReplicationCheckpoint.empty(shardId, new CodecService(null, indexSettings, null).codec("default").getName()), + null + ); + } + + public TestBatchAllocator addData(DiscoveryNode node, String allocationId, boolean primary, @Nullable Exception storeException) { + Settings nodeSettings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir()).build(); + IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", nodeSettings); + return addData( + node, + allocationId, + primary, + ReplicationCheckpoint.empty(shardId, new CodecService(null, indexSettings, null).codec("default").getName()), + storeException + ); + } + + public TestBatchAllocator addData( + DiscoveryNode node, + String allocationId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint, + @Nullable Exception storeException + ) { + if (data == null) { + data = new HashMap<>(); + } + Map shardData = Map.of( + shardId, + new TransportNodesGatewayStartedShardHelper.GatewayStartedShard( + allocationId, + primary, + replicationCheckpoint, + storeException + ) + ); + data.put(node, new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch(node, shardData)); + return this; + } + + public TestBatchAllocator addShardData( + DiscoveryNode node, + String allocationId, + ShardId shardId, + boolean primary, + ReplicationCheckpoint replicationCheckpoint, + @Nullable Exception storeException + ) { + if (data == null) { + data = new HashMap<>(); + } + Map shardData = new HashMap<>(); + shardData.put( + shardId, + new TransportNodesGatewayStartedShardHelper.GatewayStartedShard( + allocationId, + primary, + replicationCheckpoint, + storeException + ) + ); + if (data.get(node) != null) shardData.putAll(data.get(node).getNodeGatewayStartedShardsBatch()); + data.put(node, new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch(node, shardData)); + return this; + } + + @Override + protected AsyncShardFetch.FetchResult fetchData( + List shardsEligibleForFetch, + List inEligibleShards, + RoutingAllocation allocation + ) { + return new AsyncShardFetch.FetchResult<>(data, Collections.>emptyMap()); + } + } +} diff --git a/test/framework/src/main/java/org/opensearch/test/gateway/TestGatewayAllocator.java b/test/framework/src/main/java/org/opensearch/test/gateway/TestGatewayAllocator.java index f123b926f5bad..b1695ff00e0cc 100644 --- a/test/framework/src/main/java/org/opensearch/test/gateway/TestGatewayAllocator.java +++ b/test/framework/src/main/java/org/opensearch/test/gateway/TestGatewayAllocator.java @@ -42,6 +42,7 @@ import org.opensearch.gateway.GatewayAllocator; import org.opensearch.gateway.PrimaryShardAllocator; import org.opensearch.gateway.ReplicaShardAllocator; +import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper; import org.opensearch.gateway.TransportNodesListGatewayStartedShards.NodeGatewayStartedShards; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.store.TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata; @@ -91,9 +92,12 @@ protected AsyncShardFetch.FetchResult fetchData(ShardR routing -> currentNodes.get(routing.currentNodeId()), routing -> new NodeGatewayStartedShards( currentNodes.get(routing.currentNodeId()), - routing.allocationId().getId(), - routing.primary(), - getReplicationCheckpoint(shardId, routing.currentNodeId()) + new TransportNodesGatewayStartedShardHelper.GatewayStartedShard( + routing.allocationId().getId(), + routing.primary(), + getReplicationCheckpoint(shardId, routing.currentNodeId()), + null + ) ) ) );