diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java index b787bae5e5057..481f1835034f2 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java @@ -13,26 +13,34 @@ import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.logging.Loggers; import org.opensearch.core.index.shard.ShardId; import org.opensearch.indices.store.ShardAttributes; +import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; /** - * Implementation of AsyncShardFetchAbstract with batching support. - * cache will be created using ShardBatchCache class as that can handle the caching strategy correctly for a - * batch of shards. Other necessary functions are also stored so cache can store or get the data for both primary - * and replicas. + * Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch + * part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here. + * Cleanup of failed shards is necessary in a batch and based on that a reroute should be triggered to take care of + * those in the next run. This separation also takes care of the extra generic type V which is only needed for batch + * transport actions like {@link TransportNodesListGatewayStartedShardsBatch}. * * @param Response type of the transport action. * @param Data type of shard level response. + * + * @opensearch.internal */ public abstract class AsyncShardBatchFetch extends AsyncShardFetch { @@ -47,7 +55,7 @@ public abstract class AsyncShardBatchFetch, T> action, String batchId, Class clazz, - BiFunction, T> responseConstructor, + BiFunction, T> responseGetter, Function> shardsBatchDataGetter, Supplier emptyResponseBuilder, Consumer handleFailedShard @@ -61,21 +69,13 @@ public abstract class AsyncShardBatchFetch fetchData(DiscoveryNodes nodes, Map> ignoreNodes) { if (failedShards.isEmpty() == false) { // trigger a reroute if there are any shards failed, to make sure they're picked up in next run @@ -85,13 +85,7 @@ public synchronized FetchResult fetchData(DiscoveryNodes nodes, Map Response type of transport action. + * @param Data type of shard level response. + */ + public static class ShardBatchCache extends AsyncShardFetchCache { + private final Map> cache = new HashMap<>(); + private final Map shardIdKey = new HashMap<>(); + private final AtomicInteger shardIdIndex = new AtomicInteger(); + private final int batchSize; + private final Class shardResponseClass; + private final BiFunction, T> responseConstructor; + private final Map shardIdReverseKey = new HashMap<>(); + private final Function> shardsBatchDataGetter; + private final Supplier emptyResponseBuilder; + private final Consumer handleFailedShard; + + public ShardBatchCache( + Logger logger, + String type, + Map shardAttributesMap, + String logKey, + Class clazz, + BiFunction, T> responseGetter, + Function> shardsBatchDataGetter, + Supplier emptyResponseBuilder, + Consumer handleFailedShard + ) { + super(Loggers.getLogger(logger, "_" + logKey), type); + this.batchSize = shardAttributesMap.size(); + fillShardIdKeys(shardAttributesMap.keySet()); + this.shardResponseClass = clazz; + this.responseConstructor = responseGetter; + this.shardsBatchDataGetter = shardsBatchDataGetter; + this.emptyResponseBuilder = emptyResponseBuilder; + this.handleFailedShard = handleFailedShard; + } + + @Override + public Map getCache() { + return cache; + } + + @Override + public void deleteShard(ShardId shardId) { + if (shardIdKey.containsKey(shardId)) { + Integer shardIdIndex = shardIdKey.remove(shardId); + for (String nodeId : cache.keySet()) { + cache.get(nodeId).clearShard(shardIdIndex); + } + } + } + + @Override + public Map getCacheData(DiscoveryNodes nodes, Set failedNodes) { + refreshReverseIdMap(); + return super.getCacheData(nodes, failedNodes); + } + + /** + * Build a reverse map to get shardId from the array index, this will be used to construct the response which + * PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for. + */ + private void refreshReverseIdMap() { + shardIdReverseKey.clear(); + for (ShardId shardId : shardIdKey.keySet()) { + shardIdReverseKey.putIfAbsent(shardIdKey.get(shardId), shardId); + } + } + + @Override + public void initData(DiscoveryNode node) { + cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize)); + } + + /** + * Put the response received from data nodes into the cache. + * Get shard level data from batch, then filter out if any shards received failures. + * After that complete storing the data at node level and mark fetching as done. + * + * @param node node from which we got the response. + * @param response shard metadata coming from node. + */ + @Override + public void putData(DiscoveryNode node, T response) { + NodeEntry nodeEntry = cache.get(node.getId()); + Map batchResponse = shardsBatchDataGetter.apply(response); + filterFailedShards(batchResponse); + nodeEntry.doneFetching(batchResponse, shardIdKey); + } + + /** + * Return the shard for which we got unhandled exceptions. + * + * @param batchResponse response from one node for the batch. + */ + private void filterFailedShards(Map batchResponse) { + logger.trace("filtering failed shards"); + for (Iterator it = batchResponse.keySet().iterator(); it.hasNext();) { + ShardId shardId = it.next(); + if (batchResponse.get(shardId) != null) { + if (batchResponse.get(shardId).getException() != null) { + // handle per shard level exceptions, process other shards, only throw out this shard from + // the batch + Exception shardException = batchResponse.get(shardId).getException(); + // if the request got rejected or timed out, we need to try it again next time... + if (retryableException(shardException)) { + logger.trace( + "got unhandled retryable exception for shard {} {}", + shardId.toString(), + shardException.toString() + ); + handleFailedShard.accept(shardId); + // remove this failed entry. So, while storing the data, we don't need to re-process it. + it.remove(); + } + } + } + } + } + + @Override + public T getData(DiscoveryNode node) { + return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId()))); + } + + private HashMap getBatchData(NodeEntry nodeEntry) { + V[] nodeShardEntries = nodeEntry.getData(); + boolean[] emptyResponses = nodeEntry.getEmptyShardResponse(); + HashMap shardData = new HashMap<>(); + for (Integer shardIdIndex : shardIdKey.values()) { + if (emptyResponses[shardIdIndex]) { + shardData.put(shardIdReverseKey.get(shardIdIndex), emptyResponseBuilder.get()); + } else if (nodeShardEntries[shardIdIndex] != null) { + // ignore null responses here + shardData.put(shardIdReverseKey.get(shardIdIndex), nodeShardEntries[shardIdIndex]); + } + } + return shardData; + } + + private void fillShardIdKeys(Set shardIds) { + for (ShardId shardId : shardIds) { + this.shardIdKey.putIfAbsent(shardId, shardIdIndex.getAndIncrement()); + } + this.shardIdKey.keySet().removeIf(shardId -> { + if (!shardIds.contains(shardId)) { + deleteShard(shardId); + return true; + } else { + return false; + } + }); + } + + /** + * A node entry, holding the state of the fetched data for a specific shard + * for a giving node. + */ + static class NodeEntry extends BaseNodeEntry { + private final V[] shardData; + private final boolean[] emptyShardResponse; + + NodeEntry(String nodeId, Class clazz, int batchSize) { + super(nodeId); + this.shardData = (V[]) Array.newInstance(clazz, batchSize); + this.emptyShardResponse = new boolean[batchSize]; + } + + void doneFetching(Map shardDataFromNode, Map shardIdKey) { + fillShardData(shardDataFromNode, shardIdKey); + super.doneFetching(); + } + + void clearShard(Integer shardIdIndex) { + this.shardData[shardIdIndex] = null; + } + + V[] getData() { + return this.shardData; + } + + boolean[] getEmptyShardResponse() { + return emptyShardResponse; + } + + private void fillShardData(Map shardDataFromNode, Map shardIdKey) { + for (ShardId shardId : shardDataFromNode.keySet()) { + if (shardDataFromNode.get(shardId) != null) { + if (shardDataFromNode.get(shardId).isEmpty()) { + this.emptyShardResponse[shardIdKey.get(shardId)] = true; + this.shardData[shardIdKey.get(shardId)] = null; + } else if (shardDataFromNode.get(shardId).getException() == null) { + this.shardData[shardIdKey.get(shardId)] = shardDataFromNode.get(shardId); + } + // if exception is not null, we got unhandled failure for the shard which needs to be ignored + } + } + } + } } } diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java index bcadc5be1d1e2..5abcef37be42f 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardFetch.java @@ -65,11 +65,9 @@ * Allows to asynchronously fetch shard related data from other nodes for allocation, without blocking * the cluster update thread. *

- * The async fetch logic maintains a map of which nodes are being fetched from in an async manner, - * and once the results are back, it makes sure to schedule a reroute to make sure those results will - * be taken into account. + * The async fetch logic maintains a cache {@link AsyncShardFetchCache} which is filled in async manner when nodes respond back. + * It also schedules a reroute to make sure those results will be taken into account. * - * It comes in two modes, to single fetch a shard or fetch a batch of shards. * @opensearch.internal */ public abstract class AsyncShardFetch implements Releasable { @@ -86,7 +84,7 @@ public interface Lister, N protected final String type; protected final Map shardAttributesMap; private final Lister, T> action; - private final Map> cache = new HashMap<>(); + final Map> cache = new HashMap<>(); private final AtomicLong round = new AtomicLong(); private boolean closed; protected final String reroutingKey; diff --git a/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java b/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java deleted file mode 100644 index 336a5b5c94a42..0000000000000 --- a/server/src/main/java/org/opensearch/gateway/ShardBatchCache.java +++ /dev/null @@ -1,241 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.gateway; - -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchTimeoutException; -import org.opensearch.action.support.nodes.BaseNodeResponse; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.common.Nullable; -import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.indices.store.ShardAttributes; -import org.opensearch.transport.ReceiveTimeoutTransportException; - -import java.lang.reflect.Array; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.BiFunction; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Cache implementation of transport actions returning batch of shards data in the response. Cache uses a specific - * NodeEntry class that stores the data in array format. To keep the class generic for primary or replica, all - * functions are stored during object creation. - * - * @param Response type of transport action. - * @param Data type of shard level response. - */ -public class ShardBatchCache extends BaseShardCache { - private final Map> cache; - private final Map shardIdToArray; // used for mapping array index for a shard - private final AtomicInteger shardIdIndex; - private final int batchSize; - private final Class shardResponseClass; - private final BiFunction, T> responseConstructor; - private final Map arrayToShardId; - private final Function> shardsBatchDataGetter; - private final Supplier emptyResponseBuilder; - private final Consumer handleFailedShard; - - public ShardBatchCache( - Logger logger, - String type, - Map shardToCustomDataPath, - String logKey, - Class clazz, - BiFunction, T> responseConstructor, - Function> shardsBatchDataGetter, - Supplier emptyResponseBuilder, - Consumer handleFailedShard - ) { - super(logger, logKey, type); - this.batchSize = shardToCustomDataPath.size(); - fillShardIdKeys(shardToCustomDataPath.keySet()); - this.shardResponseClass = clazz; - this.responseConstructor = responseConstructor; - this.shardsBatchDataGetter = shardsBatchDataGetter; - this.emptyResponseBuilder = emptyResponseBuilder; - cache = new HashMap<>(); - shardIdToArray = new HashMap<>(); - arrayToShardId = new HashMap<>(); - shardIdIndex = new AtomicInteger(); - this.handleFailedShard = handleFailedShard; - } - - @Override - public Map getCache() { - return cache; - } - - @Override - public void deleteData(ShardId shardId) { - if (shardIdToArray.containsKey(shardId)) { - Integer shardIdIndex = shardIdToArray.remove(shardId); - for (String nodeId : cache.keySet()) { - cache.get(nodeId).clearShard(shardIdIndex); - } - } - } - - @Override - public Map getCacheData(DiscoveryNodes nodes, Set failedNodes) { - refreshReverseIdMap(); - return super.getCacheData(nodes, failedNodes); - } - - /** - * Build a reverse map to get shardId from the array index, this will be used to construct the response which - * PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for. - */ - private void refreshReverseIdMap() { - arrayToShardId.clear(); - for (ShardId shardId : shardIdToArray.keySet()) { - arrayToShardId.putIfAbsent(shardIdToArray.get(shardId), shardId); - } - } - - @Override - public void initData(DiscoveryNode node) { - cache.put(node.getId(), new NodeEntry<>(node.getId(), shardResponseClass, batchSize)); - } - - /** - * Put the response received from data nodes into the cache. - * Get shard level data from batch, then filter out if any shards received failures. - * After that, complete storing the data at node level and mark fetching as done. - * @param node node from which we got the response. - * @param response shard metadata coming from node. - */ - @Override - public void putData(DiscoveryNode node, T response) { - NodeEntry nodeEntry = cache.get(node.getId()); - Map batchResponse = shardsBatchDataGetter.apply(response); - filterFailedShards(batchResponse); - nodeEntry.doneFetching(batchResponse, shardIdToArray); - } - - /** - * Return the shard for which we got unhandled exceptions. - * - * @param batchResponse response from one node for the batch. - */ - private void filterFailedShards(Map batchResponse) { - for (Iterator it = batchResponse.keySet().iterator(); it.hasNext();) { - ShardId shardId = it.next(); - if (batchResponse.get(shardId) != null) { - if (batchResponse.get(shardId).getException() != null) { - // handle per shard level exceptions, process other shards, only throw out this shard from - // the batch - Exception shardException = batchResponse.get(shardId).getException(); - // if the request got rejected or timed out, we need to try it again next time... - if (retryableException(shardException)) { - logger.trace("got unhandled retryable exception for shard {} {}", shardId.toString(), - shardException.toString()); - handleFailedShard.accept(shardId); - // remove this failed entry. So, while storing the data, we don't need to re-process it. - it.remove(); - } - } - } - } - } - - @Override - public T getData(DiscoveryNode node) { - return this.responseConstructor.apply(node, getBatchData(cache.get(node.getId()))); - } - - private HashMap getBatchData(NodeEntry nodeEntry) { - V[] nodeShardEntries = nodeEntry.getData(); - boolean[] emptyResponses = nodeEntry.getEmptyShardResponse(); - HashMap shardData = new HashMap<>(); - for (Integer shardIdIndex : shardIdToArray.values()) { - if (emptyResponses[shardIdIndex]) { - shardData.put(arrayToShardId.get(shardIdIndex), emptyResponseBuilder.get()); - } else if (nodeShardEntries[shardIdIndex] != null) { - // ignore null responses here - shardData.put(arrayToShardId.get(shardIdIndex), nodeShardEntries[shardIdIndex]); - } - } - return shardData; - } - - private void fillShardIdKeys(Set shardIds) { - for (ShardId shardId : shardIds) { - this.shardIdToArray.putIfAbsent(shardId, shardIdIndex.getAndIncrement()); - } - this.shardIdToArray.keySet().removeIf(shardId -> { - if (!shardIds.contains(shardId)) { - deleteData(shardId); - return true; - } else { - return false; - } - }); - } - - /** - * A node entry, holding the state of the fetched data for a specific shard - * for a giving node. This will only store the data from TransportNodesListGatewayStartedShardsBatch or - * TransportNodesListShardStoreMetadataBatch transport actions. - */ - static class NodeEntry extends BaseShardCache.BaseNodeEntry { - @Nullable - private final V[] shardData; - private final boolean[] emptyShardResponse; - - NodeEntry(String nodeId, Class clazz, int batchSize) { - super(nodeId); - this.shardData = (V[]) Array.newInstance(clazz, batchSize); - this.emptyShardResponse = new boolean[batchSize]; - } - - void doneFetching(Map shardDataFromNode, Map shardIdKey) { - fillShardData(shardDataFromNode, shardIdKey); - super.doneFetching(); - } - - void clearShard(Integer shardIdIndex) { - this.shardData[shardIdIndex] = null; - } - - V[] getData() { - return this.shardData; - } - - boolean[] getEmptyShardResponse() { - return emptyShardResponse; - } - - private void fillShardData(Map shardDataFromNode, Map shardIdKey) { - for (ShardId shardId : shardDataFromNode.keySet()) { - if (shardDataFromNode.get(shardId) != null) { - if (shardDataFromNode.get(shardId).isEmpty()) { - this.emptyShardResponse[shardIdKey.get(shardId)] = true; - this.shardData[shardIdKey.get(shardId)] = null; - } else if (shardDataFromNode.get(shardId).getException() == null) { - this.shardData[shardIdKey.get(shardId)] = shardDataFromNode.get(shardId); - } - // if exception is not null, we got unhandled failure for the shard which needs to be ignored - } - } - } - } - -}