Skip to content

Commit

Permalink
Filter shards for sliced search at coordinator
Browse files Browse the repository at this point in the history
Prior to this commit, a sliced search would fan out to every shard,
then apply a MatchNoDocsQuery filter on shards that don't correspond
to the current slice. This still creates a (useless) search context
on each shard for every slice, though. For a long-running sliced
scroll, this can quickly exhaust the number of available scroll
contexts.

This change avoids fanning out to all the shards by checking at the
coordinator if a shard is matched by the current slice. This should
reduce the number of open scroll contexts to max(numShards, numSlices)
instead of numShards * numSlices.

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed Dec 10, 2024
1 parent 5ba909a commit 541979e
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.slice.SliceBuilder;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -61,6 +62,8 @@ public class ClusterSearchShardsRequest extends ClusterManagerNodeReadRequest<Cl
@Nullable
private String preference;
private IndicesOptions indicesOptions = IndicesOptions.lenientExpandOpen();
@Nullable
private SliceBuilder sliceBuilder;

public ClusterSearchShardsRequest() {}

Expand Down Expand Up @@ -166,4 +169,13 @@ public ClusterSearchShardsRequest preference(String preference) {
public String preference() {
return this.preference;
}

public ClusterSearchShardsRequest slice(SliceBuilder sliceBuilder) {
this.sliceBuilder = sliceBuilder;
return this;
}

public SliceBuilder slice() {
return this.sliceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ protected void clusterManagerOperation(

Set<String> nodeIds = new HashSet<>();
GroupShardsIterator<ShardIterator> groupShardsIterator = clusterService.operationRouting()
.searchShards(clusterState, concreteIndices, routingMap, request.preference());
.searchShards(clusterState, concreteIndices, routingMap, request.preference(), null, null, request.slice());
ShardRouting shard;
ClusterSearchShardsGroup[] groupResponses = new ClusterSearchShardsGroup[groupShardsIterator.size()];
int currentGroup = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ private AsyncShardsAction(FieldCapabilitiesIndexRequest request, ActionListener<
throw blockException;
}

shardsIt = clusterService.operationRouting()
.searchShards(clusterService.state(), new String[] { request.index() }, null, null, null, null);
shardsIt = clusterService.operationRouting().searchShards(clusterService.state(), new String[] { request.index() }, null, null);
}

public void start() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.slice.SliceBuilder;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
Expand Down Expand Up @@ -551,6 +552,7 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
);
} else {
AtomicInteger skippedClusters = new AtomicInteger(0);
SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice();
collectSearchShards(
searchRequest.indicesOptions(),
searchRequest.preference(),
Expand All @@ -559,6 +561,7 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
remoteClusterIndices,
remoteClusterService,
threadPool,
slice,
ActionListener.wrap(searchShardsResponses -> {
final BiFunction<String, String, DiscoveryNode> clusterNodeLookup = getRemoteClusterNodeLookup(
searchShardsResponses
Expand Down Expand Up @@ -787,6 +790,7 @@ static void collectSearchShards(
Map<String, OriginalIndices> remoteIndicesByCluster,
RemoteClusterService remoteClusterService,
ThreadPool threadPool,
SliceBuilder slice,
ActionListener<Map<String, ClusterSearchShardsResponse>> listener
) {
final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size());
Expand All @@ -800,7 +804,8 @@ static void collectSearchShards(
ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(indicesOptions)
.local(true)
.preference(preference)
.routing(routing);
.routing(routing)
.slice(slice);
clusterClient.admin()
.cluster()
.searchShards(
Expand Down Expand Up @@ -1042,14 +1047,16 @@ private void executeSearch(
concreteLocalIndices[i] = indices[i].getName();
}
Map<String, Long> nodeSearchCounts = searchTransportService.getPendingSearchRequests();
SliceBuilder slice = searchRequest.source() == null ? null : searchRequest.source().slice();
GroupShardsIterator<ShardIterator> localShardRoutings = clusterService.operationRouting()
.searchShards(
clusterState,
concreteLocalIndices,
routingMap,
searchRequest.preference(),
searchService.getResponseCollectorService(),
nodeSearchCounts
nodeSearchCounts,
slice
);
localShardIterators = StreamSupport.stream(localShardRoutings.spliterator(), false)
.map(it -> new SearchShardIterator(searchRequest.getLocalClusterAlias(), it.shardId(), it.getShardRoutings(), localIndices))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.cluster.routing;

import org.apache.lucene.util.CollectionUtil;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.WeightedRoutingMetadata;
Expand All @@ -48,10 +49,12 @@
import org.opensearch.index.IndexModule;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.node.ResponseCollectorService;
import org.opensearch.search.slice.SliceBuilder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -230,7 +233,7 @@ public GroupShardsIterator<ShardIterator> searchShards(
@Nullable Map<String, Set<String>> routing,
@Nullable String preference
) {
return searchShards(clusterState, concreteIndices, routing, preference, null, null);
return searchShards(clusterState, concreteIndices, routing, preference, null, null, null);
}

public GroupShardsIterator<ShardIterator> searchShards(
Expand All @@ -239,9 +242,10 @@ public GroupShardsIterator<ShardIterator> searchShards(
@Nullable Map<String, Set<String>> routing,
@Nullable String preference,
@Nullable ResponseCollectorService collectorService,
@Nullable Map<String, Long> nodeCounts
@Nullable Map<String, Long> nodeCounts,
@Nullable SliceBuilder slice
) {
final Set<IndexShardRoutingTable> shards = computeTargetedShards(clusterState, concreteIndices, routing);
final Set<IndexShardRoutingTable> shards = computeTargetedShards(clusterState, concreteIndices, routing, slice);
final Set<ShardIterator> set = new HashSet<>(shards.size());
for (IndexShardRoutingTable shard : shards) {
IndexMetadata indexMetadataForShard = indexMetadata(clusterState, shard.shardId.getIndex().getName());
Expand Down Expand Up @@ -290,25 +294,36 @@ public static ShardIterator getShards(ClusterState clusterState, ShardId shardId
private Set<IndexShardRoutingTable> computeTargetedShards(
ClusterState clusterState,
String[] concreteIndices,
@Nullable Map<String, Set<String>> routing
@Nullable Map<String, Set<String>> routing,
@Nullable SliceBuilder slice
) {
routing = routing == null ? EMPTY_ROUTING : routing; // just use an empty map
final Set<IndexShardRoutingTable> set = new HashSet<>();
// we use set here and not list since we might get duplicates
for (String index : concreteIndices) {
Set<IndexShardRoutingTable> indexSet = new HashSet<>();
final IndexRoutingTable indexRouting = indexRoutingTable(clusterState, index);
final IndexMetadata indexMetadata = indexMetadata(clusterState, index);
final Set<String> effectiveRouting = routing.get(index);
if (effectiveRouting != null) {
for (String r : effectiveRouting) {
final int routingPartitionSize = indexMetadata.getRoutingPartitionSize();
for (int partitionOffset = 0; partitionOffset < routingPartitionSize; partitionOffset++) {
set.add(RoutingTable.shardRoutingTable(indexRouting, calculateScaledShardId(indexMetadata, r, partitionOffset)));
indexSet.add(
RoutingTable.shardRoutingTable(indexRouting, calculateScaledShardId(indexMetadata, r, partitionOffset))
);
}
}
} else {
for (IndexShardRoutingTable indexShard : indexRouting) {
set.add(indexShard);
indexSet.add(indexShard);
}
}
List<IndexShardRoutingTable> shards = new ArrayList<>(indexSet);
CollectionUtil.timSort(shards, Comparator.comparing(s -> s.shardId));
for (int i = 0; i < shards.size(); i++) {
if (slice == null || slice.shardMatches(i, shards.size())) {
set.add(shards.get(i));
}
}
}
Expand Down
31 changes: 18 additions & 13 deletions server/src/main/java/org/opensearch/search/slice/SliceBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ public int hashCode() {
return Objects.hash(this.field, this.id, this.max);
}

public boolean shardMatches(int shardId, int numShards) {
if (max >= numShards) {
// Slices are distributed over shards
return id % numShards == shardId;
}
// Shards are distributed over slices
return shardId % max == id;
}

/**
* Converts this QueryBuilder to a lucene {@link Query}.
*
Expand Down Expand Up @@ -255,7 +264,12 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
}
}

String field = this.field;
if (shardMatches(shardId, numShards) == false) {
// We should have already excluded this shard before routing to it.
// If we somehow land here, then we match nothing.
return new MatchNoDocsQuery("this shard is not part of the slice");
}

boolean useTermQuery = false;
if ("_uid".equals(field)) {
throw new IllegalArgumentException("Computing slices on the [_uid] field is illegal for 7.x indices, use [_id] instead");
Expand All @@ -277,12 +291,7 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
// the number of slices is greater than the number of shards
// in such case we can reduce the number of requested shards by slice

// first we check if the slice is responsible of this shard
int targetShard = id % numShards;
if (targetShard != shardId) {
// the shard is not part of this slice, we can skip it.
return new MatchNoDocsQuery("this shard is not part of the slice");
}
// compute the number of slices where this shard appears
int numSlicesInShard = max / numShards;
int rest = max % numShards;
Expand All @@ -301,14 +310,8 @@ public Query toFilter(ClusterService clusterService, ShardSearchRequest request,
? new TermsSliceQuery(field, shardSlice, numSlicesInShard)
: new DocValuesSliceQuery(field, shardSlice, numSlicesInShard);
}
// the number of shards is greater than the number of slices
// the number of shards is greater than the number of slices. If we target this shard, we target all of it.

// check if the shard is assigned to the slice
int targetSlice = shardId % max;
if (id != targetSlice) {
// the shard is not part of this slice, we can skip it.
return new MatchNoDocsQuery("this shard is not part of the slice");
}
return new MatchAllDocsQuery();
}

Expand All @@ -321,6 +324,8 @@ private GroupShardsIterator<ShardIterator> buildShardIterator(ClusterService clu
Map<String, Set<String>> routingMap = request.indexRoutings().length > 0
? Collections.singletonMap(indices[0], Sets.newHashSet(request.indexRoutings()))
: null;
// Note that we do *not* want to filter this set of shard IDs based on the slice, since we want the
// full set of shards matched by the routing parameters.
return clusterService.operationRouting().searchShards(state, indices, routingMap, request.preference());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ public void testCollectSearchShards() throws Exception {
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand All @@ -835,6 +836,7 @@ public void testCollectSearchShards() throws Exception {
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -880,6 +882,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -907,6 +910,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down Expand Up @@ -949,6 +953,7 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti
remoteIndicesByCluster,
remoteClusterService,
threadPool,
null,
new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
);
awaitLatch(latch, 5, TimeUnit.SECONDS);
Expand Down
Loading

0 comments on commit 541979e

Please sign in to comment.