Skip to content

Commit

Permalink
Add logic for rebalancing and allocation
Browse files Browse the repository at this point in the history
Signed-off-by: Arpit Bandejiya <[email protected]>
  • Loading branch information
Arpit-Bandejiya committed Feb 16, 2024
1 parent 443cfca commit fa38d3d
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,15 @@ public float weightWithAllocationConstraints(ShardsBalancer balancer, ModelNode

public float weightWithRebalanceConstraints(ShardsBalancer balancer, ModelNode node, String index) {
float balancerWeight = weight(balancer, node, index);
return balancerWeight + rebalanceConstraints.weight(balancer, node, index);
float extraWt = 0;
int primaryShardCount = node.numPrimaryShards();
int allowedPrimaryShardCount = (int) Math.ceil(balancer.avgPrimaryShardsPerNode());

if(primaryShardCount > allowedPrimaryShardCount) {
extraWt += 1000000L;
}

return balancerWeight + extraWt + rebalanceConstraints.weight(balancer, node, index);
}

float weight(ShardsBalancer balancer, ModelNode node, String index) {
Expand All @@ -396,6 +404,10 @@ void updateRebalanceConstraint(String constraint, boolean add) {
public static class ModelNode implements Iterable<ModelIndex> {
private final Map<String, ModelIndex> indices = new HashMap<>();
private int numShards = 0;

private int totalPrimary = 0;

private int totalReplica = 0;
private final RoutingNode routingNode;

ModelNode(RoutingNode routingNode) {
Expand Down Expand Up @@ -448,6 +460,12 @@ public void addShard(ShardRouting shard) {
}
index.addShard(shard);
numShards++;

if(shard.primary()) {
totalPrimary++;
} else {
totalReplica ++;
}
}

public void removeShard(ShardRouting shard) {
Expand All @@ -459,6 +477,12 @@ public void removeShard(ShardRouting shard) {
}
}
numShards--;

if(shard.primary()) {
totalPrimary--;
} else {
totalReplica--;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -71,6 +72,8 @@ public class LocalShardsBalancer extends ShardsBalancer {
private final BalancedShardsAllocator.NodeSorter sorter;
private final Set<RoutingNode> inEligibleTargetNode;

public int totalRelocation;

public LocalShardsBalancer(
Logger logger,
RoutingAllocation allocation,
Expand All @@ -94,6 +97,7 @@ public LocalShardsBalancer(
inEligibleTargetNode = new HashSet<>();
this.preferPrimaryBalance = preferPrimaryBalance;
this.shardMovementStrategy = shardMovementStrategy;
this.totalRelocation = 0;
}

/**
Expand All @@ -103,6 +107,21 @@ private BalancedShardsAllocator.ModelNode[] nodesArray() {
return nodes.values().toArray(new BalancedShardsAllocator.ModelNode[0]);
}

public void increaseRelocationCount() {
totalRelocation++;
}

public int getRelocationCount() {
return totalRelocation;
}

public void resetRelocationCount() {
totalRelocation = 0;
}

public void printRelocationCount(){
logger.info("Total relocation count is: {}", Integer.toString(totalRelocation));
}
/**
* Returns the average of shards per node for the given index
*/
Expand Down Expand Up @@ -342,6 +361,7 @@ private void balanceByWeights() {
for (String index : buildWeightOrderedIndices()) {
IndexMetadata indexMetadata = metadata.index(index);


// find nodes that have a shard of this index or where shards of this index are allowed to be allocated to,
// move these nodes to the front of modelNodes so that we can only balance based on these nodes
int relevantNodes = 0;
Expand All @@ -366,6 +386,11 @@ private void balanceByWeights() {
while (true) {
final BalancedShardsAllocator.ModelNode minNode = modelNodes[lowIdx];
final BalancedShardsAllocator.ModelNode maxNode = modelNodes[highIdx];
// n1 --> p 1 r 2
// n2 --> p 2 r 1
// n3 --> p 3 r 1
// --> primary balance, avg - 1

advance_range: if (maxNode.numShards(index) > 0) {
final float delta = absDelta(weights[lowIdx], weights[highIdx]);
if (lessThan(delta, threshold)) {
Expand Down Expand Up @@ -897,6 +922,8 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) {
* iteration order is different for each run and makes testing hard */
Map<String, NodeAllocationResult> nodeExplanationMap = explain ? new HashMap<>() : null;
List<Tuple<String, Float>> nodeWeights = explain ? new ArrayList<>() : null;
// Maintain the list of node which have min weight
List<BalancedShardsAllocator.ModelNode> minNodes = new ArrayList<>();
for (BalancedShardsAllocator.ModelNode node : nodes.values()) {
if (node.containsShard(shard) && explain == false) {
// decision is NO without needing to check anything further, so short circuit
Expand All @@ -917,7 +944,9 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) {
}
if (currentDecision.type() == Decision.Type.YES || currentDecision.type() == Decision.Type.THROTTLE) {
final boolean updateMinNode;
final boolean updateMinNodeList;
if (currentWeight == minWeight) {
// debug it more
/* we have an equal weight tie breaking:
* 1. if one decision is YES prefer it
* 2. prefer the node that holds the primary for this index with the next id in the ring ie.
Expand All @@ -935,11 +964,26 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) {
final int minNodeHigh = minNode.highestPrimary(shard.getIndexName());
updateMinNode = ((((nodeHigh > repId && minNodeHigh > repId) || (nodeHigh < repId && minNodeHigh < repId))
&& (nodeHigh < minNodeHigh)) || (nodeHigh > repId && minNodeHigh < repId));

updateMinNodeList = true;
// Add node to the possible node which can be picked
minNodes.add(node);

} else {
updateMinNode = currentDecision.type() == Decision.Type.YES;
if (updateMinNode) {
updateMinNodeList = true;
minNodes.clear();
minNodes.add(node);
}
}
} else {
updateMinNode = currentWeight < minWeight;
if (updateMinNode) {
updateMinNodeList = true;
minNodes.clear();
minNodes.add(node);
}
}
if (updateMinNode) {
minNode = node;
Expand All @@ -963,6 +1007,12 @@ AllocateUnassignedDecision decideAllocateUnassigned(final ShardRouting shard) {
nodeDecisions.add(new NodeAllocationResult(current.getNode(), current.getCanAllocateDecision(), ++weightRanking));
}
}

if (minNodes.isEmpty()){
minNode = null;
} else {
minNode = minNodes.get(new Random().nextInt(minNodes.size()));
}
return AllocateUnassignedDecision.fromDecision(decision, minNode != null ? minNode.getRoutingNode().node() : null, nodeDecisions);
}

Expand Down Expand Up @@ -1002,7 +1052,7 @@ private boolean tryRelocateShard(BalancedShardsAllocator.ModelNode minNode, Bala
// doing such relocation wouldn't help in primary balance.
if (preferPrimaryBalance == true
&& shard.primary()
&& maxNode.numPrimaryShards(shard.getIndexName()) - minNode.numPrimaryShards(shard.getIndexName()) < 2) {
&& maxNode.numPrimaryShards() - minNode.numPrimaryShards() < 2) {
continue;
}

Expand All @@ -1012,8 +1062,9 @@ private boolean tryRelocateShard(BalancedShardsAllocator.ModelNode minNode, Bala

if (decision.type() == Decision.Type.YES) {
/* only allocate on the cluster if we are not throttled */
logger.debug("Relocate [{}] from [{}] to [{}]", shard, maxNode.getNodeId(), minNode.getNodeId());
logger.info("Relocate [{}] from [{}] to [{}]", shard, maxNode.getNodeId(), minNode.getNodeId());
minNode.addShard(routingNodes.relocateShard(shard, minNode.getNodeId(), shardSize, allocation.changes()).v1());
increaseRelocationCount();
return true;
} else {
/* allocate on the model even if throttled */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ public static class ShardAllocations {
*/
static TreeMap<String, int[]> nodeToShardCountMap = new TreeMap<>();

static TreeMap<String, String[]> nodeIdToIndexMap = new TreeMap<>();
static TreeMap<String, String[]> nodeIdToIndexReplicaMap = new TreeMap<>();

/**
* Helper map containing NodeName to NodeId
*/
Expand All @@ -342,8 +345,20 @@ private final static String printShardAllocationWithHeader(int[] shardCount) {
return sb.toString();
}

private final static String printShardAllocationWithHeader(String[] indexNames) {
StringBuffer sb = new StringBuffer();
Formatter formatter = new Formatter(sb, Locale.getDefault());
for( String index: indexNames) {
formatter.format("%-20s", index);
}
formatter.format("\n");
return sb.toString();
}

private static void reset() {
nodeToShardCountMap.clear();
nodeIdToIndexReplicaMap.clear();
nodeIdToIndexMap.clear();
nameToNodeId.clear();
totalShards[0] = totalShards[1] = 0;
unassigned[0] = unassigned[1] = 0;
Expand All @@ -358,26 +373,42 @@ private static void buildMap(ClusterState inputState) {
nameToNodeId.putIfAbsent(node.nodeId(), node.nodeId());
}
nodeToShardCountMap.putIfAbsent(node.nodeId(), new int[] { 0, 0 });
nodeIdToIndexMap.putIfAbsent(node.nodeId(), new String[] {});
nodeIdToIndexReplicaMap.putIfAbsent(node.nodeId(), new String[] {});
}
for (ShardRouting shardRouting : inputState.routingTable().allShards()) {
// Fetch shard to update. Initialize local array
updateMap(nodeToShardCountMap, shardRouting);
updateMap(nodeToShardCountMap, nodeIdToIndexMap, nodeIdToIndexReplicaMap, shardRouting);
}
}

private static void updateMap(TreeMap<String, int[]> mapToUpdate, ShardRouting shardRouting) {
private static void updateMap(TreeMap<String, int[]> ShardCountMapToUpdate,
TreeMap<String, String[]> nodeIdToIndexMapToUpdate,
TreeMap<String, String[]> nodeIdToIndexReplicaMapToUpdate,
ShardRouting shardRouting) {
int[] shard;
shard = shardRouting.assignedToNode() ? mapToUpdate.get(shardRouting.currentNodeId()) : unassigned;
shard = shardRouting.assignedToNode() ? ShardCountMapToUpdate.get(shardRouting.currentNodeId()) : unassigned;
String indexName = shardRouting.getIndexName();
// Update shard type count
if (shardRouting.primary()) {
shard[0]++;
totalShards[0]++;
String[] indexArray = nodeIdToIndexMapToUpdate.get(shardRouting.currentNodeId());
String[] newArray = Arrays.copyOf(indexArray, indexArray.length + 1);
newArray[indexArray.length] = indexName;
nodeIdToIndexMapToUpdate.put(shardRouting.currentNodeId(), newArray);
} else {
shard[1]++;
totalShards[1]++;
String[] indexArray = nodeIdToIndexReplicaMapToUpdate.get(shardRouting.currentNodeId());
String[] newArray = Arrays.copyOf(indexArray, indexArray.length + 1);
newArray[indexArray.length] = indexName;
nodeIdToIndexReplicaMapToUpdate.put(shardRouting.currentNodeId(), newArray);
}


// For assigned shards, put back counter
if (shardRouting.assignedToNode()) mapToUpdate.put(shardRouting.currentNodeId(), shard);
if (shardRouting.assignedToNode()) ShardCountMapToUpdate.put(shardRouting.currentNodeId(), shard);
}

private static String allocation() {
Expand All @@ -388,6 +419,10 @@ private static String allocation() {
String nodeId = nameToNodeId.get(entry.getKey());
formatter.format("%-20s\n", entry.getKey().toUpperCase(Locale.getDefault()));
sb.append(printShardAllocationWithHeader(nodeToShardCountMap.get(nodeId)));
sb.append("Primary Shard indices: " + ONE_LINE_RETURN);
sb.append(printShardAllocationWithHeader(nodeIdToIndexMap.get(nodeId)));
sb.append("Replica Shard indices: " + ONE_LINE_RETURN);
sb.append(printShardAllocationWithHeader(nodeIdToIndexReplicaMap.get(nodeId)));
}
sb.append(ONE_LINE_RETURN);
formatter.format("%-20s (P)%-5s (R)%-5s\n\n", "Unassigned ", unassigned[0], unassigned[1]);
Expand Down

0 comments on commit fa38d3d

Please sign in to comment.