Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the shard routings with no weight for faster access #12989

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -63,7 +62,6 @@
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Collections.emptyMap;

Expand Down Expand Up @@ -96,8 +94,8 @@
private volatile Map<AttributesKey, AttributesRoutings> initializingShardsByAttributes = emptyMap();
private final Object shardsByAttributeMutex = new Object();
private final Object shardsByWeightMutex = new Object();
private volatile Map<WeightedRoutingKey, List<ShardRouting>> activeShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, List<ShardRouting>> initializingShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> activeShardsByWeight = emptyMap();
private volatile Map<WeightedRoutingKey, WeightedShardRoutings> initializingShardsByWeight = emptyMap();

private static final Logger logger = LogManager.getLogger(IndexShardRoutingTable.class);

Expand Down Expand Up @@ -249,7 +247,7 @@
return this.assignedShards;
}

public Map<WeightedRoutingKey, List<ShardRouting>> getActiveShardsByWeight() {
public Map<WeightedRoutingKey, WeightedShardRoutings> getActiveShardsByWeight() {
return activeShardsByWeight;
}

Expand Down Expand Up @@ -338,23 +336,7 @@
// append shards for attribute value with weight zero, so that shard search requests can be tried on
// shard copies in case of request failure from other attribute values.
if (isFailOpenEnabled) {
try {
Stream<String> keys = weightedRouting.weights()
.entrySet()
.stream()
.filter(entry -> entry.getValue().intValue() == WeightedRoutingMetadata.WEIGHED_AWAY_WEIGHT)
.map(Map.Entry::getKey);
keys.forEach(key -> {
ShardIterator iterator = onlyNodeSelectorActiveInitializingShardsIt(weightedRouting.attributeName() + ":" + key, nodes);
while (iterator.remaining() > 0) {
ordered.add(iterator.nextOrNull());
}
});
} catch (IllegalArgumentException e) {
// this exception is thrown by {@link onlyNodeSelectorActiveInitializingShardsIt} in case count of shard
// copies found is zero
logger.debug("no shard copies found for shard id [{}] for node attribute with weight zero", shardId);
}
ordered.addAll(activeInitializingShardsWithoutWeights(weightedRouting, nodes, defaultWeight));
}

return new PlainShardIterator(shardId, ordered);
Expand All @@ -378,6 +360,18 @@
return orderedListWithDistinctShards;
}

private List<ShardRouting> activeInitializingShardsWithoutWeights(
WeightedRouting weightedRouting,
DiscoveryNodes nodes,
double defaultWeight
) {
List<ShardRouting> ordered = new ArrayList<>(getActiveShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
if (!allInitializingShards.isEmpty()) {
ordered.addAll(getInitializingShardsWithoutWeight(weightedRouting, nodes, defaultWeight));
}
return ordered.stream().distinct().collect(Collectors.toList());
}

/**
* Returns a list containing shard routings ordered using weighted round-robin scheduling.
*/
Expand Down Expand Up @@ -949,20 +943,60 @@
}
}

/**
* Holder class for shard routing(s) which are classified and stored based on their weights.
*
* @opensearch.api
*/
@PublicApi(since = "2.14.0")
public static class WeightedShardRoutings {
backslasht marked this conversation as resolved.
Show resolved Hide resolved
private final List<ShardRouting> shardRoutingsWithWeight;
private final List<ShardRouting> shardRoutingWithoutWeight;

public WeightedShardRoutings(List<ShardRouting> shardRoutingsWithWeight, List<ShardRouting> shardRoutingWithoutWeight) {
this.shardRoutingsWithWeight = Collections.unmodifiableList(shardRoutingsWithWeight);
this.shardRoutingWithoutWeight = Collections.unmodifiableList(shardRoutingWithoutWeight);
}

public List<ShardRouting> getShardRoutingsWithWeight() {
return shardRoutingsWithWeight;
}

public List<ShardRouting> getShardRoutingWithoutWeight() {
return shardRoutingWithoutWeight;
}
}

