Skip to content

Commit

Permalink
Use failed shards data only when all nodes fetching is done
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Mar 12, 2024
1 parent 7f55cb7 commit 4a1ff54
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static Map<ShardId, ShardAttributes> prepareRequestMap(String[] indices,
);
for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) {
final ShardId shardId = new ShardId(index, shardIdNum);
shardIdShardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
shardIdShardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
}
}
return shardIdShardAttributesMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import reactor.util.annotation.NonNull;

/**
* Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch
* part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here.
Expand Down Expand Up @@ -77,14 +78,27 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
}

public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
reroute("shards-failed", "shards failed in " + reroutingKey);
FetchResult<T> result = super.fetchData(nodes, ignoreNodes);
if (result.hasData()) {
// trigger reroute for failed shards only when all nodes have completed fetching
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
logger.trace("triggering another reroute for failed shards in {}", reroutingKey);
reroute("shards-failed", "shards failed in " + reroutingKey);
failedShards.clear();
}
}
return super.fetchData(nodes, ignoreNodes);
return result;
}

/**
* Remove the shard from shardAttributesMap so it's not sent in next asyncFetch.
* Call removeShardFromBatch method to remove the shardId from the batch object created in
* ShardsBatchGatewayAllocator.
* Add shardId to failedShards, so it can be used to trigger another reroute as part of upcoming fetchData call.
*
* @param shardId shardId to be cleaned up from batch and cache.
*/
private void cleanUpFailedShards(ShardId shardId) {
shardAttributesMap.remove(shardId);
removeShardFromBatch.accept(shardId);
Expand All @@ -109,13 +123,12 @@ public void clearShard(ShardId shardId) {
* @param <V> Data type of shard level response.
*/
public static class ShardBatchCache<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardFetchCache<T> {
private final Map<String, NodeEntry<V>> cache = new HashMap<>();
private final Map<ShardId, Integer> shardIdKey = new HashMap<>();
private final AtomicInteger shardIdIndex = new AtomicInteger();
private final Map<String, NodeEntry<V>> cache;
private final Map<ShardId, Integer> shardIdToArray;
private final int batchSize;
private final Class<V> shardResponseClass;
private final BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseConstructor;
private final Map<Integer, ShardId> shardIdReverseKey = new HashMap<>();
private final Map<Integer, ShardId> arrayToShardId;
private final Function<T, Map<ShardId, V>> shardsBatchDataGetter;
private final Supplier<V> emptyResponseBuilder;
private final Consumer<ShardId> handleFailedShard;
Expand All @@ -133,6 +146,9 @@ public ShardBatchCache(
) {
super(Loggers.getLogger(logger, "_" + logKey), type);
this.batchSize = shardAttributesMap.size();
cache = new HashMap<>();
shardIdToArray = new HashMap<>();
arrayToShardId = new HashMap<>();
fillShardIdKeys(shardAttributesMap.keySet());
this.shardResponseClass = clazz;
this.responseConstructor = responseGetter;
Expand All @@ -148,8 +164,8 @@ public ShardBatchCache(

@Override
public void deleteShard(ShardId shardId) {
if (shardIdKey.containsKey(shardId)) {
Integer shardIdIndex = shardIdKey.remove(shardId);
if (shardIdToArray.containsKey(shardId)) {
Integer shardIdIndex = shardIdToArray.remove(shardId);
for (String nodeId : cache.keySet()) {
cache.get(nodeId).clearShard(shardIdIndex);
}
Expand All @@ -167,9 +183,9 @@ public Map<DiscoveryNode, T> getCacheData(DiscoveryNodes nodes, Set<String> fail
* PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for.
*/
private void refreshReverseIdMap() {
shardIdReverseKey.clear();
for (ShardId shardId : shardIdKey.keySet()) {
shardIdReverseKey.putIfAbsent(shardIdKey.get(shardId), shardId);
arrayToShardId.clear();
for (ShardId shardId : shardIdToArray.keySet()) {
arrayToShardId.putIfAbsent(shardIdToArray.get(shardId), shardId);
}
}

Expand All @@ -191,7 +207,7 @@ public void putData(DiscoveryNode node, T response) {
NodeEntry<V> nodeEntry = cache.get(node.getId());
Map<ShardId, V> batchResponse = shardsBatchDataGetter.apply(response);
filterFailedShards(batchResponse);
nodeEntry.doneFetching(batchResponse, shardIdKey);
nodeEntry.doneFetching(batchResponse, shardIdToArray);
}

/**
Expand Down Expand Up @@ -233,22 +249,23 @@ private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
V[] nodeShardEntries = nodeEntry.getData();
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
HashMap<ShardId, V> shardData = new HashMap<>();
for (Integer shardIdIndex : shardIdKey.values()) {
for (Integer shardIdIndex : shardIdToArray.values()) {
if (emptyResponses[shardIdIndex]) {
shardData.put(shardIdReverseKey.get(shardIdIndex), emptyResponseBuilder.get());
shardData.put(arrayToShardId.get(shardIdIndex), emptyResponseBuilder.get());
} else if (nodeShardEntries[shardIdIndex] != null) {
// ignore null responses here
shardData.put(shardIdReverseKey.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
shardData.put(arrayToShardId.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
}
}
return shardData;
}

private void fillShardIdKeys(Set<ShardId> shardIds) {
int shardIdIndex = 0;
for (ShardId shardId : shardIds) {
this.shardIdKey.putIfAbsent(shardId, shardIdIndex.getAndIncrement());
this.shardIdToArray.putIfAbsent(shardId, shardIdIndex++);
}
this.shardIdKey.keySet().removeIf(shardId -> {
this.shardIdToArray.keySet().removeIf(shardId -> {
if (!shardIds.contains(shardId)) {
deleteShard(shardId);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.transport.TransportResponse;

import java.io.IOException;

Expand All @@ -20,7 +19,7 @@
*
* @opensearch.internal
*/
public abstract class BaseShardResponse extends TransportResponse {
public abstract class BaseShardResponse {

private Exception storeException;

Expand All @@ -42,7 +41,6 @@ public BaseShardResponse(StreamInput in) throws IOException {
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (storeException != null) {
out.writeBoolean(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
public class TransportNodesGatewayStartedShardHelper {

public static final String INDEX_NOT_FOUND = "node doesn't have meta data for index";

public static TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard getShardInfoOnLocalNode(
Logger logger,
final ShardId shardId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ protected NodesGatewayStartedShardsBatch newResponse(
@Override
protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) {
Map<ShardId, NodeGatewayStartedShard> shardsOnNode = new HashMap<>();
for (ShardAttributes shardAttr : request.shardAttributes.values()) {
final ShardId shardId = shardAttr.getShardId();
for (Map.Entry<ShardId, ShardAttributes> shardAttr : request.shardAttributes.entrySet()) {
final ShardId shardId = shardAttr.getKey();
try {
shardsOnNode.put(
shardId,
Expand All @@ -147,7 +147,7 @@ protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) {
namedXContentRegistry,
nodeEnv,
indicesService,
shardAttr.getCustomDataPath(),
shardAttr.getValue().getCustomDataPath(),
settings,
clusterService
)
Expand Down Expand Up @@ -378,28 +378,26 @@ public Map<ShardId, NodeGatewayStartedShard> getNodeGatewayStartedShardsBatch()

public NodeGatewayStartedShardsBatch(StreamInput in) throws IOException {
super(in);
this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new,
i -> {
if (i.readBoolean()) {
return new NodeGatewayStartedShard(i);
} else {
return null;
}
});
this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, i -> {
if (i.readBoolean()) {
return new NodeGatewayStartedShard(i);
} else {
return null;
}
});
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o),
(o, v) -> {
if (v != null) {
o.writeBoolean(true);
v.writeTo(o);
} else {
o.writeBoolean(false);
}
});
out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o), (o, v) -> {
if (v != null) {
o.writeBoolean(true);
v.writeTo(o);
} else {
o.writeBoolean(false);
}
});
}

public NodeGatewayStartedShardsBatch(DiscoveryNode node, Map<ShardId, NodeGatewayStartedShard> nodeGatewayStartedShardsBatch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.gateway.AsyncShardFetch;

import java.io.IOException;
Expand All @@ -24,24 +23,17 @@
* @opensearch.internal
*/
public class ShardAttributes implements Writeable {
private final ShardId shardId;
@Nullable
private final String customDataPath;

public ShardAttributes(ShardId shardId, String customDataPath) {
this.shardId = shardId;
public ShardAttributes(String customDataPath) {
this.customDataPath = customDataPath;
}

public ShardAttributes(StreamInput in) throws IOException {
shardId = new ShardId(in);
customDataPath = in.readString();
}

public ShardId getShardId() {
return shardId;
}

/**
* Returns the custom data path that is used to look up information for this shard.
* Returns an empty string if no custom data path is used for this index.
Expand All @@ -53,7 +45,6 @@ public String getCustomDataPath() {
}

public void writeTo(StreamOutput out) throws IOException {
shardId.writeTo(out);
out.writeString(customDataPath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ public void setUp() throws Exception {
HashMap<ShardId, ShardAttributes> shardToCustomDataPath = new HashMap<>();
ShardId shardId0 = new ShardId("index1", "index_uuid1", 0);
ShardId shardId1 = new ShardId("index2", "index_uuid2", 0);
shardToCustomDataPath.put(shardId0, new ShardAttributes(shardId0, ""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(shardId1, ""));
shardToCustomDataPath.put(shardId0, new ShardAttributes(""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(""));
this.test = new TestFetch(threadPool, shardToCustomDataPath);
}
}
Expand Down
Loading

0 comments on commit 4a1ff54

Please sign in to comment.