From 9ff421beae3a755f71d8b889192d662cd13ca346 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 10 Jun 2024 09:25:53 +0100 Subject: [PATCH] Add a rewrite phase that allows retrievers to handle nested retrievers --- .../action/search/TransportSearchAction.java | 91 +++++- .../org/elasticsearch/index/IndexService.java | 1 + .../query/CoordinatorRewriteContext.java | 1 + .../index/query/QueryRewriteContext.java | 17 +- .../index/query/SearchExecutionContext.java | 1 + .../elasticsearch/indices/IndicesService.java | 5 +- .../elasticsearch/search/SearchModule.java | 10 + .../elasticsearch/search/SearchService.java | 7 +- .../search/builder/SearchSourceBuilder.java | 9 +- .../elasticsearch/search/rank/RankDoc.java | 3 + .../search/rank/feature/RankFeatureDoc.java | 6 + .../search/retriever/RankDocsQuery.java | 181 ++++++++++++ .../retriever/RankDocsQueryBuilder.java | 98 +++++++ .../retriever/RankDocsRetrieverBuilder.java | 94 ++++++ .../search/retriever/RankDocsSortBuilder.java | 74 +++++ .../search/retriever/RankDocsSortField.java | 93 ++++++ .../search/retriever/RetrieverBuilder.java | 44 +-- .../retriever/StandardRetrieverBuilder.java | 36 +-- .../search/TransportSearchActionTests.java | 3 +- .../snapshots/SnapshotResiliencyTests.java | 3 +- .../search/rank/TestRankDoc.java | 6 + .../test/AbstractBuilderTestCase.java | 3 +- .../xpack/rank/rrf/RRFRankDoc.java | 42 +++ .../xpack/rank/rrf/RRFRetrieverBuilder.java | 271 ++++++++++++++---- .../rrf/RRFRetrieverBuilderParsingTests.java | 25 +- 25 files changed, 969 insertions(+), 155 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortField.java diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index a12d149bbe34..392c8cf11258 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -65,11 +66,13 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationReduceContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.SearchProfileResults; import org.elasticsearch.search.profile.SearchProfileShardResult; +import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; @@ -150,6 +153,7 @@ public class TransportSearchAction extends HandledTransportAction buildPerIndexOriginalIndices( @@ -311,13 +317,13 @@ protected void doExecute(Task task, SearchRequest searchRequest, ActionListener< void executeRequest( SearchTask task, - SearchRequest original, + SearchRequest searchRequest, ActionListener listener, Function, SearchPhaseProvider> searchPhaseProvider ) { final long relativeStartNanos = System.nanoTime(); final SearchTimeProvider timeProvider = new SearchTimeProvider( - original.getOrCreateAbsoluteStartMillis(), + searchRequest.getOrCreateAbsoluteStartMillis(), relativeStartNanos, System::nanoTime ); @@ -326,16 +332,16 @@ void executeRequest( clusterState.blocks().globalBlockedRaiseException(ClusterBlockLevel.READ); final ResolvedIndices resolvedIndices; - if (original.pointInTimeBuilder() != null) { + if (searchRequest.pointInTimeBuilder() != null) { resolvedIndices = ResolvedIndices.resolveWithPIT( - original.pointInTimeBuilder(), - original.indicesOptions(), + searchRequest.pointInTimeBuilder(), + searchRequest.indicesOptions(), clusterState, namedWriteableRegistry ); } else { resolvedIndices = ResolvedIndices.resolveWithIndicesRequest( - original, + searchRequest, clusterState, indexNameExpressionResolver, remoteClusterService, @@ -344,7 +350,8 @@ void executeRequest( frozenIndexCheck(resolvedIndices); } - ActionListener rewriteListener = listener.delegateFailureAndWrap((delegate, rewritten) -> { + var retriever = searchRequest.source().retriever(); + ActionListener rewriteSearchRequestListener = listener.delegateFailureAndWrap((delegate, rewritten) -> { if (ccsCheckCompatibility) { checkCCSVersionCompatibility(rewritten); } @@ -461,12 +468,70 @@ void executeRequest( } } }); - - Rewriteable.rewriteAndFetch( - original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices), - rewriteListener + if (retriever == null) { + Rewriteable.rewriteAndFetch( + searchRequest, + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices), + rewriteSearchRequestListener + ); + return; + } + searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); + if (retriever.requiresPointInTime() && searchRequest.source().pointInTimeBuilder() == null) { + rewriteSearchRequestListener = ActionListener.releaseAfter( + rewriteSearchRequestListener, + () -> closePIT(searchRequest.source().pointInTimeBuilder()) + ); + } + ActionListener rewriteRetrieverListener = rewriteSearchRequestListener.delegateFailureAndWrap( + (delegate, newRetriever) -> { + newRetriever.extractToSearchSourceBuilder(searchRequest.source(), false); + Rewriteable.rewriteAndFetch( + searchRequest, + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices), + delegate + ); + } ); + if (searchRequest.source().pointInTimeBuilder() == null) { + ActionListener openPitListener = rewriteRetrieverListener.delegateFailureAndWrap((delegate, resp) -> { + var pit = new PointInTimeBuilder(resp.getPointInTimeId()); + searchRequest.source().pointInTimeBuilder(pit); + Rewriteable.rewriteAndFetch( + retriever, + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, pit), + rewriteRetrieverListener + ); + }); + + OpenPointInTimeRequest pitReq = new OpenPointInTimeRequest(searchRequest.indices()).indicesOptions( + searchRequest.indicesOptions() + ).preference(searchRequest.preference()).routing(searchRequest.routing()).keepAlive(TimeValue.ONE_MINUTE); + nodeClient.execute(TransportOpenPointInTimeAction.TYPE, pitReq, openPitListener); + } else { + Rewriteable.rewriteAndFetch( + retriever, + searchService.getRewriteContext( + timeProvider::absoluteStartMillis, + resolvedIndices, + searchRequest.source().pointInTimeBuilder() + ), + rewriteRetrieverListener + ); + } + } + + private void closePIT(PointInTimeBuilder pit) { + if (pit == null) { + return; + } + nodeClient.execute(TransportClosePointInTimeAction.TYPE, new ClosePointInTimeRequest(pit.getEncodedId()), new ActionListener<>() { + @Override + public void onResponse(ClosePointInTimeResponse resp) {} + + @Override + public void onFailure(Exception e) {} + }); } static void adjustSearchType(SearchRequest searchRequest, boolean singleShard) { diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index 0605e36b2ea4..0d90df9778a5 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -790,6 +790,7 @@ public QueryRewriteContext newQueryRewriteContext( valuesSourceRegistry, allowExpensiveQueries, scriptService, + null, null ); } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index 2a1062f8876d..ac6512b0839e 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -51,6 +51,7 @@ public CoordinatorRewriteContext( null, null, null, + null, null ); this.indexLongFieldRange = indexLongFieldRange; diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index 81adfee36f92..6a34a991197e 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -62,6 +63,7 @@ public class QueryRewriteContext { protected boolean mapUnmappedFieldAsString; protected Predicate allowedFields; private final ResolvedIndices resolvedIndices; + private final PointInTimeBuilder pit; public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, @@ -77,7 +79,8 @@ public QueryRewriteContext( final ValuesSourceRegistry valuesSourceRegistry, final BooleanSupplier allowExpensiveQueries, final ScriptCompiler scriptService, - final ResolvedIndices resolvedIndices + final ResolvedIndices resolvedIndices, + final PointInTimeBuilder pit ) { this.parserConfiguration = parserConfiguration; @@ -95,6 +98,7 @@ public QueryRewriteContext( this.allowExpensiveQueries = allowExpensiveQueries; this.scriptService = scriptService; this.resolvedIndices = resolvedIndices; + this.pit = pit; } public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) { @@ -112,6 +116,7 @@ public QueryRewriteContext(final XContentParserConfiguration parserConfiguration null, null, null, + null, null ); } @@ -120,7 +125,8 @@ public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis, - final ResolvedIndices resolvedIndices + final ResolvedIndices resolvedIndices, + final PointInTimeBuilder pit ) { this( parserConfiguration, @@ -136,7 +142,8 @@ public QueryRewriteContext( null, null, null, - resolvedIndices + resolvedIndices, + pit ); } @@ -390,4 +397,8 @@ public Iterable> getAllFields() { public ResolvedIndices getResolvedIndices() { return resolvedIndices; } + + public PointInTimeBuilder pointInTimeBuilder() { + return pit; + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index 9d3aa9905c74..315ce3724421 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -269,6 +269,7 @@ private SearchExecutionContext( valuesSourceRegistry, allowExpensiveQueries, scriptService, + null, null ); this.shardId = shardId; diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesService.java b/server/src/main/java/org/elasticsearch/indices/IndicesService.java index c0483ee2c820..5414706d1aeb 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesService.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesService.java @@ -138,6 +138,7 @@ import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; @@ -1759,8 +1760,8 @@ public AliasFilter buildAliasFilter(ClusterState state, String index, Set void registerFromPlugin(List plugins, Function> producer, Consumer consumer) { @@ -1074,6 +1078,9 @@ private void registerFetchSubPhase(FetchSubPhase subPhase) { private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); + registerRetriever(new RetrieverSpec<>(RankDocsRetrieverBuilder.NAME, (p, c) -> { + throw new IllegalArgumentException("[rank_docs] retriever cannot be provided directly"); + })); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } @@ -1173,6 +1180,9 @@ private void registerQueryParsers(List plugins) { registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> { throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly"); })); + registerQuery(new QuerySpec<>(RankDocsQueryBuilder.NAME, RankDocsQueryBuilder::new, parser -> { + throw new IllegalArgumentException("[rank_docs] queries cannot be provided directly"); + })); registerFromPlugin(plugins, SearchPlugin::getQueries, this::registerQuery); diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index b45a2e2e2ca1..02dcf1fee464 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -84,6 +84,7 @@ import org.elasticsearch.search.aggregations.SearchContextAggregations; import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.aggregations.support.AggregationContext.ProductionAggregationContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseContext; @@ -1820,7 +1821,11 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re * Returns a new {@link QueryRewriteContext} with the given {@code now} provider */ public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices) { - return indicesService.getRewriteContext(nowInMillis, resolvedIndices); + return getRewriteContext(nowInMillis, resolvedIndices, null); + } + + public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) { + return indicesService.getRewriteContext(nowInMillis, resolvedIndices, pit); } public CoordinatorRewriteContextProvider getCoordinatorRewriteContextProvider(LongSupplier nowInMillis) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index cb2c53a97fbc..a9062a0b626b 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -203,6 +203,8 @@ public static HighlightBuilder highlight() { private Map runtimeMappings = emptyMap(); + private transient RetrieverBuilder retrieverBuilder; + /** * Constructs a new search source builder. */ @@ -367,6 +369,10 @@ public void writeTo(StreamOutput out) throws IOException { } } + public RetrieverBuilder retriever() { + return retrieverBuilder; + } + /** * Sets the query for this request. */ @@ -1293,7 +1299,6 @@ private SearchSourceBuilder parseXContent( } List knnBuilders = new ArrayList<>(); - RetrieverBuilder retrieverBuilder = null; SearchUsage searchUsage = new SearchUsage(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -1657,9 +1662,7 @@ private SearchSourceBuilder parseXContent( if (specified.isEmpty() == false) { throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified); } - retrieverBuilder.extractToSearchSourceBuilder(this, false); } - searchUsageConsumer.accept(searchUsage); return this; } diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java index 50b3ddc0f370..4553da27b277 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.search.rank; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.StreamInput; @@ -74,4 +75,6 @@ public final int hashCode() { public String toString() { return "RankDoc{" + "score=" + score + ", doc=" + doc + ", shardIndex=" + shardIndex + '}'; } + + public abstract Explanation explain(); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index d8b4ec10410f..ecd71e1d20b4 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.search.rank.feature; +import org.apache.lucene.search.Explanation; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; @@ -54,6 +55,11 @@ protected int doHashCode() { return Objects.hashCode(featureData); } + @Override + public Explanation explain() { + return Explanation.noMatch("No match"); + } + @Override public String getWriteableName() { return NAME; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java new file mode 100644 index 000000000000..a766a51cba95 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQuery.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.elasticsearch.search.rank.RankDoc; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +public class RankDocsQuery extends Query { + private final RankDoc[] docs; + private final int[] segmentStarts; + private final Object contextIdentity; + + /** + * Creates a query. + * + * @param docs the global doc IDs of documents that match, in ascending order + * @param segmentStarts the indexes in docs and scores corresponding to the first matching + * document in each segment. If a segment has no matching documents, it should be assigned + * the index of the next segment that does. There should be a final entry that is always + * docs.length-1. + * @param contextIdentity an object identifying the reader context that was used to build this + * query + */ + RankDocsQuery(RankDoc[] docs, int[] segmentStarts, Object contextIdentity) { + this.docs = docs; + this.segmentStarts = segmentStarts; + this.contextIdentity = contextIdentity; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + if (docs.length == 0) { + return new MatchNoDocsQuery(); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (searcher.getIndexReader().getContext().id() != contextIdentity) { + throw new IllegalStateException("This RankDocsDocQuery was created by a different reader"); + } + return new Weight(this) { + + @Override + public int count(LeafReaderContext context) { + return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + int found = Arrays.binarySearch(docs, doc + context.docBase, (a, b) -> Integer.compare(((RankDoc) a).doc, (int) b)); + if (found < 0) { + return Explanation.noMatch("not in top k documents"); + } + return docs[found].explain(); + } + + @Override + public Scorer scorer(LeafReaderContext context) { + // Segment starts indicate how many docs are in the segment, + // upper equalling lower indicates no documents for this segment + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return null; + } + return new Scorer(this) { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return currentDocId(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return currentDocId(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docId) { + return 0; + } + + @Override + public float score() { + return 0; + } + + @Override + public int docID() { + return currentDocId(); + } + + private int currentDocId() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo].doc - context.docBase; + } + + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return "ScoreAndDocQuery"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + return Arrays.equals(docs, ((RankDocsQuery) obj).docs) + && Arrays.equals(segmentStarts, ((RankDocsQuery) obj).segmentStarts) + && contextIdentity == ((RankDocsQuery) obj).contextIdentity; + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), Arrays.hashCode(docs), Arrays.hashCode(segmentStarts), contextIdentity); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java new file mode 100644 index 000000000000..948d86ba2fa8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; + +public class RankDocsQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "rank_docs"; + + private final RankDoc[] rankDocs; + + public RankDocsQueryBuilder(RankDoc[] rankDocs) { + this.rankDocs = rankDocs; + } + + public RankDocsQueryBuilder(StreamInput in) throws IOException { + super(in); + this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + RankDoc[] shardRankDocs = Arrays.stream(rankDocs) + .filter(r -> r.shardIndex == context.getShardRequestIndex()) + .toArray(RankDoc[]::new); + Arrays.sort(shardRankDocs, Comparator.comparingInt(s -> s.doc)); + IndexReader reader = context.getIndexReader(); + int[] segmentStarts = findSegmentStarts(reader, shardRankDocs); + return new RankDocsQuery(shardRankDocs, segmentStarts, reader.getContext().id()); + } + + private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper, (a, b) -> Integer.compare(((RankDoc) a).doc, (int) b)); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + protected boolean doEquals(RankDocsQueryBuilder other) { + return Arrays.equals(rankDocs, other.rankDocs); + } + + @Override + protected int doHashCode() { + return Arrays.hashCode(rankDocs); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + // TODO + return TransportVersion.current(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java new file mode 100644 index 000000000000..1e41d22d521e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.DisMaxQueryBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class RankDocsRetrieverBuilder extends RetrieverBuilder { + private static final Logger logger = LogManager.getLogger(RankDocsRetrieverBuilder.class); + + public static final String NAME = "rank_docs"; + private final int windowSize; + private final List sources; + private final Supplier rankDocs; + + public RankDocsRetrieverBuilder(int windowSize, List sources, Supplier rankDocs) { + this.windowSize = windowSize; + this.rankDocs = rankDocs; + this.sources = sources; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.sort(Collections.singletonList(new RankDocsSortBuilder(rankDocs.get()))); + if (searchSourceBuilder.explain() != null && searchSourceBuilder.explain()) { + searchSourceBuilder.trackScores(true); + } + var bq = new BoolQueryBuilder(); + var rankQuery = new RankDocsQueryBuilder(rankDocs.get()); + if (searchSourceBuilder.aggregations() != null) { + bq.must(rankQuery); + searchSourceBuilder.postFilter(rankQuery); + } else { + bq.should(rankQuery); + } + for (var preFilterQueryBuilder : preFilterQueryBuilders) { + bq.filter(preFilterQueryBuilder); + } + + DisMaxQueryBuilder disMax = new DisMaxQueryBuilder().tieBreaker(0f); + for (var originalSource : sources) { + if (originalSource.query() != null) { + disMax.add(originalSource.query()); + } + for (var knnSearch : originalSource.knnSearch()) { + // TODO nested + inner_hits + ExactKnnQueryBuilder knn = new ExactKnnQueryBuilder(knnSearch.getQueryVector(), knnSearch.getField()); + disMax.add(knn); + } + } + bq.should(disMax); + searchSourceBuilder.query(bq); + } + + @Override + protected boolean doEquals(Object o) { + RankDocsRetrieverBuilder other = (RankDocsRetrieverBuilder) o; + return Arrays.equals(rankDocs.get(), other.rankDocs.get()); + } + + @Override + protected int doHashCode() { + return Objects.hash(super.hashCode(), windowSize, rankDocs.get()); + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java new file mode 100644 index 000000000000..06c89090f18c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortBuilder; +import org.elasticsearch.search.sort.SortFieldAndFormat; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; + +public class RankDocsSortBuilder extends SortBuilder { + public static final String NAME = "rank_docs"; + + private final RankDoc[] rankDocs; + + public RankDocsSortBuilder(RankDoc[] rankDocs) { + this.rankDocs = rankDocs; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); + } + + @Override + public SortBuilder rewrite(QueryRewriteContext ctx) throws IOException { + return this; + } + + @Override + protected SortFieldAndFormat build(SearchExecutionContext context) throws IOException { + RankDoc[] shardRankDocs = Arrays.stream(rankDocs) + .filter(r -> r.shardIndex == context.getShardRequestIndex()) + .toArray(RankDoc[]::new); + return new SortFieldAndFormat(new RankDocsSortField(shardRankDocs), DocValueFormat.RAW); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + // TODO + return TransportVersion.current(); + } + + @Override + public BucketedSort buildBucketedSort(SearchExecutionContext context, BigArrays bigArrays, int bucketSize, BucketedSort.ExtraData extra) + throws IOException { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortField.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortField.java new file mode 100644 index 000000000000..0aef421a873a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortField.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldComparatorSource; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.comparators.NumericComparator; +import org.apache.lucene.util.hnsw.IntToIntFunction; +import org.elasticsearch.search.rank.RankDoc; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +public class RankDocsSortField extends SortField { + public RankDocsSortField(RankDoc[] rankDocs) { + super("rank", new FieldComparatorSource() { + @Override + public FieldComparator newComparator(String fieldname, int numHits, Pruning pruning, boolean reversed) { + return new RankDocsComparator(numHits, rankDocs); + } + }); + } + + private static class RankDocsComparator extends NumericComparator { + private final int[] values; + private final Map rankDocMap; + private int topValue; + private int bottom; + + private RankDocsComparator(int numHits, RankDoc[] rankDocs) { + super("rank", Integer.MAX_VALUE, false, Pruning.NONE, Integer.BYTES); + this.values = new int[numHits]; + this.rankDocMap = Arrays.stream(rankDocs).collect(Collectors.toMap(k -> k.doc, v -> v.rank)); + } + + @Override + public int compare(int slot1, int slot2) { + return Integer.compare(values[slot1], values[slot2]); + } + + @Override + public Integer value(int slot) { + return Integer.valueOf(values[slot]); + } + + @Override + public void setTopValue(Integer value) { + topValue = value; + } + + @Override + public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { + IntToIntFunction docToRank = doc -> rankDocMap.getOrDefault(context.docBase + doc, Integer.MAX_VALUE); + return new LeafFieldComparator() { + @Override + public void setBottom(int slot) throws IOException { + bottom = values[slot]; + } + + @Override + public int compareBottom(int doc) throws IOException { + return Integer.compare(bottom, docToRank.apply(doc)); + } + + @Override + public int compareTop(int doc) throws IOException { + return Integer.compare(topValue, docToRank.apply(doc)); + } + + @Override + public void copy(int slot, int doc) throws IOException { + values[slot] = docToRank.apply(doc); + } + + @Override + public void setScorer(Scorable scorer) throws IOException {} + }; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 6e3d2a58dbd5..cbd523b1b744 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -14,9 +14,10 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.AbstractObjectParser; -import org.elasticsearch.xcontent.FilterXContentParserWrapper; import org.elasticsearch.xcontent.NamedObjectNotFoundException; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; @@ -42,7 +43,7 @@ * serialization and is expected to be fully extracted to a {@link SearchSourceBuilder} * prior to any transport calls. */ -public abstract class RetrieverBuilder implements ToXContent { +public abstract class RetrieverBuilder implements Rewriteable, ToXContent { public static final NodeFeature RETRIEVERS_SUPPORTED = new NodeFeature("retrievers_supported"); @@ -72,36 +73,6 @@ private void retrieverName(String retrieverName) { * compound retriever. */ public static RetrieverBuilder parseTopLevelRetrieverBuilder(XContentParser parser, RetrieverParserContext context) throws IOException { - parser = new FilterXContentParserWrapper(parser) { - - int nestedDepth = 0; - - @Override - public T namedObject(Class categoryClass, String name, Object context) throws IOException { - if (categoryClass.equals(RetrieverBuilder.class)) { - nestedDepth++; - - if (nestedDepth > 2) { - throw new IllegalArgumentException( - "the nested depth of the [" + name + "] retriever exceeds the maximum nested depth [2] for retrievers" - ); - } - } - - T namedObject = getXContentRegistry().parseNamedObject(categoryClass, name, this, context); - - if (categoryClass.equals(RetrieverBuilder.class)) { - nestedDepth--; - } - - return namedObject; - } - }; - - return parseInnerRetrieverBuilder(parser, context); - } - - protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser parser, RetrieverParserContext context) throws IOException { Objects.requireNonNull(context); if (parser.currentToken() != XContentParser.Token.START_OBJECT && parser.nextToken() != XContentParser.Token.START_OBJECT) { @@ -188,6 +159,15 @@ public List getPreFilterQueryBuilders() { return preFilterQueryBuilders; } + public boolean requiresPointInTime() { + return false; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + return this; + } + /** * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}. * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}. diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index 469478077061..f494db0c92fe 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -14,7 +14,6 @@ import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; @@ -118,59 +117,28 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder if (queryBuilder != null) { boolQueryBuilder.must(queryBuilder); } - - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); + searchSourceBuilder.query(queryBuilder); } else if (queryBuilder != null) { - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(queryBuilder)); + searchSourceBuilder.query(queryBuilder); } if (searchAfterBuilder != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + SEARCH_AFTER_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.searchAfter(searchAfterBuilder.getSortValues()); } if (terminateAfter != SearchContext.DEFAULT_TERMINATE_AFTER) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + TERMINATE_AFTER_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.terminateAfter(terminateAfter); } if (sortBuilders != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + SORT_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.sort(sortBuilders); } if (minScore != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + MIN_SCORE_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.minScore(minScore); } if (collapseBuilder != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + COLLAPSE_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.collapse(collapseBuilder); } } diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index a35dac815751..98de321d792e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1763,7 +1763,8 @@ protected void doWriteTo(StreamOutput out) throws IOException { null, null, new SearchTransportAPMMetrics(TelemetryProvider.NOOP.getMeterRegistry()), - new SearchResponseMetrics(TelemetryProvider.NOOP.getMeterRegistry()) + new SearchResponseMetrics(TelemetryProvider.NOOP.getMeterRegistry()), + client ); CountDownLatch latch = new CountDownLatch(1); diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 6419759ab596..ecb9409ff6b5 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -2426,7 +2426,8 @@ public RecyclerBytesStreamOutput newNetworkBytesStream() { namedWriteableRegistry, EmptySystemIndices.INSTANCE.getExecutorSelector(), new SearchTransportAPMMetrics(TelemetryProvider.NOOP.getMeterRegistry()), - new SearchResponseMetrics(TelemetryProvider.NOOP.getMeterRegistry()) + new SearchResponseMetrics(TelemetryProvider.NOOP.getMeterRegistry()), + client ) ); actions.put( diff --git a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankDoc.java b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankDoc.java index f2f3cb82d203..d265b194b09b 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankDoc.java +++ b/test/framework/src/main/java/org/elasticsearch/search/rank/TestRankDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.search.rank; +import org.apache.lucene.search.Explanation; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -38,6 +39,11 @@ public int doHashCode() { return 0; } + @Override + public Explanation explain() { + return null; + } + @Override public String getWriteableName() { return "test_rank_doc"; diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 271df2a971fb..2a3cc3a248f4 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -617,7 +617,8 @@ QueryRewriteContext createQueryRewriteContext() { null, () -> true, scriptService, - createMockResolvedIndices() + createMockResolvedIndices(), + null ); } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java index 8f078c0c4d11..a3cc472c6611 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.rank.rrf; +import org.apache.lucene.search.Explanation; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; @@ -92,6 +93,47 @@ public String toString() { + '}'; } + @Override + public Explanation explain() { + int queries = positions.length; + Explanation[] details = new Explanation[queries]; + int rankConstant = 60; + for (int i = 0; i < queries; i++) { + if (positions[i] == RRFRankDoc.NO_RANK) { + final String description = "rrf score: [0], result not found in query " + i; + details[i] = Explanation.noMatch(description); + } else { + final int rank = positions[i] + 1; + details[i] = Explanation.match( + rank, + "rrf score: [" + + (1f / (rank + rankConstant)) + + "], " + + "for rank [" + + (rank) + + "] in query " + + i + + " computed as [1 / (" + + (rank) + + " + " + + rankConstant + + "]), for matching query" + ); + } + } + return Explanation.match( + score, + "rrf score: [" + + score + + "] computed for initial ranks " + + Arrays.toString(Arrays.stream(positions).map(x -> x + 1).toArray()) + + " with rankConstant: [" + + rankConstant + + "] as sum of [1 / (rank + rankConstant)] for each query", + details + ); + } + @Override public String getWriteableName() { return NAME; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index e5a798310727..c23df5d28154 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -1,62 +1,78 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ package org.elasticsearch.xpack.rank.rrf; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.SortBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.XPackPlugin; import java.io.IOException; -import java.util.Collections; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; -import static org.elasticsearch.xpack.rank.rrf.RRFRankPlugin.NAME; - -/** - * An rrf retriever is used to represent an rrf rank element, but - * as a tree-like structure. This retriever is a compound retriever - * meaning it has a set of child retrievers that each return a set of - * top docs that will then be combined and ranked according to the rrf - * formula. - */ -public final class RRFRetrieverBuilder extends RetrieverBuilder { +import static org.elasticsearch.search.rank.RankDoc.NO_RANK; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +public class RRFRetrieverBuilder extends RetrieverBuilder { + public static final String NAME = "rrf"; public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported"); public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); - public static final ObjectParser PARSER = new ObjectParser<>( + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, - RRFRetrieverBuilder::new + false, + args -> { + int rankWindowSize = args[1] == null ? RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; + int rankConstant = args[2] == null ? RRFRankBuilder.DEFAULT_RANK_CONSTANT : (int) args[2]; + return new RRFRetrieverBuilder((List) args[0], rankWindowSize, rankConstant); + } ); static { - PARSER.declareObjectArray((r, v) -> r.retrieverBuilders = v, (p, c) -> { + PARSER.declareObjectArray(constructorArg(), (p, c) -> { p.nextToken(); String name = p.currentName(); RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c); p.nextToken(); return retrieverBuilder; }, RETRIEVERS_FIELD); - PARSER.declareInt((r, v) -> r.rankWindowSize = v, RANK_WINDOW_SIZE_FIELD); - PARSER.declareInt((r, v) -> r.rankConstant = v, RANK_CONSTANT_FIELD); - + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -64,49 +80,126 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP if (context.clusterSupportsFeature(RRF_RETRIEVER_SUPPORTED) == false) { throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]"); } - if (RRFRankPlugin.RANK_RRF_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { - throw LicenseUtils.newComplianceException("Reciprocal Rank Fusion (RRF)"); - } return PARSER.apply(parser, context); } - List retrieverBuilders = Collections.emptyList(); - int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; - int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT; + private record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} + + private final List retrievers; + private final int rankWindowSize; + private final int rankConstant; + private final SetOnce rankDocs; + + public RRFRetrieverBuilder(List retrieverBuilders, int rankWindowSize, int rankConstant) { + this( + retrieverBuilders.stream().map(r -> new RetrieverSource(r, null)).collect(Collectors.toList()), + rankWindowSize, + rankConstant, + null + ); + } + + private RRFRetrieverBuilder(List retrievers, int rankWindowSize, int rankConstant, SetOnce rankDocs) { + this.retrievers = retrievers; + this.rankWindowSize = rankWindowSize; + this.rankConstant = rankConstant; + this.rankDocs = rankDocs; + } @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (compoundUsed) { - throw new IllegalArgumentException("[rank] cannot be used in children of compound retrievers"); - } + public String getName() { + return NAME; + } - for (RetrieverBuilder retrieverBuilder : retrieverBuilders) { - if (preFilterQueryBuilders.isEmpty() == false) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + @Override + public boolean requiresPointInTime() { + return true; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (ctx.pointInTimeBuilder() == null) { + throw new IllegalStateException("PIT is required"); + } + List newRetrievers = new ArrayList<>(); + boolean hasChanged = false; + for (var source : retrievers) { + RetrieverBuilder rewritten = source.retriever.rewrite(ctx); + if (rewritten != source.retriever) { + newRetrievers.add(new RetrieverSource(rewritten, null)); + hasChanged |= rewritten != source.retriever; + } else if (rewritten == source.retriever) { + SearchSourceBuilder sourceBuilder; + if (source.source == null) { + sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(ctx.pointInTimeBuilder()).size(rankWindowSize); + rewritten.extractToSearchSourceBuilder(sourceBuilder, false); + List> sortBuilders = sourceBuilder.sorts() != null + ? new ArrayList<>(sourceBuilder.sorts()) + : new ArrayList<>(); + if (sortBuilders.isEmpty()) { + sortBuilders.add(new ScoreSortBuilder()); + } + sourceBuilder.sort(sortBuilders); + } else { + sourceBuilder = source.source; + } + var rewrittenSource = sourceBuilder.rewrite(ctx); + newRetrievers.add(new RetrieverSource(rewritten, rewrittenSource)); + hasChanged |= rewrittenSource != source.source; } + } + if (hasChanged) { + return new RRFRetrieverBuilder(newRetrievers, rankWindowSize, rankConstant, null); + } - retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, true); + if (rankDocs != null) { + return this; } - searchSourceBuilder.rankBuilder(new RRFRankBuilder(rankWindowSize, rankConstant)); - } + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (var ret : retrievers) { + SearchRequest searchRequest = new SearchRequest().source(ret.source); + searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); + multiSearchRequest.add(searchRequest); + } + final SetOnce results = new SetOnce<>(); + ctx.registerAsyncAction((client, listener) -> { + client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { + @Override + public void onResponse(MultiSearchResponse items) { + List topDocs = new ArrayList<>(); + for (int i = 0; i < items.getResponses().length; i++) { + var item = items.getResponses()[i]; + topDocs.add(getTopDocs(item.getResponse())); + } + results.set(combineQueryPhaseResults(topDocs)); + listener.onResponse(null); + } - // ---- FOR TESTING XCONTENT PARSING ---- + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + }); + return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.source).toList(), results::get); + } @Override - public String getName() { - return NAME; + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new IllegalStateException("Should not be called, missing a rewrite?"); } + // ---- FOR TESTING XCONTENT PARSING ---- @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { - if (retrieverBuilders.isEmpty() == false) { + if (retrievers.isEmpty() == false) { builder.startArray(RETRIEVERS_FIELD.getPreferredName()); - for (RetrieverBuilder retrieverBuilder : retrieverBuilders) { + for (var entry : retrievers) { builder.startObject(); - builder.field(retrieverBuilder.getName()); - retrieverBuilder.toXContent(builder, params); + builder.field(entry.retriever.getName()); + entry.retriever.toXContent(builder, params); builder.endObject(); } @@ -120,15 +213,93 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept @Override public boolean doEquals(Object o) { RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; - return rankWindowSize == that.rankWindowSize - && rankConstant == that.rankConstant - && Objects.equals(retrieverBuilders, that.retrieverBuilders); + return rankWindowSize == that.rankWindowSize && rankConstant == that.rankConstant && Objects.equals(retrievers, that.retrievers); } @Override public int doHashCode() { - return Objects.hash(retrieverBuilders, rankWindowSize, rankConstant); + return Objects.hash(retrievers, rankWindowSize, rankConstant); + } + + private ScoreDoc[] getTopDocs(SearchResponse searchResponse) { + int size = Math.min(rankWindowSize, searchResponse.getHits().getHits().length); + ScoreDoc[] docs = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + var hit = searchResponse.getHits().getAt(i); + long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; + int shardIndex = (int) (sortValue >> 32); + docs[i] = new ScoreDoc(hit.docId(), hit.getScore(), shardIndex); + } + return docs; } - // ---- END FOR TESTING ---- + public RRFRankDoc[] combineQueryPhaseResults(List rankResults) { + // combine the disjointed sets of TopDocs into a single set or RRFRankDocs + // each RRFRankDoc will have both the position and score for each query where + // it was within the result set for that query + // if a doc isn't part of a result set its position will be NO_RANK [0] and + // its score is [0f] + int queries = rankResults.size(); + Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + int index = 0; + for (var rrfRankResult : rankResults) { + int rank = 1; + for (ScoreDoc scoreDoc : rrfRankResult) { + final int findex = index; + final int frank = rank; + long docAndShard = (((long) scoreDoc.shardIndex) << 32) | (scoreDoc.doc & 0xFFFFFFFFL); + docsToRankResults.compute(docAndShard, (key, value) -> { + if (value == null) { + value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries); + } + + // calculate the current rrf score for this document + // later used to sort and covert to a rank + value.score += 1.0f / (rankConstant + frank); + + // record the position for each query + // for explain and debugging + value.positions[findex] = frank - 1; + + // record the score for each query + // used to later re-rank on the coordinator + value.scores[findex] = scoreDoc.score; + + return value; + }); + ++rank; + } + ++index; + } + + // sort the results based on rrf score, tiebreaker based on smaller doc id + RRFRankDoc[] sortedResults = docsToRankResults.values().toArray(RRFRankDoc[]::new); + Arrays.sort(sortedResults, (RRFRankDoc rrf1, RRFRankDoc rrf2) -> { + if (rrf1.score != rrf2.score) { + return rrf1.score < rrf2.score ? 1 : -1; + } + assert rrf1.positions.length == rrf2.positions.length; + for (int qi = 0; qi < rrf1.positions.length; ++qi) { + if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) { + if (rrf1.scores[qi] != rrf2.scores[qi]) { + return rrf1.scores[qi] < rrf2.scores[qi] ? 1 : -1; + } + } else if (rrf1.positions[qi] != NO_RANK) { + return -1; + } else if (rrf2.positions[qi] != NO_RANK) { + return 1; + } + } + return rrf1.doc < rrf2.doc ? -1 : 1; + }); + // trim the results if needed, otherwise each shard will always return `rank_window_size` results. + // pagination and all else will happen on the coordinator when combining the shard responses + RRFRankDoc[] topResults = new RRFRankDoc[Math.min(rankWindowSize, sortedResults.length)]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = sortedResults[rank]; + topResults[rank].rank = rank + 1; + topResults[rank].score = Float.NaN; + } + return topResults; + } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index 330c936327b8..c350e3f0d6ed 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -29,25 +29,22 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase(retrieverCount); + List retrieverBuilders = new ArrayList<>(retrieverCount); while (retrieverCount > 0) { - rrfRetrieverBuilder.retrieverBuilders.add(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); + retrieverBuilders.add(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); --retrieverCount; } - - return rrfRetrieverBuilder; + int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; + if (randomBoolean()) { + rankWindowSize = randomIntBetween(1, 10000); + } + int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT; + if (randomBoolean()) { + rankConstant = randomIntBetween(1, 1000000); + } + return new RRFRetrieverBuilder(retrieverBuilders, rankWindowSize, rankConstant); } @Override