Skip to content

Commit

Permalink
Batch Async Fetcher class changes (opensearch-project#8742)
Browse files Browse the repository at this point in the history
* Async Fetcher class changes

Signed-off-by: Gaurav Chandani <[email protected]>
  • Loading branch information
Gaurav614 authored and rayshrey committed Mar 18, 2024
1 parent 30dff5d commit 6930f68
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void start() {
} else {
for (Tuple<ShardId, String> shard : shards) {
InternalAsyncFetch fetch = new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), listShardStoresInfo);
fetch.fetchData(nodes, Collections.<String>emptySet());
fetch.fetchData(nodes, Collections.emptyMap());
}
}
}
Expand Down Expand Up @@ -223,7 +223,7 @@ protected synchronized void processAsyncFetch(
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardId, responses, failures));
fetchResponses.add(new Response(shardAttributesMap.keySet().iterator().next(), responses, failures));
if (expectedOps.countDown()) {
finish();
}
Expand Down Expand Up @@ -312,7 +312,7 @@ private boolean shardExistsInNode(final NodeGatewayStartedShards response) {
}

@Override
protected void reroute(ShardId shardId, String reason) {
protected void reroute(String shardId, String reason) {
// no-op
}

Expand Down
140 changes: 103 additions & 37 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.transport.ReceiveTimeoutTransportException;

import java.util.ArrayList;
Expand All @@ -54,12 +55,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;

/**
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
Expand All @@ -69,6 +69,7 @@
* and once the results are back, it makes sure to schedule a reroute to make sure those results will
* be taken into account.
*
* It comes in two modes, to single fetch a shard or fetch a batch of shards.
* @opensearch.internal
*/
public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Releasable {
Expand All @@ -77,18 +78,21 @@ public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Rel
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(ShardId shardId, @Nullable String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, ShardAttributes> shardAttributesMap, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);

}

protected final Logger logger;
protected final String type;
protected final ShardId shardId;
protected final String customDataPath;
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final Set<String> nodesToIgnore = new HashSet<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String reroutingKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Expand All @@ -100,9 +104,36 @@ protected AsyncShardFetch(
) {
this.logger = logger;
this.type = type;
this.shardId = Objects.requireNonNull(shardId);
this.customDataPath = Objects.requireNonNull(customDataPath);
shardAttributesMap = new HashMap<>();
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

/**
* Added to fetch a batch of shards from nodes
*
* @param logger Logger
* @param type type of action
* @param shardAttributesMap Map of {@link ShardId} to {@link ShardAttributes} to perform fetching on them a
* @param action Transport Action
* @param batchId For the given ShardAttributesMap, we expect them to tie with a single batch id for logging and later identification
*/
@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Logger logger,
String type,
Map<ShardId, ShardAttributes> shardAttributesMap,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardAttributesMap = shardAttributesMap;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "BatchID=[" + batchId + "]";
enableBatchMode = true;
}

@Override
Expand Down Expand Up @@ -130,11 +161,32 @@ public synchronized int getNumberOfInFlightFetches() {
* The ignoreNodes are nodes that are supposed to be ignored for this round, since fetching is async, we need
* to keep them around and make sure we add them back when all the responses are fetched and returned.
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> ignoreNodes) {
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (closed) {
throw new IllegalStateException(shardId + ": can't fetch data on closed async fetch");
throw new IllegalStateException(reroutingKey + ": can't fetch data on closed async fetch");
}
nodesToIgnore.addAll(ignoreNodes);

if (enableBatchMode == false) {
// we will do assertions here on ignoreNodes
if (ignoreNodes.size() > 1) {
throw new IllegalStateException(
"Fetching Shard Data, " + reroutingKey + "Can only have atmost one shard" + "for non-batch mode"
);
}
if (ignoreNodes.size() == 1) {
if (shardAttributesMap.containsKey(ignoreNodes.keySet().iterator().next()) == false) {
throw new IllegalStateException("Shard Id must be same as initialized in AsyncShardFetch. Expecting = " + reroutingKey);
}
}
}

// add the nodes to ignore to the list of nodes to ignore for each shard
for (Map.Entry<ShardId, Set<String>> ignoreNodesEntry : ignoreNodes.entrySet()) {
Set<String> ignoreNodesSet = shardToIgnoreNodes.getOrDefault(ignoreNodesEntry.getKey(), new HashSet<>());
ignoreNodesSet.addAll(ignoreNodesEntry.getValue());
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
Expand All @@ -153,7 +205,7 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
return new FetchResult<>(shardId, null, emptySet());
return new FetchResult<>(null, emptyMap());
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Expand All @@ -177,16 +229,27 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
}
}
}
Set<String> allIgnoreNodes = unmodifiableSet(new HashSet<>(nodesToIgnore));

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
// we need to try them if another full run is needed
nodesToIgnore.clear();
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false || allIgnoreNodes.isEmpty() == false) {
reroute(shardId, "nodes failed [" + failedNodes.size() + "], ignored [" + allIgnoreNodes.size() + "]");

if (failedNodes.isEmpty() == false
|| allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
reroute(
reroutingKey,
"nodes failed ["
+ failedNodes.size()
+ "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum()
+ "]"
);
}
return new FetchResult<>(shardId, fetchData, allIgnoreNodes);

return new FetchResult<>(fetchData, allIgnoreNodesMap);
}
}