/**
* *
* Gets active shard routing from memory if available, else calculates and put it in memory.
*/
private List<ShardRouting> getActiveShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> shardRoutings = activeShardsByWeight.get(key);
if (shardRoutings == null) {
synchronized (shardsByWeightMutex) {
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
activeShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
}
if (activeShardsByWeight.get(key) == null) {
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return activeShardsByWeight.get(key).getShardRoutingsWithWeight();
}

private List<ShardRouting> getActiveShardsWithoutWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
if (activeShardsByWeight.get(key) == null) {
populateActiveShardWeightsMap(weightedRouting, nodes, defaultWeight);

Check warning on line 985 in server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java#L985

Added line #L985 was not covered by tests
}
return activeShardsByWeight.get(key).getShardRoutingWithoutWeight();
}

private void populateActiveShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
backslasht marked this conversation as resolved.
Show resolved Hide resolved
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
List<ShardRouting> nonWeightedRoutings = activeShards.stream()
.filter(shard -> !weightedRoutings.contains(shard))
.collect(Collectors.toUnmodifiableList());
synchronized (shardsByWeightMutex) {
activeShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
.immutableMap();
}
return shardRoutings;
}

/**
Expand All @@ -971,14 +1005,34 @@
*/
private List<ShardRouting> getInitializingShardsByWeight(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> shardRoutings = initializingShardsByWeight.get(key);
if (shardRoutings == null) {
synchronized (shardsByWeightMutex) {
shardRoutings = shardsOrderedByWeight(activeShards, weightedRouting, nodes, defaultWeight);
initializingShardsByWeight = new MapBuilder().put(key, shardRoutings).immutableMap();
}
if (initializingShardsByWeight.get(key) == null) {
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);
}
return initializingShardsByWeight.get(key).getShardRoutingsWithWeight();
}

private List<ShardRouting> getInitializingShardsWithoutWeight(
WeightedRouting weightedRouting,
DiscoveryNodes nodes,
double defaultWeight
) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
if (initializingShardsByWeight.get(key) == null) {
populateInitializingShardWeightsMap(weightedRouting, nodes, defaultWeight);

Check warning on line 1021 in server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/cluster/routing/IndexShardRoutingTable.java#L1021

Added line #L1021 was not covered by tests
}
return initializingShardsByWeight.get(key).getShardRoutingWithoutWeight();
}

private void populateInitializingShardWeightsMap(WeightedRouting weightedRouting, DiscoveryNodes nodes, double defaultWeight) {
WeightedRoutingKey key = new WeightedRoutingKey(weightedRouting);
List<ShardRouting> weightedRoutings = shardsOrderedByWeight(allInitializingShards, weightedRouting, nodes, defaultWeight);
List<ShardRouting> nonWeightedRoutings = allInitializingShards.stream()
.filter(shard -> !weightedRoutings.contains(shard))
.collect(Collectors.toUnmodifiableList());
synchronized (shardsByWeightMutex) {
initializingShardsByWeight = new MapBuilder().put(key, new WeightedShardRoutings(weightedRoutings, nonWeightedRoutings))
.immutableMap();
}
return shardRoutings;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,18 @@ public void testWeightedRoutingWithDifferentWeights() {
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
shardRouting = shardIterator.nextOrNull();
assertNotNull(shardRouting);
assertFalse(Arrays.asList("node2", "node1").contains(shardRouting.currentNodeId()));
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());

weights = Map.of("zone1", -1.0, "zone2", 0.0, "zone3", 1.0);
weightedRouting = new WeightedRouting("zone", weights);
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());

weights = Map.of("zone1", 3.0, "zone2", 2.0, "zone3", 0.0);
weightedRouting = new WeightedRouting("zone", weights);
Expand All @@ -711,8 +720,138 @@ public void testWeightedRoutingWithDifferentWeights() {
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
shardRouting = shardIterator.nextOrNull();
assertNotNull(shardRouting);
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertNotEquals("node3", shardIterator.nextOrNull().currentNodeId());
assertEquals("node3", shardIterator.nextOrNull().currentNodeId());
} finally {
terminate(threadPool);
}
}

