Skip to content

Commit

Permalink
Use failed shards data only when all nodes fetching is done
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Mar 14, 2024
1 parent f4c4ff3 commit dc86547
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static Map<ShardId, ShardAttributes> prepareRequestMap(String[] indices,
);
for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) {
final ShardId shardId = new ShardId(index, shardIdNum);
shardIdShardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
shardIdShardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
}
}
return shardIdShardAttributesMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.common.logging.Loggers;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;
import reactor.util.annotation.NonNull;

import java.lang.reflect.Array;
import java.util.ArrayList;
Expand All @@ -25,12 +24,13 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import reactor.util.annotation.NonNull;

/**
* Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch
* part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here.
Expand Down Expand Up @@ -78,14 +78,27 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
}

public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
reroute("shards-failed", "shards failed in " + reroutingKey);
FetchResult<T> result = super.fetchData(nodes, ignoreNodes);
if (result.hasData()) {
// trigger reroute for failed shards only when all nodes have completed fetching
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
reroute("shards-failed", "shards failed in " + reroutingKey);
failedShards.clear();
}
}
return super.fetchData(nodes, ignoreNodes);
return result;
}

/**
* Remove the shard from shardAttributesMap so it's not sent in next asyncFetch.
* Call removeShardFromBatch method to remove the shardId from the batch object created in
* ShardsBatchGatewayAllocator.
* Add shardId to failedShards, so it can be used to trigger another reroute as part of upcoming fetchData call.
*
* @param shardId shardId to be cleaned up from batch and cache.
*/
private void cleanUpFailedShards(ShardId shardId) {
shardAttributesMap.remove(shardId);
removeShardFromBatch.accept(shardId);
Expand All @@ -110,13 +123,12 @@ public void clearShard(ShardId shardId) {
* @param <V> Data type of shard level response.
*/
public static class ShardBatchCache<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardFetchCache<T> {
private final Map<String, NodeEntry<V>> cache = new HashMap<>();
private final Map<ShardId, Integer> shardIdKey = new HashMap<>();
private final AtomicInteger shardIdIndex = new AtomicInteger();
private final Map<String, NodeEntry<V>> cache;
private final Map<ShardId, Integer> shardIdToArray;
private final int batchSize;
private final Class<V> shardResponseClass;
private final BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseConstructor;
private final Map<Integer, ShardId> shardIdReverseKey = new HashMap<>();
private final Map<Integer, ShardId> arrayToShardId;
private final Function<T, Map<ShardId, V>> shardsBatchDataGetter;
private final Supplier<V> emptyResponseBuilder;
private final Consumer<ShardId> handleFailedShard;
Expand All @@ -135,6 +147,9 @@ public ShardBatchCache(
) {
super(Loggers.getLogger(logger, "_" + logKey), type);
this.batchSize = shardAttributesMap.size();
cache = new HashMap<>();
shardIdToArray = new HashMap<>();
arrayToShardId = new HashMap<>();
fillShardIdKeys(shardAttributesMap.keySet());
this.shardResponseClass = clazz;
this.responseConstructor = responseGetter;
Expand All @@ -152,8 +167,8 @@ public ShardBatchCache(

@Override
public void deleteShard(ShardId shardId) {
if (shardIdKey.containsKey(shardId)) {
Integer shardIdIndex = shardIdKey.remove(shardId);
if (shardIdToArray.containsKey(shardId)) {
Integer shardIdIndex = shardIdToArray.remove(shardId);
for (String nodeId : cache.keySet()) {
cache.get(nodeId).clearShard(shardIdIndex);
}
Expand All @@ -171,9 +186,9 @@ public Map<DiscoveryNode, T> getCacheData(DiscoveryNodes nodes, Set<String> fail
* PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for.
*/
private void refreshReverseIdMap() {
shardIdReverseKey.clear();
for (ShardId shardId : shardIdKey.keySet()) {
shardIdReverseKey.putIfAbsent(shardIdKey.get(shardId), shardId);
arrayToShardId.clear();
for (ShardId shardId : shardIdToArray.keySet()) {
arrayToShardId.putIfAbsent(shardIdToArray.get(shardId), shardId);
}
}

Expand All @@ -195,7 +210,7 @@ public void putData(DiscoveryNode node, T response) {
NodeEntry<V> nodeEntry = cache.get(node.getId());
Map<ShardId, V> batchResponse = shardsBatchDataGetter.apply(response);
filterFailedShards(batchResponse);
nodeEntry.doneFetching(batchResponse, shardIdKey);
nodeEntry.doneFetching(batchResponse, shardIdToArray);
}

/**
Expand Down Expand Up @@ -237,22 +252,23 @@ private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
V[] nodeShardEntries = nodeEntry.getData();
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
HashMap<ShardId, V> shardData = new HashMap<>();
for (Integer shardIdIndex : shardIdKey.values()) {
for (Integer shardIdIndex : shardIdToArray.values()) {
if (emptyResponses[shardIdIndex]) {
shardData.put(shardIdReverseKey.get(shardIdIndex), emptyResponseBuilder.get());
shardData.put(arrayToShardId.get(shardIdIndex), emptyResponseBuilder.get());
} else if (nodeShardEntries[shardIdIndex] != null) {
// ignore null responses here
shardData.put(shardIdReverseKey.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
shardData.put(arrayToShardId.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
}
}
return shardData;
}

private void fillShardIdKeys(Set<ShardId> shardIds) {
int shardIdIndex = 0;
for (ShardId shardId : shardIds) {
this.shardIdKey.putIfAbsent(shardId, shardIdIndex.getAndIncrement());
this.shardIdToArray.putIfAbsent(shardId, shardIdIndex++);
}
this.shardIdKey.keySet().removeIf(shardId -> {
this.shardIdToArray.keySet().removeIf(shardId -> {
if (!shardIds.contains(shardId)) {
deleteShard(shardId);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.transport.TransportResponse;

import java.io.IOException;

Expand All @@ -20,7 +19,7 @@
*
* @opensearch.internal
*/
public abstract class BaseShardResponse extends TransportResponse {
public abstract class BaseShardResponse {

private Exception storeException;

Expand All @@ -42,7 +41,6 @@ public BaseShardResponse(StreamInput in) throws IOException {
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (storeException != null) {
out.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
public class TransportNodesGatewayStartedShardHelper {

public static final String INDEX_NOT_FOUND = "node doesn't have meta data for index";

public static TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard getShardInfoOnLocalNode(
Logger logger,
final ShardId shardId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,28 +378,26 @@ public Map<ShardId, NodeGatewayStartedShard> getNodeGatewayStartedShardsBatch()

public NodeGatewayStartedShardsBatch(StreamInput in) throws IOException {
super(in);
this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new,
i -> {
if (i.readBoolean()) {
return new NodeGatewayStartedShard(i);
} else {
return null;
}
});
this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, i -> {
if (i.readBoolean()) {
return new NodeGatewayStartedShard(i);
} else {
return null;
}
});
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o),
(o, v) -> {
if (v != null) {
o.writeBoolean(true);
v.writeTo(o);
} else {
o.writeBoolean(false);
}
});
out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o), (o, v) -> {
if (v != null) {
o.writeBoolean(true);
v.writeTo(o);
} else {
o.writeBoolean(false);
}
});
}

public NodeGatewayStartedShardsBatch(DiscoveryNode node, Map<ShardId, NodeGatewayStartedShard> nodeGatewayStartedShardsBatch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ public void setUp() throws Exception {
HashMap<ShardId, ShardAttributes> shardToCustomDataPath = new HashMap<>();
ShardId shardId0 = new ShardId("index1", "index_uuid1", 0);
ShardId shardId1 = new ShardId("index2", "index_uuid2", 0);
shardToCustomDataPath.put(shardId0, new ShardAttributes(shardId0, ""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(shardId1, ""));
shardToCustomDataPath.put(shardId0, new ShardAttributes(""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(""));
this.test = new TestFetch(threadPool, shardToCustomDataPath);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public class ShardBatchCacheTests extends OpenSearchAllocationTestCase {
private static final String BATCH_ID = "b1";
private final DiscoveryNode node1 = newNode("node1");
private final DiscoveryNode node2 = newNode("node2");
// Compilation would pass once ShardsBatchGatewayAllocator is committed in main
private final Map<ShardId, ShardsBatchGatewayAllocator.ShardEntry> batchInfo = new HashMap<>();
// Needs to be enabled once ShardsBatchGatewayAllocator is pushed
// private final Map<ShardId, ShardsBatchGatewayAllocator.ShardEntry> batchInfo = new HashMap<>();
private AsyncShardBatchFetch.ShardBatchCache<NodeGatewayStartedShardsBatch, NodeGatewayStartedShard> shardCache;
private List<ShardId> shardsInBatch = new ArrayList<>();
private static final int NUMBER_OF_SHARDS_DEFAULT = 10;
Expand Down Expand Up @@ -85,8 +85,7 @@ public void testGetCacheData() {
this.shardCache.initData(node1);
this.shardCache.initData(node2);
this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch,
ResponseType.EMPTY)));
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY)));
assertTrue(
this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null)
.get(node1)
Expand Down Expand Up @@ -118,8 +117,7 @@ public void testPutData() {
this.shardCache.initData(node1);
this.shardCache.initData(node2);
this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch,
ResponseType.VALID)));
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.VALID)));
this.shardCache.putData(node2, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY)));

Map<DiscoveryNode, NodeGatewayStartedShardsBatch> fetchData = shardCache.getCacheData(
Expand All @@ -141,11 +139,12 @@ public void testNullResponses() {
setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT);
this.shardCache.initData(node1);
this.shardCache.markAsFetching(List.of(node1.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch,
ResponseType.NULL)));
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.NULL)));

