diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java index 90eff50fd9b5d..3e2ec27ff3c79 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java @@ -370,7 +370,15 @@ public float weightWithAllocationConstraints(ShardsBalancer balancer, ModelNode public float weightWithRebalanceConstraints(ShardsBalancer balancer, ModelNode node, String index) { float balancerWeight = weight(balancer, node, index); - return balancerWeight + rebalanceConstraints.weight(balancer, node, index); + float extraWt = 0; + int primaryShardCount = node.numPrimaryShards(); + int allowedPrimaryShardCount = (int) Math.ceil(balancer.avgPrimaryShardsPerNode()); + + if(primaryShardCount > allowedPrimaryShardCount) { + extraWt += 1000000L; + } + + return balancerWeight + extraWt + rebalanceConstraints.weight(balancer, node, index); } float weight(ShardsBalancer balancer, ModelNode node, String index) { @@ -396,6 +404,10 @@ void updateRebalanceConstraint(String constraint, boolean add) { public static class ModelNode implements Iterable { private final Map indices = new HashMap<>(); private int numShards = 0; + + private int totalPrimary = 0; + + private int totalReplica = 0; private final RoutingNode routingNode; ModelNode(RoutingNode routingNode) { @@ -448,6 +460,12 @@ public void addShard(ShardRouting shard) { } index.addShard(shard); numShards++; + + if(shard.primary()) { + totalPrimary++; + } else { + totalReplica ++; + } } public void removeShard(ShardRouting shard) { @@ -459,6 +477,12 @@ public void removeShard(ShardRouting shard) { } } numShards--; + + if(shard.primary()) { + totalPrimary--; + } else { + totalReplica--; + } } @Override diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java index 3365b58d92a63..d44a8b42b42af 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java @@ -40,6 +40,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -71,6 +72,8 @@ public class LocalShardsBalancer extends ShardsBalancer { private final BalancedShardsAllocator.NodeSorter sorter; private final Set inEligibleTargetNode; + public int totalRelocation; + public LocalShardsBalancer( Logger logger, RoutingAllocation allocation, @@ -94,6 +97,7 @@ public LocalShardsBalancer( inEligibleTargetNode = new HashSet<>(); this.preferPrimaryBalance = preferPrimaryBalance; this.shardMovementStrategy = shardMovementStrategy; + this.totalRelocation = 0; } /** @@ -103,6 +107,21 @@ private BalancedShardsAllocator.ModelNode[] nodesArray() { return nodes.values().toArray(new BalancedShardsAllocator.ModelNode[0]); } + public void increaseRelocationCount() { + totalRelocation++; + } + + public int getRelocationCount() { + return totalRelocation; + } + + public void resetRelocationCount() { + totalRelocation = 0; + } + + public void printRelocationCount(){ + logger.info("Total relocation count is: {}", Integer.toString(totalRelocation)); + } /** * Returns the average of shards per node for the given index */ @@ -342,6 +361,7 @@ private void balanceByWeights() { for (String index : buildWeightOrderedIndices()) { IndexMetadata indexMetadata = metadata.index(index); + // find nodes that have a shard of this index or where shards of this index are allowed to be allocated to, // move these nodes to the front of modelNodes so that we can only balance based on these nodes int relevantNodes = 0; @@ -366,6 +386,11 @@ private void balanceByWeights() { while (true) { final BalancedShardsAllocator.ModelNode minNode = modelNodes[lowIdx]; final BalancedShardsAllocator.ModelNode maxNode = modelNodes[highIdx]; + // n1 --> p 1 r 2 + // n2 --> p 2 r 1 + // n3 --> p 3 r 1 + // --> primary balance, avg - 1 + advance_range: if (maxNode.numShards(index) > 0) { final float delta = absDelta(weights[lowIdx], weights[highIdx]); if (lessThan(delta, threshold)) { @@ -897,6 +922,8 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) { * iteration order is different for each run and makes testing hard */ Map nodeExplanationMap = explain ? new HashMap<>() : null; List> nodeWeights = explain ? new ArrayList<>() : null; + // Maintain the list of node which have min weight + List minNodes = new ArrayList<>(); for (BalancedShardsAllocator.ModelNode node : nodes.values()) { if (node.containsShard(shard) && explain == false) { // decision is NO without needing to check anything further, so short circuit @@ -917,7 +944,9 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) { } if (currentDecision.type() == Decision.Type.YES || currentDecision.type() == Decision.Type.THROTTLE) { final boolean updateMinNode; + final boolean updateMinNodeList; if (currentWeight == minWeight) { + // debug it more /* we have an equal weight tie breaking: * 1. if one decision is YES prefer it * 2. prefer the node that holds the primary for this index with the next id in the ring ie. @@ -935,11 +964,26 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) { final int minNodeHigh = minNode.highestPrimary(shard.getIndexName()); updateMinNode = ((((nodeHigh > repId && minNodeHigh > repId) || (nodeHigh < repId && minNodeHigh < repId)) && (nodeHigh < minNodeHigh)) || (nodeHigh > repId && minNodeHigh < repId)); + + updateMinNodeList = true; + // Add node to the possible node which can be picked + minNodes.add(node); + } else { updateMinNode = currentDecision.type() == Decision.Type.YES; + if (updateMinNode) { + updateMinNodeList = true; + minNodes.clear(); + minNodes.add(node); + } } } else { updateMinNode = currentWeight < minWeight; + if (updateMinNode) { + updateMinNodeList = true; + minNodes.clear(); + minNodes.add(node); + } } if (updateMinNode) { minNode = node; @@ -963,6 +1007,12 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) { nodeDecisions.add(new NodeAllocationResult(current.getNode(), current.getCanAllocateDecision(), ++weightRanking)); } } + + if (minNodes.isEmpty()){ + minNode = null; + } else { + minNode = minNodes.get(new Random().nextInt(minNodes.size())); + } return AllocateUnassignedDecision.fromDecision(decision, minNode != null ? minNode.getRoutingNode().node() : null, nodeDecisions); } @@ -1002,7 +1052,7 @@ private boolean tryRelocateShard(BalancedShardsAllocator.ModelNode minNode, Bala // doing such relocation wouldn't help in primary balance. if (preferPrimaryBalance == true && shard.primary() - && maxNode.numPrimaryShards(shard.getIndexName()) - minNode.numPrimaryShards(shard.getIndexName()) < 2) { + && maxNode.numPrimaryShards() - minNode.numPrimaryShards() < 2) { continue; } @@ -1012,8 +1062,9 @@ private boolean tryRelocateShard(BalancedShardsAllocator.ModelNode minNode, Bala if (decision.type() == Decision.Type.YES) { /* only allocate on the cluster if we are not throttled */ - logger.debug("Relocate [{}] from [{}] to [{}]", shard, maxNode.getNodeId(), minNode.getNodeId()); + logger.info("Relocate [{}] from [{}] to [{}]", shard, maxNode.getNodeId(), minNode.getNodeId()); minNode.addShard(routingNodes.relocateShard(shard, minNode.getNodeId(), shardSize, allocation.changes()).v1()); + increaseRelocationCount(); return true; } else { /* allocate on the model even if throttled */ diff --git a/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java b/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java index f6113860e3907..571bc8493ad5d 100644 --- a/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java +++ b/test/framework/src/main/java/org/opensearch/cluster/OpenSearchAllocationTestCase.java @@ -322,6 +322,9 @@ public static class ShardAllocations { */ static TreeMap nodeToShardCountMap = new TreeMap<>(); + static TreeMap nodeIdToIndexMap = new TreeMap<>(); + static TreeMap nodeIdToIndexReplicaMap = new TreeMap<>(); + /** * Helper map containing NodeName to NodeId */ @@ -342,8 +345,20 @@ private final static String printShardAllocationWithHeader(int[] shardCount) { return sb.toString(); } + private final static String printShardAllocationWithHeader(String[] indexNames) { + StringBuffer sb = new StringBuffer(); + Formatter formatter = new Formatter(sb, Locale.getDefault()); + for( String index: indexNames) { + formatter.format("%-20s", index); + } + formatter.format("\n"); + return sb.toString(); + } + private static void reset() { nodeToShardCountMap.clear(); + nodeIdToIndexReplicaMap.clear(); + nodeIdToIndexMap.clear(); nameToNodeId.clear(); totalShards[0] = totalShards[1] = 0; unassigned[0] = unassigned[1] = 0; @@ -358,26 +373,42 @@ private static void buildMap(ClusterState inputState) { nameToNodeId.putIfAbsent(node.nodeId(), node.nodeId()); } nodeToShardCountMap.putIfAbsent(node.nodeId(), new int[] { 0, 0 }); + nodeIdToIndexMap.putIfAbsent(node.nodeId(), new String[] {}); + nodeIdToIndexReplicaMap.putIfAbsent(node.nodeId(), new String[] {}); } for (ShardRouting shardRouting : inputState.routingTable().allShards()) { // Fetch shard to update. Initialize local array - updateMap(nodeToShardCountMap, shardRouting); + updateMap(nodeToShardCountMap, nodeIdToIndexMap, nodeIdToIndexReplicaMap, shardRouting); } } - private static void updateMap(TreeMap mapToUpdate, ShardRouting shardRouting) { + private static void updateMap(TreeMap ShardCountMapToUpdate, + TreeMap nodeIdToIndexMapToUpdate, + TreeMap nodeIdToIndexReplicaMapToUpdate, + ShardRouting shardRouting) { int[] shard; - shard = shardRouting.assignedToNode() ? mapToUpdate.get(shardRouting.currentNodeId()) : unassigned; + shard = shardRouting.assignedToNode() ? ShardCountMapToUpdate.get(shardRouting.currentNodeId()) : unassigned; + String indexName = shardRouting.getIndexName(); // Update shard type count if (shardRouting.primary()) { shard[0]++; totalShards[0]++; + String[] indexArray = nodeIdToIndexMapToUpdate.get(shardRouting.currentNodeId()); + String[] newArray = Arrays.copyOf(indexArray, indexArray.length + 1); + newArray[indexArray.length] = indexName; + nodeIdToIndexMapToUpdate.put(shardRouting.currentNodeId(), newArray); } else { shard[1]++; totalShards[1]++; + String[] indexArray = nodeIdToIndexReplicaMapToUpdate.get(shardRouting.currentNodeId()); + String[] newArray = Arrays.copyOf(indexArray, indexArray.length + 1); + newArray[indexArray.length] = indexName; + nodeIdToIndexReplicaMapToUpdate.put(shardRouting.currentNodeId(), newArray); } + + // For assigned shards, put back counter - if (shardRouting.assignedToNode()) mapToUpdate.put(shardRouting.currentNodeId(), shard); + if (shardRouting.assignedToNode()) ShardCountMapToUpdate.put(shardRouting.currentNodeId(), shard); } private static String allocation() { @@ -388,6 +419,10 @@ private static String allocation() { String nodeId = nameToNodeId.get(entry.getKey()); formatter.format("%-20s\n", entry.getKey().toUpperCase(Locale.getDefault())); sb.append(printShardAllocationWithHeader(nodeToShardCountMap.get(nodeId))); + sb.append("Primary Shard indices: " + ONE_LINE_RETURN); + sb.append(printShardAllocationWithHeader(nodeIdToIndexMap.get(nodeId))); + sb.append("Replica Shard indices: " + ONE_LINE_RETURN); + sb.append(printShardAllocationWithHeader(nodeIdToIndexReplicaMap.get(nodeId))); } sb.append(ONE_LINE_RETURN); formatter.format("%-20s (P)%-5s (R)%-5s\n\n", "Unassigned ", unassigned[0], unassigned[1]);