From 4a1ff5471a0c754eb6fb6df6d5490b30ad0c8665 Mon Sep 17 00:00:00 2001 From: Aman Khare Date: Tue, 12 Mar 2024 00:19:30 +0530 Subject: [PATCH] Use failed shards data only when all nodes fetching is done Signed-off-by: Aman Khare --- .../gateway/GatewayRecoveryTestUtils.java | 2 +- .../gateway/AsyncShardBatchFetch.java | 59 ++++++++++++------- .../opensearch/gateway/BaseShardResponse.java | 4 +- ...ansportNodesGatewayStartedShardHelper.java | 1 + ...ortNodesListGatewayStartedShardsBatch.java | 38 ++++++------ .../indices/store/ShardAttributes.java | 11 +--- .../gateway/AsyncShardFetchTests.java | 4 +- .../gateway/ShardBatchCacheTests.java | 55 +++++++++-------- .../indices/store/ShardAttributesTests.java | 6 +- 9 files changed, 93 insertions(+), 87 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/gateway/GatewayRecoveryTestUtils.java b/server/src/internalClusterTest/java/org/opensearch/gateway/GatewayRecoveryTestUtils.java index 2b6a5b4ee6867..dc157681be6fa 100644 --- a/server/src/internalClusterTest/java/org/opensearch/gateway/GatewayRecoveryTestUtils.java +++ b/server/src/internalClusterTest/java/org/opensearch/gateway/GatewayRecoveryTestUtils.java @@ -54,7 +54,7 @@ public static Map prepareRequestMap(String[] indices, ); for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) { final ShardId shardId = new ShardId(index, shardIdNum); - shardIdShardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath)); + shardIdShardAttributesMap.put(shardId, new ShardAttributes(customDataPath)); } } return shardIdShardAttributesMap; diff --git a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java index 481f1835034f2..8da24edf5a9a0 100644 --- a/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java +++ b/server/src/main/java/org/opensearch/gateway/AsyncShardBatchFetch.java @@ -24,12 +24,13 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import reactor.util.annotation.NonNull; + /** * Implementation of AsyncShardFetch with batching support. This class is responsible for executing the fetch * part using the base class {@link AsyncShardFetch}. Other functionalities needed for a batch are only written here. @@ -77,14 +78,27 @@ public abstract class AsyncShardBatchFetch fetchData(DiscoveryNodes nodes, Map> ignoreNodes) { - if (failedShards.isEmpty() == false) { - // trigger a reroute if there are any shards failed, to make sure they're picked up in next run - logger.trace("triggering another reroute for failed shards in {}", reroutingKey); - reroute("shards-failed", "shards failed in " + reroutingKey); + FetchResult result = super.fetchData(nodes, ignoreNodes); + if (result.hasData()) { + // trigger reroute for failed shards only when all nodes have completed fetching + if (failedShards.isEmpty() == false) { + // trigger a reroute if there are any shards failed, to make sure they're picked up in next run + logger.trace("triggering another reroute for failed shards in {}", reroutingKey); + reroute("shards-failed", "shards failed in " + reroutingKey); + failedShards.clear(); + } } - return super.fetchData(nodes, ignoreNodes); + return result; } + /** + * Remove the shard from shardAttributesMap so it's not sent in next asyncFetch. + * Call removeShardFromBatch method to remove the shardId from the batch object created in + * ShardsBatchGatewayAllocator. + * Add shardId to failedShards, so it can be used to trigger another reroute as part of upcoming fetchData call. + * + * @param shardId shardId to be cleaned up from batch and cache. + */ private void cleanUpFailedShards(ShardId shardId) { shardAttributesMap.remove(shardId); removeShardFromBatch.accept(shardId); @@ -109,13 +123,12 @@ public void clearShard(ShardId shardId) { * @param Data type of shard level response. */ public static class ShardBatchCache extends AsyncShardFetchCache { - private final Map> cache = new HashMap<>(); - private final Map shardIdKey = new HashMap<>(); - private final AtomicInteger shardIdIndex = new AtomicInteger(); + private final Map> cache; + private final Map shardIdToArray; private final int batchSize; private final Class shardResponseClass; private final BiFunction, T> responseConstructor; - private final Map shardIdReverseKey = new HashMap<>(); + private final Map arrayToShardId; private final Function> shardsBatchDataGetter; private final Supplier emptyResponseBuilder; private final Consumer handleFailedShard; @@ -133,6 +146,9 @@ public ShardBatchCache( ) { super(Loggers.getLogger(logger, "_" + logKey), type); this.batchSize = shardAttributesMap.size(); + cache = new HashMap<>(); + shardIdToArray = new HashMap<>(); + arrayToShardId = new HashMap<>(); fillShardIdKeys(shardAttributesMap.keySet()); this.shardResponseClass = clazz; this.responseConstructor = responseGetter; @@ -148,8 +164,8 @@ public ShardBatchCache( @Override public void deleteShard(ShardId shardId) { - if (shardIdKey.containsKey(shardId)) { - Integer shardIdIndex = shardIdKey.remove(shardId); + if (shardIdToArray.containsKey(shardId)) { + Integer shardIdIndex = shardIdToArray.remove(shardId); for (String nodeId : cache.keySet()) { cache.get(nodeId).clearShard(shardIdIndex); } @@ -167,9 +183,9 @@ public Map getCacheData(DiscoveryNodes nodes, Set fail * PrimaryShardBatchAllocator or ReplicaShardBatchAllocator are looking for. */ private void refreshReverseIdMap() { - shardIdReverseKey.clear(); - for (ShardId shardId : shardIdKey.keySet()) { - shardIdReverseKey.putIfAbsent(shardIdKey.get(shardId), shardId); + arrayToShardId.clear(); + for (ShardId shardId : shardIdToArray.keySet()) { + arrayToShardId.putIfAbsent(shardIdToArray.get(shardId), shardId); } } @@ -191,7 +207,7 @@ public void putData(DiscoveryNode node, T response) { NodeEntry nodeEntry = cache.get(node.getId()); Map batchResponse = shardsBatchDataGetter.apply(response); filterFailedShards(batchResponse); - nodeEntry.doneFetching(batchResponse, shardIdKey); + nodeEntry.doneFetching(batchResponse, shardIdToArray); } /** @@ -233,22 +249,23 @@ private HashMap getBatchData(NodeEntry nodeEntry) { V[] nodeShardEntries = nodeEntry.getData(); boolean[] emptyResponses = nodeEntry.getEmptyShardResponse(); HashMap shardData = new HashMap<>(); - for (Integer shardIdIndex : shardIdKey.values()) { + for (Integer shardIdIndex : shardIdToArray.values()) { if (emptyResponses[shardIdIndex]) { - shardData.put(shardIdReverseKey.get(shardIdIndex), emptyResponseBuilder.get()); + shardData.put(arrayToShardId.get(shardIdIndex), emptyResponseBuilder.get()); } else if (nodeShardEntries[shardIdIndex] != null) { // ignore null responses here - shardData.put(shardIdReverseKey.get(shardIdIndex), nodeShardEntries[shardIdIndex]); + shardData.put(arrayToShardId.get(shardIdIndex), nodeShardEntries[shardIdIndex]); } } return shardData; } private void fillShardIdKeys(Set shardIds) { + int shardIdIndex = 0; for (ShardId shardId : shardIds) { - this.shardIdKey.putIfAbsent(shardId, shardIdIndex.getAndIncrement()); + this.shardIdToArray.putIfAbsent(shardId, shardIdIndex++); } - this.shardIdKey.keySet().removeIf(shardId -> { + this.shardIdToArray.keySet().removeIf(shardId -> { if (!shardIds.contains(shardId)) { deleteShard(shardId); return true; diff --git a/server/src/main/java/org/opensearch/gateway/BaseShardResponse.java b/server/src/main/java/org/opensearch/gateway/BaseShardResponse.java index 0922abf2c942f..876d9632b5ed8 100644 --- a/server/src/main/java/org/opensearch/gateway/BaseShardResponse.java +++ b/server/src/main/java/org/opensearch/gateway/BaseShardResponse.java @@ -10,7 +10,6 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.transport.TransportResponse; import java.io.IOException; @@ -20,7 +19,7 @@ * * @opensearch.internal */ -public abstract class BaseShardResponse extends TransportResponse { +public abstract class BaseShardResponse { private Exception storeException; @@ -42,7 +41,6 @@ public BaseShardResponse(StreamInput in) throws IOException { } } - @Override public void writeTo(StreamOutput out) throws IOException { if (storeException != null) { out.writeBoolean(true); diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java b/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java index cac88275ce0d9..bea98335d03e8 100644 --- a/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesGatewayStartedShardHelper.java @@ -39,6 +39,7 @@ public class TransportNodesGatewayStartedShardHelper { public static final String INDEX_NOT_FOUND = "node doesn't have meta data for index"; + public static TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard getShardInfoOnLocalNode( Logger logger, final ShardId shardId, diff --git a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java index 86829cfbd12db..10a9538a37d18 100644 --- a/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java +++ b/server/src/main/java/org/opensearch/gateway/TransportNodesListGatewayStartedShardsBatch.java @@ -136,8 +136,8 @@ protected NodesGatewayStartedShardsBatch newResponse( @Override protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) { Map shardsOnNode = new HashMap<>(); - for (ShardAttributes shardAttr : request.shardAttributes.values()) { - final ShardId shardId = shardAttr.getShardId(); + for (Map.Entry shardAttr : request.shardAttributes.entrySet()) { + final ShardId shardId = shardAttr.getKey(); try { shardsOnNode.put( shardId, @@ -147,7 +147,7 @@ protected NodeGatewayStartedShardsBatch nodeOperation(NodeRequest request) { namedXContentRegistry, nodeEnv, indicesService, - shardAttr.getCustomDataPath(), + shardAttr.getValue().getCustomDataPath(), settings, clusterService ) @@ -378,28 +378,26 @@ public Map getNodeGatewayStartedShardsBatch() public NodeGatewayStartedShardsBatch(StreamInput in) throws IOException { super(in); - this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, - i -> { - if (i.readBoolean()) { - return new NodeGatewayStartedShard(i); - } else { - return null; - } - }); + this.nodeGatewayStartedShardsBatch = in.readMap(ShardId::new, i -> { + if (i.readBoolean()) { + return new NodeGatewayStartedShard(i); + } else { + return null; + } + }); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o), - (o, v) -> { - if (v != null) { - o.writeBoolean(true); - v.writeTo(o); - } else { - o.writeBoolean(false); - } - }); + out.writeMap(nodeGatewayStartedShardsBatch, (o, k) -> k.writeTo(o), (o, v) -> { + if (v != null) { + o.writeBoolean(true); + v.writeTo(o); + } else { + o.writeBoolean(false); + } + }); } public NodeGatewayStartedShardsBatch(DiscoveryNode node, Map nodeGatewayStartedShardsBatch) { diff --git a/server/src/main/java/org/opensearch/indices/store/ShardAttributes.java b/server/src/main/java/org/opensearch/indices/store/ShardAttributes.java index 4ef4e91f7af8c..2bff1043d0239 100644 --- a/server/src/main/java/org/opensearch/indices/store/ShardAttributes.java +++ b/server/src/main/java/org/opensearch/indices/store/ShardAttributes.java @@ -12,7 +12,6 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.gateway.AsyncShardFetch; import java.io.IOException; @@ -24,24 +23,17 @@ * @opensearch.internal */ public class ShardAttributes implements Writeable { - private final ShardId shardId; @Nullable private final String customDataPath; - public ShardAttributes(ShardId shardId, String customDataPath) { - this.shardId = shardId; + public ShardAttributes(String customDataPath) { this.customDataPath = customDataPath; } public ShardAttributes(StreamInput in) throws IOException { - shardId = new ShardId(in); customDataPath = in.readString(); } - public ShardId getShardId() { - return shardId; - } - /** * Returns the custom data path that is used to look up information for this shard. * Returns an empty string if no custom data path is used for this index. @@ -53,7 +45,6 @@ public String getCustomDataPath() { } public void writeTo(StreamOutput out) throws IOException { - shardId.writeTo(out); out.writeString(customDataPath); } } diff --git a/server/src/test/java/org/opensearch/gateway/AsyncShardFetchTests.java b/server/src/test/java/org/opensearch/gateway/AsyncShardFetchTests.java index 4e5e9c71e1fe4..3502cc8996fa2 100644 --- a/server/src/test/java/org/opensearch/gateway/AsyncShardFetchTests.java +++ b/server/src/test/java/org/opensearch/gateway/AsyncShardFetchTests.java @@ -92,8 +92,8 @@ public void setUp() throws Exception { HashMap shardToCustomDataPath = new HashMap<>(); ShardId shardId0 = new ShardId("index1", "index_uuid1", 0); ShardId shardId1 = new ShardId("index2", "index_uuid2", 0); - shardToCustomDataPath.put(shardId0, new ShardAttributes(shardId0, "")); - shardToCustomDataPath.put(shardId1, new ShardAttributes(shardId1, "")); + shardToCustomDataPath.put(shardId0, new ShardAttributes("")); + shardToCustomDataPath.put(shardId1, new ShardAttributes("")); this.test = new TestFetch(threadPool, shardToCustomDataPath); } } diff --git a/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java b/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java index 598de3d98f9b1..07bf70e9d98c6 100644 --- a/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java +++ b/server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java @@ -29,8 +29,11 @@ public class ShardBatchCacheTests extends OpenSearchAllocationTestCase { private static final String BATCH_ID = "b1"; private final DiscoveryNode node1 = newNode("node1"); private final DiscoveryNode node2 = newNode("node2"); - private final Map batchInfo = new HashMap<>(); - private ShardBatchCache shardCache; + + // Needs to be enabled once ShardsBatchGatewayAllocator is pushed + // private final Map batchInfo = new HashMap<>(); + private AsyncShardBatchFetch.ShardBatchCache shardCache; + private List shardsInBatch = new ArrayList<>(); private static final int NUMBER_OF_SHARDS_DEFAULT = 10; @@ -44,7 +47,7 @@ private enum ResponseType { public void setupShardBatchCache(String batchId, int numberOfShards) { Map shardAttributesMap = new HashMap<>(); fillShards(shardAttributesMap, numberOfShards); - this.shardCache = new ShardBatchCache<>( + this.shardCache = new AsyncShardBatchFetch.ShardBatchCache<>( logger, "batch_shards_started", shardAttributesMap, @@ -69,7 +72,7 @@ public void testClearShardCache() { .getNodeGatewayStartedShardsBatch() .containsKey(shard) ); - this.shardCache.deleteData(shard); + this.shardCache.deleteShard(shard); assertFalse( this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null) .get(node1) @@ -84,8 +87,7 @@ public void testGetCacheData() { this.shardCache.initData(node1); this.shardCache.initData(node2); this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); - this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, - ResponseType.EMPTY))); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY))); assertTrue( this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null) .get(node1) @@ -117,8 +119,7 @@ public void testPutData() { this.shardCache.initData(node1); this.shardCache.initData(node2); this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); - this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, - ResponseType.VALID))); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.VALID))); this.shardCache.putData(node2, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.EMPTY))); Map fetchData = shardCache.getCacheData( @@ -140,11 +141,12 @@ public void testNullResponses() { setupShardBatchCache(BATCH_ID, NUMBER_OF_SHARDS_DEFAULT); this.shardCache.initData(node1); this.shardCache.markAsFetching(List.of(node1.getId()), 1); - this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, - ResponseType.NULL))); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getPrimaryResponse(shardsInBatch, ResponseType.NULL))); Map fetchData = shardCache.getCacheData( - DiscoveryNodes.builder().add(node1).build(), null); + DiscoveryNodes.builder().add(node1).build(), + null + ); assertTrue(fetchData.get(node1).getNodeGatewayStartedShardsBatch().isEmpty()); } @@ -153,12 +155,13 @@ public void testFilterFailedShards() { this.shardCache.initData(node1); this.shardCache.initData(node2); this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1); - this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, - getFailedPrimaryResponse(shardsInBatch, 5))); + this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getFailedPrimaryResponse(shardsInBatch, 5))); Map fetchData = shardCache.getCacheData( - DiscoveryNodes.builder().add(node1).add(node2).build(), null); + DiscoveryNodes.builder().add(node1).add(node2).build(), + null + ); - assertEquals(5, batchInfo.size()); + // assertEquals(5, batchInfo.size()); assertEquals(2, fetchData.size()); assertEquals(5, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size()); assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().isEmpty()); @@ -185,24 +188,24 @@ private Map getPrimaryResponse(List s return shardData; } - private Map getFailedPrimaryResponse(List shards, - int failedShardsCount) { + private Map getFailedPrimaryResponse(List shards, int failedShardsCount) { int allocationId = 1; Map shardData = new HashMap<>(); for (ShardId shard : shards) { if (failedShardsCount-- > 0) { - shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, - new OpenSearchRejectedExecutionException())); + shardData.put( + shard, + new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, new OpenSearchRejectedExecutionException()) + ); } else { - shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, - null)); + shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, null)); } } return shardData; } public void removeShard(ShardId shardId) { - batchInfo.remove(shardId); + // batchInfo.remove(shardId); } private void fillShards(Map shardAttributesMap, int numberOfShards) { @@ -210,10 +213,10 @@ private void fillShards(Map shardAttributesMap, int nu for (ShardId shardId : shardsInBatch) { ShardAttributes attr = new ShardAttributes(""); shardAttributesMap.put(shardId, attr); - batchInfo.put( - shardId, - new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id())) - ); + // batchInfo.put( + // shardId, + // new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id())) + // ); } } diff --git a/server/src/test/java/org/opensearch/indices/store/ShardAttributesTests.java b/server/src/test/java/org/opensearch/indices/store/ShardAttributesTests.java index 7fa95fefe72fd..94834bab1d98b 100644 --- a/server/src/test/java/org/opensearch/indices/store/ShardAttributesTests.java +++ b/server/src/test/java/org/opensearch/indices/store/ShardAttributesTests.java @@ -28,13 +28,12 @@ public class ShardAttributesTests extends OpenSearchTestCase { String customDataPath = "/path/to/data"; public void testShardAttributesConstructor() { - ShardAttributes attributes = new ShardAttributes(shardId, customDataPath); - assertEquals(attributes.getShardId(), shardId); + ShardAttributes attributes = new ShardAttributes(customDataPath); assertEquals(attributes.getCustomDataPath(), customDataPath); } public void testSerialization() throws IOException { - ShardAttributes attributes1 = new ShardAttributes(shardId, customDataPath); + ShardAttributes attributes1 = new ShardAttributes(customDataPath); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); StreamOutput output = new DataOutputStreamOutput(new DataOutputStream(bytes)); attributes1.writeTo(output); @@ -42,7 +41,6 @@ public void testSerialization() throws IOException { StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(bytes.toByteArray())); ShardAttributes attributes2 = new ShardAttributes(input); input.close(); - assertEquals(attributes1.getShardId(), attributes2.getShardId()); assertEquals(attributes1.getCustomDataPath(), attributes2.getCustomDataPath()); }