Map<DiscoveryNode, NodeGatewayStartedShardsBatch> fetchData = shardCache.getCacheData(
DiscoveryNodes.builder().add(node1).build(), null);
DiscoveryNodes.builder().add(node1).build(),
null
);
assertTrue(fetchData.get(node1).getNodeGatewayStartedShardsBatch().isEmpty());
}

Expand All @@ -154,12 +153,13 @@ public void testFilterFailedShards() {
this.shardCache.initData(node1);
this.shardCache.initData(node2);
this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1,
getFailedPrimaryResponse(shardsInBatch, 5)));
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getFailedPrimaryResponse(shardsInBatch, 5)));
Map<DiscoveryNode, NodeGatewayStartedShardsBatch> fetchData = shardCache.getCacheData(
DiscoveryNodes.builder().add(node1).add(node2).build(), null);
DiscoveryNodes.builder().add(node1).add(node2).build(),
null
);

assertEquals(5, batchInfo.size());
// assertEquals(5, batchInfo.size());
assertEquals(2, fetchData.size());
assertEquals(5, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size());
assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().isEmpty());
Expand All @@ -186,35 +186,35 @@ private Map<ShardId, NodeGatewayStartedShard> getPrimaryResponse(List<ShardId> s
return shardData;
}

private Map<ShardId, NodeGatewayStartedShard> getFailedPrimaryResponse(List<ShardId> shards,
int failedShardsCount) {
private Map<ShardId, NodeGatewayStartedShard> getFailedPrimaryResponse(List<ShardId> shards, int failedShardsCount) {
int allocationId = 1;
Map<ShardId, NodeGatewayStartedShard> shardData = new HashMap<>();
for (ShardId shard : shards) {
if (failedShardsCount-- > 0) {
shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null,
new OpenSearchRejectedExecutionException()));
shardData.put(
shard,
new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, new OpenSearchRejectedExecutionException())
);
} else {
shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null,
null));
shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, null));
}
}
return shardData;
}

