diff --git a/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java b/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java index 36149d014ea84..2c250f6a5d86e 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java +++ b/server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java @@ -34,7 +34,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.cluster.metadata.WeightedRoutingMetadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.Nullable; @@ -63,7 +62,6 @@ import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; -import java.util.stream.Stream; import static java.util.Collections.emptyMap; @@ -96,8 +94,8 @@ public class IndexShardRoutingTable implements Iterable { private volatile Map initializingShardsByAttributes = emptyMap(); private final Object shardsByAttributeMutex = new Object(); private final Object shardsByWeightMutex = new Object(); - private volatile Map> activeShardsByWeight = emptyMap(); - private volatile Map> initializingShardsByWeight = emptyMap(); + private volatile Map activeShardsByWeight = emptyMap(); + private volatile Map initializingShardsByWeight = emptyMap(); private static final Logger logger = LogManager.getLogger(IndexShardRoutingTable.class); @@ -249,7 +247,7 @@ public List assignedShards() { return this.assignedShards; } - public Map> getActiveShardsByWeight() { + public Map getActiveShardsByWeight() { return activeShardsByWeight; } @@ -338,23 +336,7 @@ public ShardIterator activeInitializingShardsWeightedIt( // append shards for attribute value with weight zero, so that shard search requests can be tried on // shard copies in case of request failure from other attribute values. if (isFailOpenEnabled) { - try { - Stream keys = weightedRouting.weights() - .entrySet() - .stream() - .filter(entry -> entry.getValue().intValue() == WeightedRoutingMetadata.WEIGHED_AWAY_WEIGHT) - .map(Map.Entry::getKey); - keys.forEach(key -> { - ShardIterator iterator = onlyNodeSelectorActiveInitializingShardsIt(weightedRouting.attributeName() + ":" + key, nodes); - while (iterator.remaining() > 0) { - ordered.add(iterator.nextOrNull()); - } - }); - } catch (IllegalArgumentException e) { - // this exception is thrown by {@link onlyNodeSelectorActiveInitializingShardsIt} in case count of shard - // copies found is zero - logger.debug("no shard copies found for shard id [{}] for node attribute with weight zero", shardId); - } + ordered.addAll(activeInitializingShardsWithoutWeights(weightedRouting, nodes, defaultWeight)); } return new PlainShardIterator(shardId, ordered); @@ -378,6 +360,18 @@ private List activeInitializingShardsWithWeights( return orderedListWithDistinctShards; } + private List activeInitializingShardsWithoutWeights( + WeightedRouting weightedRouting, + DiscoveryNodes nodes, + double defaultWeight + ) { + List ordered = new ArrayList<>(getActiveShardsWithoutWeight(weightedRouting, nodes, defaultWeight)); + if (!allInitializingShards.isEmpty()) { + ordered.addAll(getInitializingShardsWithoutWeight(weightedRouting, nodes, defaultWeight)); + } + return ordered.stream().distinct().collect(Collectors.toList()); + } + /** * Returns a list containing shard routings ordered using weighted round-robin scheduling. */ @@ -949,20 +943,60 @@ public int hashCode() { } } + /** + * Holder class for shard routing(s) which are classified and stored based on their weights. + * + * @opensearch.api + */ + @PublicApi(since = "2.14.0") + public static class WeightedShardRoutings { + private final List shardRoutingsWithWeight; + private final List shardRoutingWithoutWeight; + + public WeightedShardRoutings(List shardRoutingsWithWeight, List shardRoutingWithoutWeight) { + this.shardRoutingsWithWeight = Collections.unmodifiableList(shardRoutingsWithWeight); + this.shardRoutingWithoutWeight = Collections.unmodifiableList(shardRoutingWithoutWeight); + } + + public List getShardRoutingsWithWeight() { + return shardRoutingsWithWeight; + } + + public List getShardRoutingWithoutWeight() { + return shardRoutingWithoutWeight; + } + } + /** * * * Gets active shard routing from memory if available, else calculates and put it in memory. */ private List getActiveShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); - List shardRoutings = activeShardsByWeight.get(key); - if (shardRoutings == null) { - synchronized (shardsByWeightMutex) { - shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight); - activeShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap(); - } + if (activeShardsByWeight.get(key) == null) { + populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight); + } + return activeShardsByWeight.get(key).getShardRoutingsWithWeight(); + } + + private List getActiveShardsWithoutWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + if (activeShardsByWeight.get(key) == null) { + populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight); + } + return activeShardsByWeight.get(key).getShardRoutingWithoutWeight(); + } + + private void populateActiveShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + List weightedRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight); + List nonWeightedRoutings = activeShards.stream() + .filter(shard -> !weightedRoutings.contains(shard)) + .collect(Collectors.toUnmodifiableList()); + synchronized (shardsByWeightMutex) { + activeShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings)) + .immutableMap(); } - return shardRoutings; } /** @@ -971,14 +1005,34 @@ private List getActiveShardsByWeight(WeightedRouting weightedRouti */ private List getInitializingShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); - List shardRoutings = initializingShardsByWeight.get(key); - if (shardRoutings == null) { - synchronized (shardsByWeightMutex) { - shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight); - initializingShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap(); - } + if (initializingShardsByWeight.get(key) == null) { + populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight); + } + return initializingShardsByWeight.get(key).getShardRoutingsWithWeight(); + } + + private List getInitializingShardsWithoutWeight( + WeightedRouting weightedRouting, + DiscoveryNodes nodes, + double defaultWeight + ) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + if (initializingShardsByWeight.get(key) == null) { + populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight); + } + return initializingShardsByWeight.get(key).getShardRoutingWithoutWeight(); + } + + private void populateInitializingShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) { + WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting); + List weightedRoutings = shardsOrderedByWeight(allInitializingShards, weightedRouting, nodes, defaultWeight); + List nonWeightedRoutings = allInitializingShards.stream() + .filter(shard -> !weightedRoutings.contains(shard)) + .collect(Collectors.toUnmodifiableList()); + synchronized (shardsByWeightMutex) { + initializingShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings)) + .immutableMap(); } - return shardRoutings; } /** diff --git a/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java b/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java index 86adfc1279d9a..190ad3283dcfc 100644 --- a/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java +++ b/server/src/test/java/org/opensearch/cluster/structure/RoutingIteratorTests.java @@ -700,9 +700,18 @@ public void testWeightedRoutingWithDifferentWeights() { .shard(0) .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null); assertEquals(1, shardIterator.size()); - shardRouting = shardIterator.nextOrNull(); - assertNotNull(shardRouting); - assertFalse(Arrays.asList("node2", "node1").contains(shardRouting.currentNodeId())); + assertEquals("node3", shardIterator.nextOrNull().currentNodeId()); + + weights = Map.of("zone1", -1.0, "zone2", 0.0, "zone3", 1.0); + weightedRouting = new WeightedRouting("zone", weights); + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null); + assertEquals(3, shardIterator.size()); + assertEquals("node3", shardIterator.nextOrNull().currentNodeId()); + assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId()); + assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId()); weights = Map.of("zone1", 3.0, "zone2", 2.0, "zone3", 0.0); weightedRouting = new WeightedRouting("zone", weights); @@ -711,8 +720,138 @@ public void testWeightedRoutingWithDifferentWeights() { .shard(0) .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null); assertEquals(3, shardIterator.size()); - shardRouting = shardIterator.nextOrNull(); - assertNotNull(shardRouting); + assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId()); + assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId()); + assertEquals("node3", shardIterator.nextOrNull().currentNodeId()); + } finally { + terminate(threadPool); + } + } + + public void testWeightedRoutingWithInitializingShards() { + TestThreadPool threadPool = null; + try { + Settings.Builder settings = Settings.builder() + .put("cluster.routing.allocation.node_concurrent_recoveries", 10) + .put("cluster.routing.allocation.awareness.attributes", "zone"); + AllocationService strategy = createAllocationService(settings.build()); + + Metadata metadata = Metadata.builder() + .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2)) + .build(); + + RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build(); + + ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)) + .metadata(metadata) + .routingTable(routingTable) + .build(); + + threadPool = new TestThreadPool("testThatOnlyNodesSupport"); + ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); + + Map node1Attributes = new HashMap<>(); + node1Attributes.put("zone", "zone1"); + Map node2Attributes = new HashMap<>(); + node2Attributes.put("zone", "zone2"); + Map node3Attributes = new HashMap<>(); + node3Attributes.put("zone", "zone3"); + + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(newNode("node1", unmodifiableMap(node1Attributes))) + .add(newNode("node2", unmodifiableMap(node2Attributes))) + .add(newNode("node3", unmodifiableMap(node3Attributes))) + .localNodeId("node1") + .build(); + clusterState = ClusterState.builder(clusterState).nodes(nodes).build(); + clusterState = strategy.reroute(clusterState, "reroute"); + + // Making the first shard as active + clusterState = startInitializingShardsAndReroute(strategy, clusterState); + // Making the second shard as active + clusterState = startRandomInitializingShard(clusterState, strategy); + + String[] startedNodes = new String[2]; + String[] startedZones = new String[2]; + String initializingNode = null; + String initializingZone = null; + int i = 0; + for (ShardRouting shard : clusterState.routingTable().allShards()) { + if (shard.initializing()) { + initializingNode = shard.currentNodeId(); + initializingZone = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone"); + + } else { + startedNodes[i] = shard.currentNodeId(); + startedZones[i++] = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone"); + } + } + + Map weights = Map.of(startedZones[0], 1.0, initializingZone, 1.0, startedZones[1], 0.0); + WeightedRouting weightedRouting = new WeightedRouting("zone", weights); + + // With fail open enabled set to false, we expect 2 shard routing, first one started, followed by initializing + ShardIterator shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null); + + assertEquals(2, shardIterator.size()); + assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId()); + assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + + // With fail open enabled set to true, we expect 3 shard routing, first one started, followed by initializing, third one started + // with zero weight + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null); + + assertEquals(3, shardIterator.size()); + assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId()); + assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId()); + + weights = Map.of(initializingZone, 1.0, startedZones[0], 0.0, startedZones[1], 0.0); + weightedRouting = new WeightedRouting("zone", weights); + + // only initializing shard has weight with fail open true + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null); + assertEquals(1, shardIterator.size()); + assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + + // only initializing shard has weight with fail open false + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null); + assertEquals(3, shardIterator.size()); + assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + + weights = Map.of(initializingZone, 0.0, startedZones[0], 1.0, startedZones[1], 0.0); + weightedRouting = new WeightedRouting("zone", weights); + + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null); + assertEquals(1, shardIterator.size()); + assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId()); + + shardIterator = clusterState.routingTable() + .index("test") + .shard(0) + .activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null); + assertEquals(3, shardIterator.size()); + assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId()); + assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId()); + assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId()); + } finally { terminate(threadPool); }