Skip to content

Commit

Permalink
Add a rewrite phase that allows retrievers to handle nested retrievers
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Jun 10, 2024
1 parent 74cec42 commit 9ff421b
Show file tree
Hide file tree
Showing 25 changed files with 969 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
Expand Down Expand Up @@ -65,11 +66,13 @@
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.SearchProfileResults;
import org.elasticsearch.search.profile.SearchProfileShardResult;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -150,6 +153,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
private final int defaultPreFilterShardSize;
private final boolean ccsCheckCompatibility;
private final SearchResponseMetrics searchResponseMetrics;
private final NodeClient nodeClient;

@Inject
public TransportSearchAction(
Expand All @@ -165,7 +169,8 @@ public TransportSearchAction(
NamedWriteableRegistry namedWriteableRegistry,
ExecutorSelector executorSelector,
SearchTransportAPMMetrics searchTransportMetrics,
SearchResponseMetrics searchResponseMetrics
SearchResponseMetrics searchResponseMetrics,
NodeClient nodeClient
) {
super(TYPE.name(), transportService, actionFilters, SearchRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
this.threadPool = threadPool;
Expand All @@ -183,6 +188,7 @@ public TransportSearchAction(
this.defaultPreFilterShardSize = DEFAULT_PRE_FILTER_SHARD_SIZE.get(clusterService.getSettings());
this.ccsCheckCompatibility = SearchService.CCS_VERSION_CHECK_SETTING.get(clusterService.getSettings());
this.searchResponseMetrics = searchResponseMetrics;
this.nodeClient = nodeClient;
}

private Map<String, OriginalIndices> buildPerIndexOriginalIndices(
Expand Down Expand Up @@ -311,13 +317,13 @@ protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<

void executeRequest(
SearchTask task,
SearchRequest original,
SearchRequest searchRequest,
ActionListener<SearchResponse> listener,
Function<ActionListener<SearchResponse>, SearchPhaseProvider> searchPhaseProvider
) {
final long relativeStartNanos = System.nanoTime();
final SearchTimeProvider timeProvider = new SearchTimeProvider(
original.getOrCreateAbsoluteStartMillis(),
searchRequest.getOrCreateAbsoluteStartMillis(),
relativeStartNanos,
System::nanoTime
);
Expand All @@ -326,16 +332,16 @@ void executeRequest(
clusterState.blocks().globalBlockedRaiseException(ClusterBlockLevel.READ);

final ResolvedIndices resolvedIndices;
if (original.pointInTimeBuilder() != null) {
if (searchRequest.pointInTimeBuilder() != null) {
resolvedIndices = ResolvedIndices.resolveWithPIT(
original.pointInTimeBuilder(),
original.indicesOptions(),
searchRequest.pointInTimeBuilder(),
searchRequest.indicesOptions(),
clusterState,
namedWriteableRegistry
);
} else {
resolvedIndices = ResolvedIndices.resolveWithIndicesRequest(
original,
searchRequest,
clusterState,
indexNameExpressionResolver,
remoteClusterService,
Expand All @@ -344,7 +350,8 @@ void executeRequest(
frozenIndexCheck(resolvedIndices);
}

ActionListener<SearchRequest> rewriteListener = listener.delegateFailureAndWrap((delegate, rewritten) -> {
var retriever = searchRequest.source().retriever();
ActionListener<SearchRequest> rewriteSearchRequestListener = listener.delegateFailureAndWrap((delegate, rewritten) -> {
if (ccsCheckCompatibility) {
checkCCSVersionCompatibility(rewritten);
}
Expand Down Expand Up @@ -461,12 +468,70 @@ void executeRequest(
}
}
});

Rewriteable.rewriteAndFetch(
original,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices),
rewriteListener
if (retriever == null) {
Rewriteable.rewriteAndFetch(
searchRequest,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices),
rewriteSearchRequestListener
);
return;
}
searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
if (retriever.requiresPointInTime() && searchRequest.source().pointInTimeBuilder() == null) {
rewriteSearchRequestListener = ActionListener.releaseAfter(
rewriteSearchRequestListener,
() -> closePIT(searchRequest.source().pointInTimeBuilder())
);
}
ActionListener<RetrieverBuilder> rewriteRetrieverListener = rewriteSearchRequestListener.delegateFailureAndWrap(
(delegate, newRetriever) -> {
newRetriever.extractToSearchSourceBuilder(searchRequest.source(), false);
Rewriteable.rewriteAndFetch(
searchRequest,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices),
delegate
);
}
);
if (searchRequest.source().pointInTimeBuilder() == null) {
ActionListener<OpenPointInTimeResponse> openPitListener = rewriteRetrieverListener.delegateFailureAndWrap((delegate, resp) -> {
var pit = new PointInTimeBuilder(resp.getPointInTimeId());
searchRequest.source().pointInTimeBuilder(pit);
Rewriteable.rewriteAndFetch(
retriever,
searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, pit),
rewriteRetrieverListener
);
});

OpenPointInTimeRequest pitReq = new OpenPointInTimeRequest(searchRequest.indices()).indicesOptions(
searchRequest.indicesOptions()
).preference(searchRequest.preference()).routing(searchRequest.routing()).keepAlive(TimeValue.ONE_MINUTE);
nodeClient.execute(TransportOpenPointInTimeAction.TYPE, pitReq, openPitListener);
} else {
Rewriteable.rewriteAndFetch(
retriever,
searchService.getRewriteContext(
timeProvider::absoluteStartMillis,
resolvedIndices,
searchRequest.source().pointInTimeBuilder()
),
rewriteRetrieverListener
);
}
}

private void closePIT(PointInTimeBuilder pit) {
if (pit == null) {
return;
}
nodeClient.execute(TransportClosePointInTimeAction.TYPE, new ClosePointInTimeRequest(pit.getEncodedId()), new ActionListener<>() {
@Override
public void onResponse(ClosePointInTimeResponse resp) {}

@Override
public void onFailure(Exception e) {}
});
}

static void adjustSearchType(SearchRequest searchRequest, boolean singleShard) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ public QueryRewriteContext newQueryRewriteContext(
valuesSourceRegistry,
allowExpensiveQueries,
scriptService,
null,
null
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public CoordinatorRewriteContext(
null,
null,
null,
null,
null
);
this.indexLongFieldRange = indexLongFieldRange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.script.ScriptCompiler;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;

Expand Down Expand Up @@ -62,6 +63,7 @@ public class QueryRewriteContext {
protected boolean mapUnmappedFieldAsString;
protected Predicate<String> allowedFields;
private final ResolvedIndices resolvedIndices;
private final PointInTimeBuilder pit;

public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration,
Expand All @@ -77,7 +79,8 @@ public QueryRewriteContext(
final ValuesSourceRegistry valuesSourceRegistry,
final BooleanSupplier allowExpensiveQueries,
final ScriptCompiler scriptService,
final ResolvedIndices resolvedIndices
final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit
) {

this.parserConfiguration = parserConfiguration;
Expand All @@ -95,6 +98,7 @@ public QueryRewriteContext(
this.allowExpensiveQueries = allowExpensiveQueries;
this.scriptService = scriptService;
this.resolvedIndices = resolvedIndices;
this.pit = pit;
}

public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) {
Expand All @@ -112,6 +116,7 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration
null,
null,
null,
null,
null
);
}
Expand All @@ -120,7 +125,8 @@ public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration,
final Client client,
final LongSupplier nowInMillis,
final ResolvedIndices resolvedIndices
final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit
) {
this(
parserConfiguration,
Expand All @@ -136,7 +142,8 @@ public QueryRewriteContext(
null,
null,
null,
resolvedIndices
resolvedIndices,
pit
);
}

Expand Down Expand Up @@ -390,4 +397,8 @@ public Iterable<Map.Entry<String, MappedFieldType>> getAllFields() {
public ResolvedIndices getResolvedIndices() {
return resolvedIndices;
}

public PointInTimeBuilder pointInTimeBuilder() {
return pit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ private SearchExecutionContext(
valuesSourceRegistry,
allowExpensiveQueries,
scriptService,
null,
null
);
this.shardId = shardId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
Expand Down Expand Up @@ -1759,8 +1760,8 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set<String
/**
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices) {
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices);
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit);
}

public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) {
Expand Down
10 changes: 10 additions & 0 deletions server/src/main/java/org/elasticsearch/search/SearchModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.RankDocsQueryBuilder;
import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder;
import org.elasticsearch.search.retriever.RankDocsSortBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
Expand Down Expand Up @@ -840,6 +843,7 @@ private void registerSorts() {
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScoreSortBuilder.NAME, ScoreSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScriptSortBuilder.NAME, ScriptSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, FieldSortBuilder.NAME, FieldSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, FieldSortBuilder::new));
}

private static <T> void registerFromPlugin(List<SearchPlugin> plugins, Function<SearchPlugin, List<T>> producer, Consumer<T> consumer) {
Expand Down Expand Up @@ -1074,6 +1078,9 @@ private void registerFetchSubPhase(FetchSubPhase subPhase) {
private void registerRetrieverParsers(List<SearchPlugin> plugins) {
registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent));
registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent));
registerRetriever(new RetrieverSpec<>(RankDocsRetrieverBuilder.NAME, (p, c) -> {
throw new IllegalArgumentException("[rank_docs] retriever cannot be provided directly");
}));

registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever);
}
Expand Down Expand Up @@ -1173,6 +1180,9 @@ private void registerQueryParsers(List<SearchPlugin> plugins) {
registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
}));
registerQuery(new QuerySpec<>(RankDocsQueryBuilder.NAME, RankDocsQueryBuilder::new, parser -> {
throw new IllegalArgumentException("[rank_docs] queries cannot be provided directly");
}));

registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.elasticsearch.search.aggregations.SearchContextAggregations;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.AggregationContext.ProductionAggregationContext;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseContext;
Expand Down Expand Up @@ -1820,7 +1821,11 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices) {
return indicesService.getRewriteContext(nowInMillis, resolvedIndices);
return getRewriteContext(nowInMillis, resolvedIndices, null);
}

public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit);
}

public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ public static HighlightBuilder highlight() {

private Map<String, Object> runtimeMappings = emptyMap();

private transient RetrieverBuilder retrieverBuilder;

/**
* Constructs a new search source builder.
*/
Expand Down Expand Up @@ -367,6 +369,10 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

public RetrieverBuilder retriever() {
return retrieverBuilder;
}

/**
* Sets the query for this request.
*/
Expand Down Expand Up @@ -1293,7 +1299,6 @@ private SearchSourceBuilder parseXContent(
}
List<KnnSearchBuilder.Builder> knnBuilders = new ArrayList<>();

RetrieverBuilder retrieverBuilder = null;
SearchUsage searchUsage = new SearchUsage();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
Expand Down Expand Up @@ -1657,9 +1662,7 @@ private SearchSourceBuilder parseXContent(
if (specified.isEmpty() == false) {
throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified);
}
retrieverBuilder.extractToSearchSourceBuilder(this, false);
}

searchUsageConsumer.accept(searchUsage);
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.search.rank;

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -74,4 +75,6 @@ public final int hashCode() {
public String toString() {
return "RankDoc{" + "score=" + score + ", doc=" + doc + ", shardIndex=" + shardIndex + '}';
}

public abstract Explanation explain();
}
Loading

0 comments on commit 9ff421b

Please sign in to comment.