Skip to content

Commit

Permalink
Move ShardBatchCache as inner class of AsyncShardBatchFetch
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Mar 14, 2024
1 parent 8a29f28 commit dcc15d0
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 270 deletions.
245 changes: 221 additions & 24 deletions server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,34 @@
import org.opensearch.action.support.nodes.BaseNodesResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.logging.Loggers;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Implementation of AsyncShardFetchAbstract with batching support.
* cache will be created using ShardBatchCache class as that can handle the caching strategy correctly for a
* batch of shards. Other necessary functions are also stored so cache can store or get the data for both primary
* and replicas.
* Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch
* part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here.
* Cleanup of failed shards is necessary in a batch and based on that a reroute should be triggered to take care of
* those in the next run. This separation also takes care of the extra generic type V which is only needed for batch
* transport actions like {@link TransportNodesListGatewayStartedShardsBatch}.
*
* @param <T> Response type of the transport action.
* @param <V> Data type of shard level response.
*
* @opensearch.internal
*/
public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardFetch<T> {

Expand All @@ -47,7 +55,7 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
AsyncShardFetch.Lister<? extends BaseNodesResponse<T>, T> action,
String batchId,
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseConstructor,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseGetter,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> handleFailedShard
Expand All @@ -61,21 +69,13 @@ public abstract class AsyncShardBatchFetch<T extends BaseNodeResponse, V extends
shardAttributesMap,
"BatchID=[" + batchId + "]",
clazz,
responseConstructor,
responseGetter,
shardsBatchDataGetter,
emptyResponseBuilder,
handleFailedShard
this::cleanUpFailedShards
);
}

/**
* Fetch the data for a batch of shards, this uses the already written {@link AsyncShardFetch} fetchData method.
* Based on the shards failed in last round, it makes sure to trigger a reroute for them.
*
* @param nodes all the nodes where transport call should be sent
* @param ignoreNodes nodes to update based on failures received from transport actions
* @return data received from the transport actions
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (failedShards.isEmpty() == false) {
// trigger a reroute if there are any shards failed, to make sure they're picked up in next run
Expand All @@ -85,13 +85,7 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId,
return super.fetchData(nodes, ignoreNodes);
}

/**
* Remove the shard from shardAttributesMap, so we don't send it in next fetching round.
* Remove shard from the batch, so it gets picked up in a new batch in next reroute.
*
* @param shardId shardId to be cleaned up
*/
private void cleanUpFailedShard(ShardId shardId) {
private void cleanUpFailedShards(ShardId shardId) {
shardAttributesMap.remove(shardId);
removeShardFromBatch.accept(shardId);
failedShards.add(shardId);
Expand All @@ -104,7 +98,210 @@ private void cleanUpFailedShard(ShardId shardId) {
* @param shardId shardId to be removed from the batch.
*/
public void clearShard(ShardId shardId) {
shardAttributesMap.remove(shardId);
cache.deleteData(shardId);
this.shardAttributesMap.remove(shardId);
this.cache.deleteShard(shardId);
}

/**
* Cache implementation of transport actions returning batch of shards related data in the response. It'll
*
* @param <T> Response type of transport action.
* @param <V> Data type of shard level response.
*/
public static class ShardBatchCache<T extends BaseNodeResponse, V extends BaseShardResponse> extends AsyncShardFetchCache<T> {
private final Map<String, NodeEntry<V>> cache = new HashMap<>();
private final Map<ShardId, Integer> shardIdKey = new HashMap<>();
private final AtomicInteger shardIdIndex = new AtomicInteger();
private final int batchSize;
private final Class<V> shardResponseClass;
private final BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseConstructor;
private final Map<Integer, ShardId> shardIdReverseKey = new HashMap<>();
private final Function<T, Map<ShardId, V>> shardsBatchDataGetter;
private final Supplier<V> emptyResponseBuilder;
private final Consumer<ShardId> handleFailedShard;

public ShardBatchCache(
Logger logger,
String type,
Map<ShardId, ShardAttributes> shardAttributesMap,
String logKey,
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseGetter,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> handleFailedShard
) {
super(Loggers.getLogger(logger, "_" + logKey), type);
this.batchSize = shardAttributesMap.size();
fillShardIdKeys(shardAttributesMap.keySet());
this.shardResponseClass = clazz;
this.responseConstructor = responseGetter;
this.shardsBatchDataGetter = shardsBatchDataGetter;
this.emptyResponseBuilder = emptyResponseBuilder;
this.handleFailedShard = handleFailedShard;
}

@Override
public Map<String, ? extends BaseNodeEntry> getCache() {
return cache;
}

@Override
public void deleteShard(ShardId shardId) {
if (shardIdKey.containsKey(shardId)) {
Integer shardIdIndex = shardIdKey.remove(shardId);
for (String nodeId : cache.keySet()) {
cache.get(nodeId).clearShard(shardIdIndex);
}
}
}

@Override
public Map<DiscoveryNode, T> getCacheData(DiscoveryNodes nodes, Set<String> failedNodes) {
refreshReverseIdMap();
return super.getCacheData(nodes, failedNodes);
}

/**
* Build a reverse map to get shardId from the array index, this will be used to construct the response which
* PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for.
*/
private void refreshReverseIdMap() {
shardIdReverseKey.clear();
for (ShardId shardId : shardIdKey.keySet()) {
shardIdReverseKey.putIfAbsent(shardIdKey.get(shardId), shardId);
}
}

@Override
public void initData(DiscoveryNode node) {
cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize));
}

/**
* Put the response received from data nodes into the cache.
* Get shard level data from batch, then filter out if any shards received failures.
* After that complete storing the data at node level and mark fetching as done.
*
* @param node node from which we got the response.
* @param response shard metadata coming from node.
*/
@Override
public void putData(DiscoveryNode node, T response) {
NodeEntry<V> nodeEntry = cache.get(node.getId());
Map<ShardId, V> batchResponse = shardsBatchDataGetter.apply(response);
filterFailedShards(batchResponse);
nodeEntry.doneFetching(batchResponse, shardIdKey);
}

/**
* Return the shard for which we got unhandled exceptions.
*
* @param batchResponse response from one node for the batch.
*/
private void filterFailedShards(Map<ShardId, V> batchResponse) {
logger.trace("filtering failed shards");
for (Iterator<ShardId> it = batchResponse.keySet().iterator(); it.hasNext();) {
ShardId shardId = it.next();
if (batchResponse.get(shardId) != null) {
if (batchResponse.get(shardId).getException() != null) {
// handle per shard level exceptions, process other shards, only throw out this shard from
// the batch
Exception shardException = batchResponse.get(shardId).getException();
// if the request got rejected or timed out, we need to try it again next time...
if (retryableException(shardException)) {
logger.trace(
"got unhandled retryable exception for shard {} {}",
shardId.toString(),
shardException.toString()
);
handleFailedShard.accept(shardId);
// remove this failed entry. So, while storing the data, we don't need to re-process it.
it.remove();
}
}
}
}
}

@Override
public T getData(DiscoveryNode node) {
return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId())));
}