public void testWeightedRoutingWithInitializingShards() {
TestThreadPool threadPool = null;
try {
Settings.Builder settings = Settings.builder()
.put("cluster.routing.allocation.node_concurrent_recoveries", 10)
.put("cluster.routing.allocation.awareness.attributes", "zone");
AllocationService strategy = createAllocationService(settings.build());

Metadata metadata = Metadata.builder()
.put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
.build();

RoutingTable routingTable = RoutingTable.builder().addAsNew(metadata.index("test")).build();

ClusterState clusterState = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
.metadata(metadata)
.routingTable(routingTable)
.build();

threadPool = new TestThreadPool("testThatOnlyNodesSupport");
ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);

Map<String, String> node1Attributes = new HashMap<>();
node1Attributes.put("zone", "zone1");
Map<String, String> node2Attributes = new HashMap<>();
node2Attributes.put("zone", "zone2");
Map<String, String> node3Attributes = new HashMap<>();
node3Attributes.put("zone", "zone3");

DiscoveryNodes nodes = DiscoveryNodes.builder()
.add(newNode("node1", unmodifiableMap(node1Attributes)))
.add(newNode("node2", unmodifiableMap(node2Attributes)))
.add(newNode("node3", unmodifiableMap(node3Attributes)))
.localNodeId("node1")
.build();
clusterState = ClusterState.builder(clusterState).nodes(nodes).build();
clusterState = strategy.reroute(clusterState, "reroute");

// Making the first shard as active
clusterState = startInitializingShardsAndReroute(strategy, clusterState);
// Making the second shard as active
clusterState = startRandomInitializingShard(clusterState, strategy);

String[] startedNodes = new String[2];
String[] startedZones = new String[2];
String initializingNode = null;
String initializingZone = null;
int i = 0;
for (ShardRouting shard : clusterState.routingTable().allShards()) {
if (shard.initializing()) {
initializingNode = shard.currentNodeId();
initializingZone = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone");

} else {
startedNodes[i] = shard.currentNodeId();
startedZones[i++] = nodes.resolveNode(shard.currentNodeId()).getAttributes().get("zone");
}
}

Map<String, Double> weights = Map.of(startedZones[0], 1.0, initializingZone, 1.0, startedZones[1], 0.0);
WeightedRouting weightedRouting = new WeightedRouting("zone", weights);

// With fail open enabled set to false, we expect 2 shard routing, first one started, followed by initializing
ShardIterator shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);

assertEquals(2, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

// With fail open enabled set to true, we expect 3 shard routing, first one started, followed by initializing, third one started
// with zero weight
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);

assertEquals(3, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId());

weights = Map.of(initializingZone, 1.0, startedZones[0], 0.0, startedZones[1], 0.0);
weightedRouting = new WeightedRouting("zone", weights);

// only initializing shard has weight with fail open true
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

// only initializing shard has weight with fail open false
shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());
assertNotEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

weights = Map.of(initializingZone, 0.0, startedZones[0], 1.0, startedZones[1], 0.0);
weightedRouting = new WeightedRouting("zone", weights);

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, false, null);
assertEquals(1, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());

shardIterator = clusterState.routingTable()
.index("test")
.shard(0)
.activeInitializingShardsWeightedIt(weightedRouting, clusterState.nodes(), 1, true, null);
assertEquals(3, shardIterator.size());
assertEquals(startedNodes[0], shardIterator.nextOrNull().currentNodeId());
assertEquals(startedNodes[1], shardIterator.nextOrNull().currentNodeId());
assertEquals(initializingNode, shardIterator.nextOrNull().currentNodeId());

} finally {
terminate(threadPool);
}
Expand Down
Loading