public void removeShard(ShardId shardId) {
batchInfo.remove(shardId);
// batchInfo.remove(shardId);
}

private void fillShards(Map<ShardId, ShardAttributes> shardAttributesMap, int numberOfShards) {
shardsInBatch = BatchTestUtil.setUpShards(numberOfShards);
for (ShardId shardId : shardsInBatch) {
ShardAttributes attr = new ShardAttributes("");
shardAttributesMap.put(shardId, attr);
batchInfo.put(
shardId,
new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id()))
);
// batchInfo.put(
// shardId,
// new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id()))
// );
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,19 @@ public class ShardAttributesTests extends OpenSearchTestCase {
String customDataPath = "/path/to/data";

public void testShardAttributesConstructor() {
ShardAttributes attributes = new ShardAttributes(shardId, customDataPath);
assertEquals(attributes.getShardId(), shardId);
ShardAttributes attributes = new ShardAttributes(customDataPath);
assertEquals(attributes.getCustomDataPath(), customDataPath);
}

public void testSerialization() throws IOException {
ShardAttributes attributes1 = new ShardAttributes(shardId, customDataPath);
ShardAttributes attributes1 = new ShardAttributes(customDataPath);
ByteArrayOutputStream bytes = new ByteArrayOutputStream();
StreamOutput output = new DataOutputStreamOutput(new DataOutputStream(bytes));
attributes1.writeTo(output);
output.close();
StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(bytes.toByteArray()));
ShardAttributes attributes2 = new ShardAttributes(input);
input.close();
assertEquals(attributes1.getShardId(), attributes2.getShardId());
assertEquals(attributes1.getCustomDataPath(), attributes2.getCustomDataPath());
}

Expand Down

0 comments on commit dc86547

Please sign in to comment.