private HashMap<ShardId, V> getBatchData(NodeEntry<V> nodeEntry) {
V[] nodeShardEntries = nodeEntry.getData();
boolean[] emptyResponses = nodeEntry.getEmptyShardResponse();
HashMap<ShardId, V> shardData = new HashMap<>();
for (Integer shardIdIndex : shardIdKey.values()) {
if (emptyResponses[shardIdIndex]) {
shardData.put(shardIdReverseKey.get(shardIdIndex), emptyResponseBuilder.get());
} else if (nodeShardEntries[shardIdIndex] != null) {
// ignore null responses here
shardData.put(shardIdReverseKey.get(shardIdIndex), nodeShardEntries[shardIdIndex]);
}
}
return shardData;
}

private void fillShardIdKeys(Set<ShardId> shardIds) {
for (ShardId shardId : shardIds) {
this.shardIdKey.putIfAbsent(shardId, shardIdIndex.getAndIncrement());
}
this.shardIdKey.keySet().removeIf(shardId -> {
if (!shardIds.contains(shardId)) {
deleteShard(shardId);
return true;
} else {
return false;
}
});
}

/**
* A node entry, holding the state of the fetched data for a specific shard
* for a giving node.
*/
static class NodeEntry<V extends BaseShardResponse> extends BaseNodeEntry {
private final V[] shardData;
private final boolean[] emptyShardResponse;

NodeEntry(String nodeId, Class<V> clazz, int batchSize) {
super(nodeId);
this.shardData = (V[]) Array.newInstance(clazz, batchSize);
this.emptyShardResponse = new boolean[batchSize];
}

void doneFetching(Map<ShardId, V> shardDataFromNode, Map<ShardId, Integer> shardIdKey) {
fillShardData(shardDataFromNode, shardIdKey);
super.doneFetching();
}

void clearShard(Integer shardIdIndex) {
this.shardData[shardIdIndex] = null;
}

V[] getData() {
return this.shardData;
}

boolean[] getEmptyShardResponse() {
return emptyShardResponse;
}

private void fillShardData(Map<ShardId, V> shardDataFromNode, Map<ShardId, Integer> shardIdKey) {
for (ShardId shardId : shardDataFromNode.keySet()) {
if (shardDataFromNode.get(shardId) != null) {
if (shardDataFromNode.get(shardId).isEmpty()) {
this.emptyShardResponse[shardIdKey.get(shardId)] = true;
this.shardData[shardIdKey.get(shardId)] = null;
} else if (shardDataFromNode.get(shardId).getException() == null) {
this.shardData[shardIdKey.get(shardId)] = shardDataFromNode.get(shardId);
}
// if exception is not null, we got unhandled failure for the shard which needs to be ignored
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
* the cluster update thread.
* <p>
* The async fetch logic maintains a map of which nodes are being fetched from in an async manner,
* and once the results are back, it makes sure to schedule a reroute to make sure those results will
* be taken into account.
* The async fetch logic maintains a cache {@link AsyncShardFetchCache} which is filled in async manner when nodes respond back.
* It also schedules a reroute to make sure those results will be taken into account.
*
* It comes in two modes, to single fetch a shard or fetch a batch of shards.
* @opensearch.internal
*/
public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Releasable {
Expand All @@ -86,7 +84,7 @@ public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, N
protected final String type;
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
protected final String reroutingKey;
Expand Down
Loading

0 comments on commit dcc15d0

Please sign in to comment.