Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Async Fetcher class changes #8742

Merged
merged 12 commits into from
Jan 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void start() {
} else {
for (Tuple<ShardId, String> shard : shards) {
InternalAsyncFetch fetch = new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), listShardStoresInfo);
fetch.fetchData(nodes, Collections.<String>emptySet());
fetch.fetchData(nodes, Collections.emptyMap());
}
}
}
Expand Down Expand Up @@ -223,7 +223,7 @@ protected synchronized void processAsyncFetch(
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardId, responses, failures));
fetchResponses.add(new Response(shardToCustomDataPath.keySet().iterator().next(), responses, failures));
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
if (expectedOps.countDown()) {
finish();
}
Expand Down Expand Up @@ -312,7 +312,7 @@ private boolean shardExistsInNode(final NodeGatewayStartedShards response) {
}

@Override
protected void reroute(ShardId shardId, String reason) {
protected void reroute(String shardId, String reason) {
// no-op
}

Expand Down
124 changes: 88 additions & 36 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;

/**
* Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking
Expand All @@ -77,18 +77,22 @@ public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Rel
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(ShardId shardId, @Nullable String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, String> shardIdsWithCustomDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved

}

protected final Logger logger;
protected final String type;
protected final ShardId shardId;
protected final String customDataPath;

protected final Map<ShardId, String> shardToCustomDataPath;
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final Set<String> nodesToIgnore = new HashSet<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String logKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Expand All @@ -100,9 +104,27 @@ protected AsyncShardFetch(
) {
this.logger = logger;
this.type = type;
this.shardId = Objects.requireNonNull(shardId);
this.customDataPath = Objects.requireNonNull(customDataPath);
shardToCustomDataPath = new HashMap<>();
shardToCustomDataPath.put(Objects.requireNonNull(shardId), Objects.requireNonNull(customDataPath));
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
Logger logger,
String type,
Map<ShardId, String> shardToCustomDataPath,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardToCustomDataPath = Objects.requireNonNull(shardToCustomDataPath);
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "BatchID=[" + batchId + "]";
enableBatchMode = true;
}

@Override
Expand Down Expand Up @@ -130,11 +152,27 @@ public synchronized int getNumberOfInFlightFetches() {
* The ignoreNodes are nodes that are supposed to be ignored for this round, since fetching is async, we need
* to keep them around and make sure we add them back when all the responses are fetched and returned.
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> ignoreNodes) {
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
if (closed) {
throw new IllegalStateException(shardId + ": can't fetch data on closed async fetch");
throw new IllegalStateException(logKey + ": can't fetch data on closed async fetch");
}

if (enableBatchMode == false) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
// we will do assertions here on ignoreNodes
assert ignoreNodes.size() <= 1 : "Can only have at-most one shard";
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
if (ignoreNodes.size() == 1) {
assert shardToCustomDataPath.containsKey(ignoreNodes.keySet().iterator().next())
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
: "ShardId should be same as initialised in fetcher";
}
}

// add the nodes to ignore to the list of nodes to ignore for each shard
for (Map.Entry<ShardId, Set<String>> ignoreNodesEntry : ignoreNodes.entrySet()) {
Set<String> ignoreNodesSet = shardToIgnoreNodes.getOrDefault(ignoreNodesEntry.getKey(), new HashSet<>());
ignoreNodesSet.addAll(ignoreNodesEntry.getValue());
shardToIgnoreNodes.put(ignoreNodesEntry.getKey(), ignoreNodesSet);
}
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
nodesToIgnore.addAll(ignoreNodes);

fillShardCacheWithDataNodes(cache, nodes);
List<NodeEntry<T>> nodesToFetch = findNodesToFetch(cache);
if (nodesToFetch.isEmpty() == false) {
Expand All @@ -153,7 +191,7 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i

// if we are still fetching, return null to indicate it
if (hasAnyNodeFetching(cache)) {
return new FetchResult<>(shardId, null, emptySet());
return new FetchResult<>(null, emptyMap());
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
} else {
// nothing to fetch, yay, build the return value
Map<DiscoveryNode, T> fetchData = new HashMap<>();
Expand All @@ -177,16 +215,27 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
}
}
}
Set<String> allIgnoreNodes = unmodifiableSet(new HashSet<>(nodesToIgnore));

Map<ShardId, Set<String>> allIgnoreNodesMap = unmodifiableMap(new HashMap<>(shardToIgnoreNodes));
// clear the nodes to ignore, we had a successful run in fetching everything we can
// we need to try them if another full run is needed
nodesToIgnore.clear();
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false || allIgnoreNodes.isEmpty() == false) {
reroute(shardId, "nodes failed [" + failedNodes.size() + "], ignored [" + allIgnoreNodes.size() + "]");
// If ignore node even for a single shard in batch of shards then also do a reroute
if (failedNodes.isEmpty() == false
|| allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
reroute(
logKey,
Gaurav614 marked this conversation as resolved.
Show resolved Hide resolved
"nodes failed ["
+ failedNodes.size()
+ "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum()
+ "]"
);
}
return new FetchResult<>(shardId, fetchData, allIgnoreNodes);

return new FetchResult<>(fetchData, allIgnoreNodesMap);
}
}

Expand All @@ -199,10 +248,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Set<String> i
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", shardId, type);
logger.trace("{} ignoring fetched [{}] results, already closed", logKey, type);
return;
}
logger.trace("{} processing fetched [{}] results", shardId, type);
logger.trace("{} processing fetched [{}] results", logKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -212,7 +261,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -221,29 +270,29 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", shardId, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", logKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", shardId, failure, type);
logger.trace("{} processing failure {} for [{}]", logKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
shardId,
logKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -261,7 +310,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
shardId,
logKey,
type,
failure.nodeId()
),
Expand All @@ -273,13 +322,13 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
}
}
}
reroute(shardId, "post_response");
reroute(logKey, "post_response");
}

/**
* Implement this in order to scheduled another round that causes a call to fetch data.
*/
protected abstract void reroute(ShardId shardId, String reason);
protected abstract void reroute(String reroutingKey, String reason);

/**
* Clear cache for node, ensuring next fetch will fetch a fresh copy.
Expand Down Expand Up @@ -334,8 +383,8 @@ private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", shardId, type, nodes);
action.list(shardId, customDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", logKey, type, nodes);
action.list(shardToCustomDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand All @@ -358,14 +407,12 @@ public void onFailure(Exception e) {
*/
public static class FetchResult<T extends BaseNodeResponse> {

private final ShardId shardId;
private final Map<DiscoveryNode, T> data;
private final Set<String> ignoreNodes;
private final Map<ShardId, Set<String>> ignoredShardToNodes;

public FetchResult(ShardId shardId, Map<DiscoveryNode, T> data, Set<String> ignoreNodes) {
this.shardId = shardId;
public FetchResult(Map<DiscoveryNode, T> data, Map<ShardId, Set<String>> ignoreNodes) {
this.data = data;
this.ignoreNodes = ignoreNodes;
this.ignoredShardToNodes = ignoreNodes;
}

/**
Expand All @@ -389,9 +436,14 @@ public Map<DiscoveryNode, T> getData() {
* Process any changes needed to the allocation based on this fetch result.
*/
public void processAllocation(RoutingAllocation allocation) {
for (String ignoreNode : ignoreNodes) {
allocation.addIgnoreShardForNode(shardId, ignoreNode);
for (Map.Entry<ShardId, Set<String>> entry : ignoredShardToNodes.entrySet()) {
ShardId shardId = entry.getKey();
Set<String> ignoreNodes = entry.getValue();
if (ignoreNodes.isEmpty() == false) {
ignoreNodes.forEach(nodeId -> allocation.addIgnoreShardForNode(shardId, nodeId));
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.indices.store.TransportNodesListShardStoreMetadata;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Spliterators;
Expand Down Expand Up @@ -226,7 +227,9 @@ private static void clearCacheForPrimary(
AsyncShardFetch<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> fetch,
RoutingAllocation allocation
) {
ShardRouting primary = allocation.routingNodes().activePrimary(fetch.shardId);
assert fetch.shardToCustomDataPath.size() == 1 : "expected only one shard";
ShardId shardId = fetch.shardToCustomDataPath.keySet().iterator().next();
ShardRouting primary = allocation.routingNodes().activePrimary(shardId);
if (primary != null) {
fetch.clearCacheForNode(primary.currentNodeId());
}
Expand Down Expand Up @@ -254,15 +257,15 @@ class InternalAsyncFetch<T extends BaseNodeResponse> extends AsyncShardFetch<T>
}

@Override
protected void reroute(ShardId shardId, String reason) {
logger.trace("{} scheduling reroute for {}", shardId, reason);
protected void reroute(String logKey, String reason) {
logger.trace("{} scheduling reroute for {}", logKey, reason);
assert rerouteService != null;
rerouteService.reroute(
"async_shard_fetch",
Priority.HIGH,
ActionListener.wrap(
r -> logger.trace("{} scheduled reroute completed for {}", shardId, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", shardId, reason), e)
r -> logger.trace("{} scheduled reroute completed for {}", logKey, reason),
e -> logger.debug(new ParameterizedMessage("{} scheduled reroute failed for {}", logKey, reason), e)
)
);
}
Expand Down Expand Up @@ -293,7 +296,11 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.Nod
);
AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShards.NodeGatewayStartedShards> shardState = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}
}
);

if (shardState.hasData()) {
Expand Down Expand Up @@ -328,7 +335,11 @@ protected AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeS
);
AsyncShardFetch.FetchResult<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata> shardStores = fetch.fetchData(
allocation.nodes(),
allocation.getIgnoreNodes(shard.shardId())
new HashMap<>() {
{
put(shard.shardId(), allocation.getIgnoreNodes(shard.shardId()));
}
}
);
if (shardStores.hasData()) {
shardStores.processAllocation(allocation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
Expand Down Expand Up @@ -124,7 +125,14 @@ public TransportNodesListGatewayStartedShards(
}

@Override
public void list(ShardId shardId, String customDataPath, DiscoveryNode[] nodes, ActionListener<NodesGatewayStartedShards> listener) {
public void list(
Map<ShardId, String> shardIdsWithCustomDataPath,
DiscoveryNode[] nodes,
ActionListener<NodesGatewayStartedShards> listener
) {
assert shardIdsWithCustomDataPath.size() == 1 : "only one shard should be specified";
final ShardId shardId = shardIdsWithCustomDataPath.keySet().iterator().next();
final String customDataPath = shardIdsWithCustomDataPath.get(shardId);
execute(new Request(shardId, customDataPath, nodes), listener);
}

Expand Down
Loading