From 541979e3ae1ba01b652a84fb3a84d78c8c514dbc Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Tue, 3 Dec 2024 12:55:21 -0800 Subject: [PATCH] Filter shards for sliced search at coordinator Prior to this commit, a sliced search would fan out to every shard, then apply a MatchNoDocsQuery filter on shards that don't correspond to the current slice. This still creates a (useless) search context on each shard for every slice, though. For a long-running sliced scroll, this can quickly exhaust the number of available scroll contexts. This change avoids fanning out to all the shards by checking at the coordinator if a shard is matched by the current slice. This should reduce the number of open scroll contexts to max(numShards, numSlices) instead of numShards * numSlices. Signed-off-by: Michael Froh --- .../shards/ClusterSearchShardsRequest.java | 12 ++++++ .../TransportClusterSearchShardsAction.java | 2 +- ...TransportFieldCapabilitiesIndexAction.java | 3 +- .../action/search/TransportSearchAction.java | 11 +++++- .../cluster/routing/OperationRouting.java | 27 ++++++++++--- .../opensearch/search/slice/SliceBuilder.java | 31 ++++++++------- .../search/TransportSearchActionTests.java | 5 +++ .../routing/OperationRoutingTests.java | 38 ++++++++++--------- 8 files changed, 87 insertions(+), 42 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/admin/cluster/shards/ClusterSearchShardsRequest.java b/server/src/main/java/org/opensearch/action/admin/cluster/shards/ClusterSearchShardsRequest.java index 62e05ebb37e28..bf2c31d842b86 100644 --- a/server/src/main/java/org/opensearch/action/admin/cluster/shards/ClusterSearchShardsRequest.java +++ b/server/src/main/java/org/opensearch/action/admin/cluster/shards/ClusterSearchShardsRequest.java @@ -41,6 +41,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.search.slice.SliceBuilder; import java.io.IOException; import java.util.Objects; @@ -61,6 +62,8 @@ public class ClusterSearchShardsRequest extends ClusterManagerNodeReadRequest nodeIds = new HashSet<>(); GroupShardsIterator groupShardsIterator = clusterService.operationRouting() - .searchShards(clusterState, concreteIndices, routingMap, request.preference()); + .searchShards(clusterState, concreteIndices, routingMap, request.preference(), null, null, request.slice()); ShardRouting shard; ClusterSearchShardsGroup[] groupResponses = new ClusterSearchShardsGroup[groupShardsIterator.size()]; int currentGroup = 0; diff --git a/server/src/main/java/org/opensearch/action/fieldcaps/TransportFieldCapabilitiesIndexAction.java b/server/src/main/java/org/opensearch/action/fieldcaps/TransportFieldCapabilitiesIndexAction.java index 10bf4975311d6..52937182e6a63 100644 --- a/server/src/main/java/org/opensearch/action/fieldcaps/TransportFieldCapabilitiesIndexAction.java +++ b/server/src/main/java/org/opensearch/action/fieldcaps/TransportFieldCapabilitiesIndexAction.java @@ -247,8 +247,7 @@ private AsyncShardsAction(FieldCapabilitiesIndexRequest request, ActionListener< throw blockException; } - shardsIt = clusterService.operationRouting() - .searchShards(clusterService.state(), new String[] { request.index() }, null, null, null, null); + shardsIt = clusterService.operationRouting().searchShards(clusterService.state(), new String[] { request.index() }, null, null); } public void start() { diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 8c4927afa9a14..dfec2e1fda738 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -85,6 +85,7 @@ import org.opensearch.search.pipeline.SearchPipelineService; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.slice.SliceBuilder; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskResourceTrackingService; @@ -551,6 +552,7 @@ private ActionListener buildRewriteListener( ); } else { AtomicInteger skippedClusters = new AtomicInteger(0); + SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice(); collectSearchShards( searchRequest.indicesOptions(), searchRequest.preference(), @@ -559,6 +561,7 @@ private ActionListener buildRewriteListener( remoteClusterIndices, remoteClusterService, threadPool, + slice, ActionListener.wrap(searchShardsResponses -> { final BiFunction clusterNodeLookup = getRemoteClusterNodeLookup( searchShardsResponses @@ -787,6 +790,7 @@ static void collectSearchShards( Map remoteIndicesByCluster, RemoteClusterService remoteClusterService, ThreadPool threadPool, + SliceBuilder slice, ActionListener> listener ) { final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size()); @@ -800,7 +804,8 @@ static void collectSearchShards( ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(indicesOptions) .local(true) .preference(preference) - .routing(routing); + .routing(routing) + .slice(slice); clusterClient.admin() .cluster() .searchShards( @@ -1042,6 +1047,7 @@ private void executeSearch( concreteLocalIndices[i] = indices[i].getName(); } Map nodeSearchCounts = searchTransportService.getPendingSearchRequests(); + SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice(); GroupShardsIterator localShardRoutings = clusterService.operationRouting() .searchShards( clusterState, @@ -1049,7 +1055,8 @@ private void executeSearch( routingMap, searchRequest.preference(), searchService.getResponseCollectorService(), - nodeSearchCounts + nodeSearchCounts, + slice ); localShardIterators = StreamSupport.stream(localShardRoutings.spliterator(), false) .map(it -> new SearchShardIterator(searchRequest.getLocalClusterAlias(), it.shardId(), it.getShardRoutings(), localIndices)) diff --git a/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java b/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java index fe9e00b250e70..0d82013445940 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java +++ b/server/src/main/java/org/opensearch/cluster/routing/OperationRouting.java @@ -32,6 +32,7 @@ package org.opensearch.cluster.routing; +import org.apache.lucene.util.CollectionUtil; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.WeightedRoutingMetadata; @@ -48,10 +49,12 @@ import org.opensearch.index.IndexModule; import org.opensearch.index.IndexNotFoundException; import org.opensearch.node.ResponseCollectorService; +import org.opensearch.search.slice.SliceBuilder; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -230,7 +233,7 @@ public GroupShardsIterator searchShards( @Nullable Map> routing, @Nullable String preference ) { - return searchShards(clusterState, concreteIndices, routing, preference, null, null); + return searchShards(clusterState, concreteIndices, routing, preference, null, null, null); } public GroupShardsIterator searchShards( @@ -239,9 +242,10 @@ public GroupShardsIterator searchShards( @Nullable Map> routing, @Nullable String preference, @Nullable ResponseCollectorService collectorService, - @Nullable Map nodeCounts + @Nullable Map nodeCounts, + @Nullable SliceBuilder slice ) { - final Set shards = computeTargetedShards(clusterState, concreteIndices, routing); + final Set shards = computeTargetedShards(clusterState, concreteIndices, routing, slice); final Set set = new HashSet<>(shards.size()); for (IndexShardRoutingTable shard : shards) { IndexMetadata indexMetadataForShard = indexMetadata(clusterState, shard.shardId.getIndex().getName()); @@ -290,12 +294,14 @@ public static ShardIterator getShards(ClusterState clusterState, ShardId shardId private Set computeTargetedShards( ClusterState clusterState, String[] concreteIndices, - @Nullable Map> routing + @Nullable Map> routing, + @Nullable SliceBuilder slice ) { routing = routing == null ? EMPTY_ROUTING : routing; // just use an empty map final Set set = new HashSet<>(); // we use set here and not list since we might get duplicates for (String index : concreteIndices) { + Set indexSet = new HashSet<>(); final IndexRoutingTable indexRouting = indexRoutingTable(clusterState, index); final IndexMetadata indexMetadata = indexMetadata(clusterState, index); final Set effectiveRouting = routing.get(index); @@ -303,12 +309,21 @@ private Set computeTargetedShards( for (String r : effectiveRouting) { final int routingPartitionSize = indexMetadata.getRoutingPartitionSize(); for (int partitionOffset = 0; partitionOffset < routingPartitionSize; partitionOffset++) { - set.add(RoutingTable.shardRoutingTable(indexRouting, calculateScaledShardId(indexMetadata, r, partitionOffset))); + indexSet.add( + RoutingTable.shardRoutingTable(indexRouting, calculateScaledShardId(indexMetadata, r, partitionOffset)) + ); } } } else { for (IndexShardRoutingTable indexShard : indexRouting) { - set.add(indexShard); + indexSet.add(indexShard); + } + } + List shards = new ArrayList<>(indexSet); + CollectionUtil.timSort(shards, Comparator.comparing(s -> s.shardId)); + for (int i = 0; i < shards.size(); i++) { + if (slice == null || slice.shardMatches(i, shards.size())) { + set.add(shards.get(i)); } } } diff --git a/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java b/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java index c9b8a896ed525..723f07cc4edf0 100644 --- a/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java +++ b/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java @@ -214,6 +214,15 @@ public int hashCode() { return Objects.hash(this.field, this.id, this.max); } + public boolean shardMatches(int shardId, int numShards) { + if (max >= numShards) { + // Slices are distributed over shards + return id % numShards == shardId; + } + // Shards are distributed over slices + return shardId % max == id; + } + /** * Converts this QueryBuilder to a lucene {@link Query}. * @@ -255,7 +264,12 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request, } } - String field = this.field; + if (shardMatches(shardId, numShards) == false) { + // We should have already excluded this shard before routing to it. + // If we somehow land here, then we match nothing. + return new MatchNoDocsQuery("this shard is not part of the slice"); + } + boolean useTermQuery = false; if ("_uid".equals(field)) { throw new IllegalArgumentException("Computing slices on the [_uid] field is illegal for 7.x indices, use [_id] instead"); @@ -277,12 +291,7 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request, // the number of slices is greater than the number of shards // in such case we can reduce the number of requested shards by slice - // first we check if the slice is responsible of this shard int targetShard = id % numShards; - if (targetShard != shardId) { - // the shard is not part of this slice, we can skip it. - return new MatchNoDocsQuery("this shard is not part of the slice"); - } // compute the number of slices where this shard appears int numSlicesInShard = max / numShards; int rest = max % numShards; @@ -301,14 +310,8 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request, ? new TermsSliceQuery(field, shardSlice, numSlicesInShard) : new DocValuesSliceQuery(field, shardSlice, numSlicesInShard); } - // the number of shards is greater than the number of slices + // the number of shards is greater than the number of slices. If we target this shard, we target all of it. - // check if the shard is assigned to the slice - int targetSlice = shardId % max; - if (id != targetSlice) { - // the shard is not part of this slice, we can skip it. - return new MatchNoDocsQuery("this shard is not part of the slice"); - } return new MatchAllDocsQuery(); } @@ -321,6 +324,8 @@ private GroupShardsIterator buildShardIterator(ClusterService clu Map> routingMap = request.indexRoutings().length > 0 ? Collections.singletonMap(indices[0], Sets.newHashSet(request.indexRoutings())) : null; + // Note that we do *not* want to filter this set of shard IDs based on the slice, since we want the + // full set of shards matched by the routing parameters. return clusterService.operationRouting().searchShards(state, indices, routingMap, request.preference()); } diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index 84955d01a59ce..0a0015ae8cbf6 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -809,6 +809,7 @@ public void testCollectSearchShards() throws Exception { remoteIndicesByCluster, remoteClusterService, threadPool, + null, new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch) ); awaitLatch(latch, 5, TimeUnit.SECONDS); @@ -835,6 +836,7 @@ public void testCollectSearchShards() throws Exception { remoteIndicesByCluster, remoteClusterService, threadPool, + null, new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch) ); awaitLatch(latch, 5, TimeUnit.SECONDS); @@ -880,6 +882,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti remoteIndicesByCluster, remoteClusterService, threadPool, + null, new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch) ); awaitLatch(latch, 5, TimeUnit.SECONDS); @@ -907,6 +910,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti remoteIndicesByCluster, remoteClusterService, threadPool, + null, new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch) ); awaitLatch(latch, 5, TimeUnit.SECONDS); @@ -949,6 +953,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti remoteIndicesByCluster, remoteClusterService, threadPool, + null, new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch) ); awaitLatch(latch, 5, TimeUnit.SECONDS); diff --git a/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java b/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java index aaeeb52ab5709..4263e1aa347dc 100644 --- a/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java +++ b/server/src/test/java/org/opensearch/cluster/routing/OperationRoutingTests.java @@ -604,7 +604,8 @@ public void testAdaptiveReplicaSelection() throws Exception { null, null, collector, - outstandingRequests + outstandingRequests, + null ); assertThat("One group per index shard", groupIterator.size(), equalTo(numIndices * numShards)); @@ -616,7 +617,7 @@ public void testAdaptiveReplicaSelection() throws Exception { searchedShards.add(firstChoice); selectedNodes.add(firstChoice.currentNodeId()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); assertThat(groupIterator.size(), equalTo(numIndices * numShards)); ShardRouting secondChoice = groupIterator.get(0).nextOrNull(); @@ -624,7 +625,7 @@ public void testAdaptiveReplicaSelection() throws Exception { searchedShards.add(secondChoice); selectedNodes.add(secondChoice.currentNodeId()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); assertThat(groupIterator.size(), equalTo(numIndices * numShards)); ShardRouting thirdChoice = groupIterator.get(0).nextOrNull(); @@ -643,26 +644,26 @@ public void testAdaptiveReplicaSelection() throws Exception { outstandingRequests.put("node_1", 1L); outstandingRequests.put("node_2", 1L); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); ShardRouting shardChoice = groupIterator.get(0).nextOrNull(); // node 1 should be the lowest ranked node to start assertThat(shardChoice.currentNodeId(), equalTo("node_1")); // node 1 starts getting more loaded... collector.addNodeStatistics("node_1", 2, TimeValue.timeValueMillis(200).nanos(), TimeValue.timeValueMillis(150).nanos()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); shardChoice = groupIterator.get(0).nextOrNull(); assertThat(shardChoice.currentNodeId(), equalTo("node_1")); // and more loaded... collector.addNodeStatistics("node_1", 3, TimeValue.timeValueMillis(250).nanos(), TimeValue.timeValueMillis(200).nanos()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); shardChoice = groupIterator.get(0).nextOrNull(); assertThat(shardChoice.currentNodeId(), equalTo("node_1")); // and even more collector.addNodeStatistics("node_1", 4, TimeValue.timeValueMillis(300).nanos(), TimeValue.timeValueMillis(250).nanos()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); shardChoice = groupIterator.get(0).nextOrNull(); // finally, node 2 is chosen instead assertThat(shardChoice.currentNodeId(), equalTo("node_2")); @@ -709,7 +710,8 @@ public void testAdaptiveReplicaSelectionWithZoneAwarenessIgnored() throws Except null, null, collector, - outstandingRequests + outstandingRequests, + null ); assertThat("One group per index shard", groupIterator.size(), equalTo(numIndices * numShards)); @@ -722,7 +724,7 @@ public void testAdaptiveReplicaSelectionWithZoneAwarenessIgnored() throws Except searchedShards.add(firstChoice); selectedNodes.add(firstChoice.currentNodeId()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); assertThat(groupIterator.size(), equalTo(numIndices * numShards)); assertThat(groupIterator.get(0).size(), equalTo(numReplicas + 1)); @@ -745,18 +747,18 @@ public void testAdaptiveReplicaSelectionWithZoneAwarenessIgnored() throws Except outstandingRequests.put("node_a1", 1L); outstandingRequests.put("node_b2", 1L); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); // node_a0 or node_a1 should be the lowest ranked node to start groupIterator.forEach(shardRoutings -> assertThat(shardRoutings.nextOrNull().currentNodeId(), containsString("node_a"))); // Adding more load to node_a0 collector.addNodeStatistics("node_a0", 10, TimeValue.timeValueMillis(200).nanos(), TimeValue.timeValueMillis(150).nanos()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); // Adding more load to node_a0 and node_a1 from zone-a collector.addNodeStatistics("node_a1", 100, TimeValue.timeValueMillis(300).nanos(), TimeValue.timeValueMillis(250).nanos()); collector.addNodeStatistics("node_a0", 100, TimeValue.timeValueMillis(300).nanos(), TimeValue.timeValueMillis(250).nanos()); - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); // ARS should pick node_b2 from zone-b since both node_a0 and node_a1 are overloaded groupIterator.forEach(shardRoutings -> assertThat(shardRoutings.nextOrNull().currentNodeId(), containsString("node_b"))); @@ -842,8 +844,8 @@ public void testWeightedOperationRouting() throws Exception { null, null, collector, - outstandingRequests - + outstandingRequests, + null ); for (ShardIterator it : groupIterator) { @@ -871,7 +873,7 @@ public void testWeightedOperationRouting() throws Exception { opRouting = new OperationRouting(setting, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); // search shards call - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); for (ShardIterator it : groupIterator) { List shardRoutings = Collections.singletonList(it.nextOrNull()); @@ -935,8 +937,8 @@ public void testWeightedOperationRoutingWeightUndefinedForOneZone() throws Excep null, null, collector, - outstandingRequests - + outstandingRequests, + null ); for (ShardIterator it : groupIterator) { @@ -969,7 +971,7 @@ public void testWeightedOperationRoutingWeightUndefinedForOneZone() throws Excep opRouting = new OperationRouting(setting, new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); // search shards call - groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests); + groupIterator = opRouting.searchShards(state, indexNames, null, null, collector, outstandingRequests, null); for (ShardIterator it : groupIterator) { while (it.remaining() > 0) {