Expand All @@ -199,10 +262,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", shardId, type);
logger.trace("{} ignoring fetched [{}] results, already closed", reroutingKey, type);
return;
}
logger.trace("{} processing fetched [{}] results", shardId, type);
logger.trace("{} processing fetched [{}] results", reroutingKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -212,7 +275,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -221,29 +284,29 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", shardId, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", reroutingKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", shardId, failure, type);
logger.trace("{} processing failure {} for [{}]", reroutingKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -261,7 +324,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
shardId,
reroutingKey,
type,
failure.nodeId()
),
Expand All @@ -273,13 +336,13 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
}
}
}
reroute(shardId, "post_response");
reroute(reroutingKey, "post_response");
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
protected abstract void reroute(ShardId shardId, String reason);
protected abstract void reroute(String reroutingKey, String reason);

/**
* Clear cache for node, ensuring next fetch will fetch a fresh copy.
Expand Down Expand Up @@ -334,8 +397,8 @@ private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", shardId, type, nodes);
action.list(shardId, customDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", reroutingKey, type, nodes);
action.list(shardAttributesMap, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand All @@ -358,14 +421,12 @@ public void onFailure(Exception e) {
*/
public static class FetchResult<T extends BaseNodeResponse> {

private final ShardId shardId;
private final Map<DiscoveryNode, T> data;
private final Set<String> ignoreNodes;
private final Map<ShardId, Set<String>> ignoredShardToNodes;

public FetchResult(ShardId shardId, Map<DiscoveryNode, T> data, Set<String> ignoreNodes) {
this.shardId = shardId;
public FetchResult(Map<DiscoveryNode, T> data, Map<ShardId, Set<String>> ignoreNodes) {
this.data = data;
this.ignoreNodes = ignoreNodes;
this.ignoredShardToNodes = ignoreNodes;
}

/**
Expand All @@ -389,9 +450,14 @@ public Map<DiscoveryNode, T> getData() {
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for (String ignoreNode : ignoreNodes) {
allocation.addIgnoreShardForNode(shardId, ignoreNode);
for (Map.Entry<ShardId, Set<String>> entry : ignoredShardToNodes.entrySet()) {
ShardId shardId = entry.getKey();
Set<String> ignoreNodes = entry.getValue();
if (ignoreNodes.isEmpty() == false) {
ignoreNodes.forEach(nodeId -> allocation.addIgnoreShardForNode(shardId, nodeId));
}
}

}
}

Expand Down
25 changes: 18 additions & 7 deletions server/src/main/java/org/opensearch/gateway/GatewayAllocator.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.indices.store.TransportNodesListShardStoreMetadata;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Spliterators;
Expand Down Expand Up @@ -226,7 +227,9 @@ private static void clearCacheForPrimary(
AsyncShardFetch<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> fetch,
RoutingAllocation allocation
) {
ShardRouting primary = allocation.routingNodes().activePrimary(fetch.shardId);
assert fetch.shardAttributesMap.size() == 1 : "expected only one shard";
ShardId shardId = fetch.shardAttributesMap.keySet().iterator().next();
ShardRouting primary = allocation.routingNodes().activePrimary(shardId);
if (primary != null) {
fetch.clearCacheForNode(primary.currentNodeId());
}
Expand Down Expand Up @@ -254,15 +257,15 @@ class InternalAsyncFetch<T extends BaseNodeResponse> extends AsyncShardFetch<T>
}

@Override
protected void reroute(ShardId shardId, String reason) {
logger.trace("{} scheduling reroute for {}", shardId, reason);
protected void reroute(String reroutingKey, String reason) {
logger.trace("{} scheduling reroute for {}", reroutingKey, reason);
assert rerouteService != null;
rerouteService.reroute(
"async_shard_fetch",
Priority.HIGH,
ActionListener.wrap(
r -> logger.trace("{} scheduled reroute completed for {}", shardId, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", shardId, reason), e)
r -> logger.trace("{} scheduled reroute completed for {}", reroutingKey, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", reroutingKey, reason), e)
)
);
}
Expand Down Expand Up @@ -293,7 +296,11 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.Nod
);
AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.NodeGatewayStartedShards> shardState = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}
}
);

if (shardState.hasData()) {
Expand Down Expand Up @@ -328,7 +335,11 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeS
);
AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> shardStores = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}
}
);
if (shardStores.hasData()) {
shardStores.processAllocation(allocation);
Expand Down
Loading

0 comments on commit 6930f68

Please sign in to comment.