Skip to content

Commit

Permalink
Addressed PR comments(opensearch-project#8742)
Browse files Browse the repository at this point in the history
1. Instead of using Map<ShardId,String> in Fetcher class now using
Map<ShardId,ShardAttributes> for making code more extensible

2. Added UT for newly added constructor in fetcher class

3. Renamed logKey to RerouteKey

4. Add IllegalStateException for non-batched fetch

Signed-off-by: Gaurav Chandani <[email protected]>
  • Loading branch information
Gaurav614 committed Dec 5, 2023
1 parent a720578 commit b55d34c
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.index.shard.ShardPath;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.io.IOException;
Expand All @@ -39,14 +40,14 @@ public class TransportNodesListGatewayStartedBatchShardsIT extends OpenSearchInt

public void testSingleShardFetch() throws Exception {
String indexName = "test";
Map<ShardId, String> shardIdCustomDataPathMap = prepareRequestMap(new String[] { indexName }, 1);
Map<ShardId, ShardAttributes> shardAttributesMap = prepareRequestMap(new String[] { indexName }, 1);

ClusterSearchShardsResponse searchShardsResponse = client().admin().cluster().prepareSearchShards(indexName).get();

TransportNodesListGatewayStartedBatchShards.NodesGatewayStartedShardsBatch response;
response = ActionTestUtils.executeBlocking(
internalCluster().getInstance(TransportNodesListGatewayStartedBatchShards.class),
new TransportNodesListGatewayStartedBatchShards.Request(searchShardsResponse.getNodes(), shardIdCustomDataPathMap)
new TransportNodesListGatewayStartedBatchShards.Request(searchShardsResponse.getNodes(), shardAttributesMap)
);
final Index index = resolveIndex(indexName);
final ShardId shardId = new ShardId(index, 0);
Expand All @@ -63,7 +64,7 @@ public void testShardFetchMultiNodeMultiIndexes() throws Exception {
String indexName1 = "test1";
String indexName2 = "test2";
// assign one primary shard each to the data nodes
Map<ShardId, String> shardIdCustomDataPathMap = prepareRequestMap(
Map<ShardId, ShardAttributes> shardAttributesMap = prepareRequestMap(
new String[] { indexName1, indexName2 },
internalCluster().numDataNodes()
);
Expand All @@ -72,7 +73,7 @@ public void testShardFetchMultiNodeMultiIndexes() throws Exception {
TransportNodesListGatewayStartedBatchShards.NodesGatewayStartedShardsBatch response;
response = ActionTestUtils.executeBlocking(
internalCluster().getInstance(TransportNodesListGatewayStartedBatchShards.class),
new TransportNodesListGatewayStartedBatchShards.Request(searchShardsResponse.getNodes(), shardIdCustomDataPathMap)
new TransportNodesListGatewayStartedBatchShards.Request(searchShardsResponse.getNodes(), shardAttributesMap)
);
for (ClusterSearchShardsGroup clusterSearchShardsGroup : searchShardsResponse.getGroups()) {
ShardId shardId = clusterSearchShardsGroup.getShardId();
Expand All @@ -88,7 +89,7 @@ public void testShardFetchMultiNodeMultiIndexes() throws Exception {

public void testShardFetchCorruptedShards() throws Exception {
String indexName = "test";
Map<ShardId, String> shardIdCustomDataPathMap = prepareRequestMap(new String[] { indexName }, 1);
Map<ShardId, ShardAttributes> shardAttributes = prepareRequestMap(new String[] { indexName }, 1);
ClusterSearchShardsResponse searchShardsResponse = client().admin().cluster().prepareSearchShards(indexName).get();
final Index index = resolveIndex(indexName);
final ShardId shardId = new ShardId(index, 0);
Expand All @@ -97,7 +98,7 @@ public void testShardFetchCorruptedShards() throws Exception {
internalCluster().restartNode(searchShardsResponse.getNodes()[0].getName());
response = ActionTestUtils.executeBlocking(
internalCluster().getInstance(TransportNodesListGatewayStartedBatchShards.class),
new TransportNodesListGatewayStartedBatchShards.Request(getDiscoveryNodes(), shardIdCustomDataPathMap)
new TransportNodesListGatewayStartedBatchShards.Request(getDiscoveryNodes(), shardAttributes)
);
DiscoveryNode[] discoveryNodes = getDiscoveryNodes();
TransportNodesListGatewayStartedBatchShards.NodeGatewayStartedShards nodeGatewayStartedShards = response.getNodesMap()
Expand Down Expand Up @@ -137,8 +138,8 @@ private void prepareIndex(String indexName, int numberOfPrimaryShards) {
flush(indexName);
}

private Map<ShardId, String> prepareRequestMap(String[] indices, int primaryShardCount) {
Map<ShardId, String> shardIdCustomDataPathMap = new HashMap<>();
private Map<ShardId, ShardAttributes> prepareRequestMap(String[] indices, int primaryShardCount) {
Map<ShardId, ShardAttributes> shardAttributesMap = new HashMap<>();
for (String indexName : indices) {
prepareIndex(indexName, primaryShardCount);
final Index index = resolveIndex(indexName);
Expand All @@ -147,10 +148,10 @@ private Map<ShardId, String> prepareRequestMap(String[] indices, int primaryShar
);
for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) {
final ShardId shardId = new ShardId(index, shardIdNum);
shardIdCustomDataPathMap.put(shardId, customDataPath);
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
}
}
return shardIdCustomDataPathMap;
return shardAttributesMap;
}

private void corruptShard(String nodeName, ShardId shardId) throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ protected synchronized void processAsyncFetch(
List<FailedNodeException> failures,
long fetchingRound
) {
fetchResponses.add(new Response(shardToCustomDataPath.keySet().iterator().next(), responses, failures));
fetchResponses.add(new Response(shardAttributesMap.keySet().iterator().next(), responses, failures));
if (expectedOps.countDown()) {
finish();
}
Expand Down
80 changes: 44 additions & 36 deletions server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.transport.ReceiveTimeoutTransportException;

import java.util.ArrayList;
Expand All @@ -68,6 +69,7 @@
* and once the results are back, it makes sure to schedule a reroute to make sure those results will
* be taken into account.
*
* It comes in two modes, to single fetch a shard or fetch a batch of shards.
* @opensearch.internal
*/
public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Releasable {
Expand All @@ -76,19 +78,18 @@ public abstract class AsyncShardFetch<T extends BaseNodeResponse> implements Rel
* An action that lists the relevant shard data that needs to be fetched.
*/
public interface Lister<NodesResponse extends BaseNodesResponse<NodeResponse>, NodeResponse extends BaseNodeResponse> {
void list(Map<ShardId, String> shardIdsWithCustomDataPath, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);
void list(Map<ShardId, ShardAttributes> shardAttributesMap, DiscoveryNode[] nodes, ActionListener<NodesResponse> listener);

}

protected final Logger logger;
protected final String type;

protected final Map<ShardId, String> shardToCustomDataPath;
protected final Map<ShardId, ShardAttributes> shardAttributesMap;
private final Lister<BaseNodesResponse<T>, T> action;
private final Map<String, NodeEntry<T>> cache = new HashMap<>();
private final AtomicLong round = new AtomicLong();
private boolean closed;
private final String logKey;
private final String reroutingKey;
private final Map<ShardId, Set<String>> shardToIgnoreNodes = new HashMap<>();

private final boolean enableBatchMode;
Expand All @@ -103,26 +104,35 @@ protected AsyncShardFetch(
) {
this.logger = logger;
this.type = type;
shardToCustomDataPath = new HashMap<>();
shardToCustomDataPath.put(shardId, customDataPath);
shardAttributesMap =new HashMap<>();
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "ShardId=[" + shardId.toString() + "]";
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
enableBatchMode = false;
}

/**
* Added to fetch a batch of shards from nodes
*
* @param logger Logger
* @param type type of action
* @param shardAttributesMap Map of {@link ShardId} to {@link ShardAttributes} to perform fetching on them a
* @param action Transport Action
* @param batchId For the given ShardAttributesMap, we expect them to tie with a single batch id for logging and later identification
*/
@SuppressWarnings("unchecked")
protected AsyncShardFetch(
Logger logger,
String type,
Map<ShardId, String> shardToCustomDataPath,
Map<ShardId, ShardAttributes> shardAttributesMap,
Lister<? extends BaseNodesResponse<T>, T> action,
String batchId
) {
this.logger = logger;
this.type = type;
this.shardToCustomDataPath = shardToCustomDataPath;
this.shardAttributesMap = shardAttributesMap;
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.logKey = "BatchID=[" + batchId + "]";
this.reroutingKey = "BatchID=[" + batchId+ "]";
enableBatchMode = true;
}

Expand Down Expand Up @@ -153,15 +163,19 @@ public synchronized int getNumberOfInFlightFetches() {
*/
public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId, Set<String>> ignoreNodes) {
if (closed) {
throw new IllegalStateException(logKey + ": can't fetch data on closed async fetch");
throw new IllegalStateException(reroutingKey + ": can't fetch data on closed async fetch");
}

if (enableBatchMode == false) {
// we will do assertions here on ignoreNodes
assert ignoreNodes.size() <= 1 : "Can only have at-most one shard";
if (ignoreNodes.size() == 1) {
assert shardToCustomDataPath.containsKey(ignoreNodes.keySet().iterator().next())
: "ShardId should be same as initialised in fetcher";
if (ignoreNodes.size() > 1) {
throw new IllegalStateException("Fetching Shard Data, " + reroutingKey + "Can only have atmost one shard" +
"for non-batch mode" );
}
if(ignoreNodes.size() == 1) {
if (shardAttributesMap.containsKey(ignoreNodes.keySet().iterator().next())) {
throw new IllegalStateException("Shard Id must be same as initialized in AsyncShardFetch. Expecting = " + reroutingKey);
}
}
}

Expand Down Expand Up @@ -221,16 +235,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId,
shardToIgnoreNodes.clear();
// if at least one node failed, make sure to have a protective reroute
// here, just case this round won't find anything, and we need to retry fetching data
if (failedNodes.isEmpty() == false
|| allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
reroute(
logKey,
"nodes failed ["
+ failedNodes.size()
+ "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum()
+ "]"
);

if (failedNodes.isEmpty() == false || allIgnoreNodesMap.values().stream().anyMatch(ignoreNodeSet -> ignoreNodeSet.isEmpty() == false)) {
reroute(reroutingKey, "nodes failed [" + failedNodes.size() + "], ignored ["
+ allIgnoreNodesMap.values().stream().mapToInt(Set::size).sum() + "]");
}

return new FetchResult<>(fetchData, allIgnoreNodesMap);
Expand All @@ -246,10 +254,10 @@ public synchronized FetchResult<T> fetchData(DiscoveryNodes nodes, Map<ShardId,
protected synchronized void processAsyncFetch(List<T> responses, List<FailedNodeException> failures, long fetchingRound) {
if (closed) {
// we are closed, no need to process this async fetch at all
logger.trace("{} ignoring fetched [{}] results, already closed", logKey, type);
logger.trace("{} ignoring fetched [{}] results, already closed", reroutingKey, type);
return;
}
logger.trace("{} processing fetched [{}] results", logKey, type);
logger.trace("{} processing fetched [{}] results", reroutingKey, type);

if (responses != null) {
for (T response : responses) {
Expand All @@ -259,7 +267,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received response for [{}] from node {} for an older fetching round (expected: {} but was: {})",
logKey,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -268,29 +276,29 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
} else if (nodeEntry.isFailed()) {
logger.trace(
"{} node {} has failed for [{}] (failure [{}])",
logKey,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFailure()
);
} else {
// if the entry is there, for the right fetching round and not marked as failed already, process it
logger.trace("{} marking {} as done for [{}], result is [{}]", logKey, nodeEntry.getNodeId(), type, response);
logger.trace("{} marking {} as done for [{}], result is [{}]", reroutingKey, nodeEntry.getNodeId(), type, response);
nodeEntry.doneFetching(response);
}
}
}
}
if (failures != null) {
for (FailedNodeException failure : failures) {
logger.trace("{} processing failure {} for [{}]", logKey, failure, type);
logger.trace("{} processing failure {} for [{}]", reroutingKey, failure, type);
NodeEntry<T> nodeEntry = cache.get(failure.nodeId());
if (nodeEntry != null) {
if (nodeEntry.getFetchingRound() != fetchingRound) {
assert nodeEntry.getFetchingRound() > fetchingRound : "node entries only replaced by newer rounds";
logger.trace(
"{} received failure for [{}] from node {} for an older fetching round (expected: {} but was: {})",
logKey,
reroutingKey,
nodeEntry.getNodeId(),
type,
nodeEntry.getFetchingRound(),
Expand All @@ -308,7 +316,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
logger.warn(
() -> new ParameterizedMessage(
"{}: failed to list shard for {} on node [{}]",
logKey,
reroutingKey,
type,
failure.nodeId()
),
Expand All @@ -320,7 +328,7 @@ protected synchronized void processAsyncFetch(List<T> responses, List<FailedNode
}
}
}
reroute(logKey, "post_response");
reroute(reroutingKey, "post_response");
}

/**
Expand Down Expand Up @@ -381,8 +389,8 @@ private boolean hasAnyNodeFetching(Map<String, NodeEntry<T>> shardCache) {
*/
// visible for testing
void asyncFetch(final DiscoveryNode[] nodes, long fetchingRound) {
logger.trace("{} fetching [{}] from {}", logKey, type, nodes);
action.list(shardToCustomDataPath, nodes, new ActionListener<BaseNodesResponse<T>>() {
logger.trace("{} fetching [{}] from {}", reroutingKey, type, nodes);
action.list(shardAttributesMap, nodes, new ActionListener<BaseNodesResponse<T>>() {
@Override
public void onResponse(BaseNodesResponse<T> response) {
processAsyncFetch(response.getNodes(), response.failures(), fetchingRound);
Expand Down
Loading

0 comments on commit b55d34c

Please sign in to comment.