diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e6b98971ff8cb..325f65bd01fd2 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -187,6 +187,7 @@ static TransportVersion def(int id) { public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0); public static final TransportVersion RANK_DOC_IN_SHARD_FETCH_REQUEST = def(8_679_00_0); public static final TransportVersion SECURITY_SETTINGS_REQUEST_TIMEOUTS = def(8_680_00_0); + public static final TransportVersion RANK_DOCS_RETRIEVER = def(8_681_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index 31a4ca97aad6a..91088f5e24af4 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -73,6 +73,7 @@ import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.lucene.grouping.TopFieldGroups; +import org.elasticsearch.search.retriever.RankDocsSortField; import org.elasticsearch.search.sort.ShardDocSortField; import java.io.IOException; @@ -548,6 +549,8 @@ private static SortField rewriteMergeSortField(SortField sortField) { return newSortField; } else if (sortField.getClass() == ShardDocSortField.class) { return new SortField(sortField.getField(), SortField.Type.LONG, sortField.getReverse()); + } else if (sortField.getClass() == RankDocsSortField.class) { + return new SortField(sortField.getField(), SortField.Type.INT, sortField.getReverse()); } else { return sortField; } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 0a5f839a76315..0eb3ef9543327 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -843,7 +843,7 @@ private void registerSorts() { namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScoreSortBuilder.NAME, ScoreSortBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScriptSortBuilder.NAME, ScriptSortBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, FieldSortBuilder.NAME, FieldSortBuilder::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, FieldSortBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, RankDocsSortBuilder::new)); } private static void registerFromPlugin(List plugins, Function> producer, Consumer consumer) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index bd763e60be715..a44f509158c53 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -12,6 +12,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; @@ -24,7 +25,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -94,12 +94,6 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP } private final KnnSearchBuilder knnSearchBuilder; - private final String field; - private final float[] queryVector; - private final QueryVectorBuilder queryVectorBuilder; - private final int k; - private final int numCands; - private final Float similarity; public KnnRetrieverBuilder( String field, @@ -109,13 +103,20 @@ public KnnRetrieverBuilder( int numCands, Float similarity ) { - this.knnSearchBuilder = new KnnSearchBuilder(field, VectorData.fromFloats(queryVector), queryVectorBuilder, k, numCands, similarity); - this.field = field; - this.queryVector = queryVector; - this.queryVectorBuilder = queryVectorBuilder; - this.k = k; - this.numCands = numCands; - this.similarity = similarity; + this.knnSearchBuilder = new KnnSearchBuilder( + field, + VectorData.fromFloats(queryVector), + queryVectorBuilder, + k, + numCands, + similarity + ); + } + + private KnnRetrieverBuilder(KnnRetrieverBuilder clone, KnnSearchBuilder knnSearchBuilder, List preFilterQueryBuilders) { + super(clone); + this.knnSearchBuilder = knnSearchBuilder; + this.preFilterQueryBuilders = preFilterQueryBuilders; } // ---- FOR TESTING XCONTENT PARSING ---- @@ -126,13 +127,22 @@ public String getName() { } @Override - public QueryBuilder originalQuery() { - // TODO nested + inner_hits - ExactKnnQueryBuilder knn = new ExactKnnQueryBuilder(knnSearchBuilder.getQueryVector(), knnSearchBuilder.getField()); - if (preFilterQueryBuilders.isEmpty()) { - return knn; + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + var rewritten = knnSearchBuilder.rewrite(ctx); + boolean hasChanged = rewritten != knnSearchBuilder; + var rewrittenFilters = rewritePreFilters(ctx); + hasChanged |= rewrittenFilters != preFilterQueryBuilders; + if (hasChanged) { + return new KnnRetrieverBuilder(this, rewritten, rewrittenFilters); } - var ret = new BoolQueryBuilder().should(knn); + return this; + } + + @Override + public QueryBuilder originalQuery(QueryBuilder leadQuery) { + // TODO nested + inner_hits + BoolQueryBuilder ret = new BoolQueryBuilder().must(leadQuery) + .should(new ExactKnnQueryBuilder(knnSearchBuilder.getQueryVector(), knnSearchBuilder.getField())); preFilterQueryBuilders.stream().forEach(ret::filter); return ret; } @@ -152,39 +162,32 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(FIELD_FIELD.getPreferredName(), field); - builder.field(K_FIELD.getPreferredName(), k); - builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); + builder.field(FIELD_FIELD.getPreferredName(), knnSearchBuilder.getField()); + builder.field(K_FIELD.getPreferredName(), knnSearchBuilder.k()); + builder.field(NUM_CANDS_FIELD.getPreferredName(), knnSearchBuilder.k()); - if (queryVector != null) { - builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + if (knnSearchBuilder.getQueryVector() != null) { + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), knnSearchBuilder.getQueryVector()); } - if (queryVectorBuilder != null) { - builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder); + if (knnSearchBuilder.getQueryVectorBuilder() != null) { + builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), knnSearchBuilder.getQueryVectorBuilder()); } - if (similarity != null) { - builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity); + if (knnSearchBuilder.getSimilarity() != null) { + builder.field(VECTOR_SIMILARITY.getPreferredName(), knnSearchBuilder.getSimilarity()); } } @Override public boolean doEquals(Object o) { KnnRetrieverBuilder that = (KnnRetrieverBuilder) o; - return k == that.k - && numCands == that.numCands - && Objects.equals(field, that.field) - && Arrays.equals(queryVector, that.queryVector) - && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) - && Objects.equals(similarity, that.similarity); + return Objects.equals(knnSearchBuilder, that.knnSearchBuilder); } @Override public int doHashCode() { - int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); - result = 31 * result + Arrays.hashCode(queryVector); - return result; + return Objects.hash(knnSearchBuilder); } // ---- END TESTING ---- diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java index 948d86ba2fa81..e08432e96483a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsQueryBuilder.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -23,7 +24,7 @@ import java.util.Comparator; public class RankDocsQueryBuilder extends AbstractQueryBuilder { - public static final String NAME = "rank_docs"; + public static final String NAME = "rank"; private final RankDoc[] rankDocs; @@ -92,7 +93,6 @@ protected int doHashCode() { @Override public TransportVersion getMinimalSupportedVersion() { - // TODO - return TransportVersion.current(); + return TransportVersions.RANK_DOCS_RETRIEVER; } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index 5601fc94f5cb1..68541372b5157 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -13,9 +13,9 @@ import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.DisMaxQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; @@ -25,6 +25,9 @@ import java.util.Objects; import java.util.function.Supplier; +/** + * An {@link RetrieverBuilder} that is used to + */ public class RankDocsRetrieverBuilder extends RetrieverBuilder { private static final Logger logger = LogManager.getLogger(RankDocsRetrieverBuilder.class); @@ -33,10 +36,16 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder { private final List sources; private final Supplier rankDocs; - public RankDocsRetrieverBuilder(int windowSize, List sources, Supplier rankDocs) { + public RankDocsRetrieverBuilder( + int windowSize, + List rewritten, + Supplier rankDocs, + List preFilterQueryBuilders + ) { this.windowSize = windowSize; this.rankDocs = rankDocs; - this.sources = sources; + this.sources = rewritten; + this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override @@ -44,11 +53,37 @@ public String getName() { return NAME; } + private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException { + for (var source : sources) { + var newSource = source.rewrite(ctx); + if (newSource != source) { + return true; + } + } + return false; + } + @Override - public QueryBuilder originalQuery() { + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + assert sourceShouldRewrite(ctx) == false : "Retriever sources should be rewritten first"; + var rewrittenFilters = rewritePreFilters(ctx); + if (rewrittenFilters != preFilterQueryBuilders) { + return new RankDocsRetrieverBuilder(windowSize, sources, rankDocs, rewrittenFilters); + } + return this; + } + + @Override + public QueryBuilder originalQuery(QueryBuilder leadQuery) { DisMaxQueryBuilder disMax = new DisMaxQueryBuilder().tieBreaker(0f); for (var source : sources) { - disMax.add(source.originalQuery()); + var query = source.originalQuery(leadQuery); + if (query != null) { + if (source.retrieverName != null) { + query.queryName(source.retrieverName); + } + disMax.add(query); + } } return disMax; } @@ -70,7 +105,10 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder for (var preFilterQueryBuilder : preFilterQueryBuilders) { bq.filter(preFilterQueryBuilder); } - bq.should(originalQuery()); + QueryBuilder originalQuery = originalQuery(rankQuery); + if (originalQuery != null) { + bq.should(originalQuery); + } searchSourceBuilder.query(bq); } @@ -82,7 +120,7 @@ protected boolean doEquals(Object o) { @Override protected int doHashCode() { - return Objects.hash(super.hashCode(), windowSize); + return Objects.hash(super.hashCode(), Arrays.hashCode(rankDocs.get()), windowSize); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java index 06c89090f18c5..a3e79638a5708 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsSortBuilder.java @@ -9,6 +9,8 @@ package org.elasticsearch.search.retriever; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.query.QueryRewriteContext; @@ -24,7 +26,7 @@ import java.util.Arrays; public class RankDocsSortBuilder extends SortBuilder { - public static final String NAME = "rank_docs"; + public static final String NAME = "rank_sort"; private final RankDoc[] rankDocs; @@ -32,9 +34,8 @@ public RankDocsSortBuilder(RankDoc[] rankDocs) { this.rankDocs = rankDocs; } - @Override - public String getWriteableName() { - return NAME; + public RankDocsSortBuilder(StreamInput in) throws IOException { + this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new); } @Override @@ -42,6 +43,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); } + @Override + public String getWriteableName() { + return NAME; + } + @Override public SortBuilder rewrite(QueryRewriteContext ctx) throws IOException { return this; @@ -57,8 +63,7 @@ protected SortFieldAndFormat build(SearchExecutionContext context) throws IOExce @Override public TransportVersion getMinimalSupportedVersion() { - // TODO - return TransportVersion.current(); + return TransportVersions.RANK_DOCS_RETRIEVER; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index aaacb6a832db0..eccefc2e5ea61 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -152,6 +152,27 @@ public static RetrieverBuilder parseTopLevelRetrieverBuilder(XContentParser pars protected String retrieverName; + public RetrieverBuilder() {} + + protected RetrieverBuilder(RetrieverBuilder clone) { + this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.retrieverName = clone.retrieverName; + } + + protected final List rewritePreFilters(QueryRewriteContext ctx) throws IOException { + List newFilters = new ArrayList<>(preFilterQueryBuilders.size()); + boolean changed = false; + for (var filter : preFilterQueryBuilders) { + var newFilter = filter.rewrite(ctx); + changed |= filter != newFilter; + newFilters.add(newFilter); + } + if (changed) { + return newFilters; + } + return preFilterQueryBuilders; + } + /** * Gets the filters for this retriever. */ @@ -173,12 +194,13 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { /** * Returns the original {@link QueryBuilder} used to compute the top documents. + * @param leadQuery */ - public abstract QueryBuilder originalQuery(); + public abstract QueryBuilder originalQuery(QueryBuilder leadQuery); /** - * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}. - * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}. + * This method is called at the end of rewrite on the final retriever. + * Elements of the search request are expected to be "extracted" into the {@link SearchSourceBuilder}. */ public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index fb1196c0b82e9..6b61bee45c009 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -108,27 +108,42 @@ public static StandardRetrieverBuilder fromXContent(XContentParser parser, Retri StandardRetrieverBuilder() {} - public StandardRetrieverBuilder(StandardRetrieverBuilder clone, QueryBuilder rewritten) { + public StandardRetrieverBuilder(StandardRetrieverBuilder clone, QueryBuilder rewritten, List preFilterQueryBuilders) { + super(clone); this.queryBuilder = rewritten; + this.preFilterQueryBuilders = preFilterQueryBuilders; this.searchAfterBuilder = clone.searchAfterBuilder; this.terminateAfter = clone.terminateAfter; this.sortBuilders = clone.sortBuilders; this.minScore = clone.minScore; this.collapseBuilder = clone.collapseBuilder; - this.preFilterQueryBuilders = clone.preFilterQueryBuilders; } @Override public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - var rewritten = queryBuilder.rewrite(ctx); - if (rewritten != queryBuilder) { - return new StandardRetrieverBuilder(this, rewritten); + // We only rewrite the query to avoid redundant work. + // The other query components are naturally rewritten during the search phase. + var rewritten = queryBuilder != null ? queryBuilder.rewrite(ctx) : null; + boolean hasChanged = rewritten != queryBuilder; + var rewrittenFilters = rewritePreFilters(ctx); + hasChanged |= rewrittenFilters != preFilterQueryBuilders; + if (hasChanged) { + return new StandardRetrieverBuilder(this, rewritten, rewrittenFilters); } return this; } @Override - public QueryBuilder originalQuery() { + public QueryBuilder originalQuery(QueryBuilder leadQuery) { + /** + * What actions should we take with {@link KnnVectorQueryBuilder} or {@link MultiTermQueryBuilder} when a + * compound retriever executes the original queries? Our goal is to retain these queries in scenarios where + * aggregations, highlighting, or inner_hits are used. However, this approach can be costly for compound + * retrievers since they will be executed twice: once as a must clause at this level and a second time as a + * should clause at the upper level (compound retriever). + * Therefore, it would be beneficial to rewrite these queries at the upper level to focus solely on + * scoring/matching similar to what {@link RetrieverBuilder#originalQuery(QueryBuilder)} is doing. + */ if (preFilterQueryBuilders.isEmpty()) { return queryBuilder; } @@ -149,7 +164,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder if (queryBuilder != null) { boolQueryBuilder.must(queryBuilder); } - searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.query(boolQueryBuilder); } else if (queryBuilder != null) { searchSourceBuilder.query(queryBuilder); } diff --git a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java index 898d7fae0bf16..bc6c021bd7504 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java @@ -30,7 +30,11 @@ public ShardDocSortField(int shardRequestIndex, boolean reverse) { this.shardRequestIndex = shardRequestIndex; } - public static int shardRequestIndex(long value) { + public static int decodeDoc(long value) { + return (int) value; + } + + public static int decodeShardRequestIndex(long value) { return (int) (value >> 32); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 3c03d3258ebab..f7e9b62dbc9f9 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -298,6 +298,10 @@ public String getField() { return field; } + public Float getSimilarity() { + return similarity; + } + public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); diff --git a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestRetrieverBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestRetrieverBuilder.java index 40cc1890f69ed..739494e20f8b3 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestRetrieverBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestRetrieverBuilder.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.retriever; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESTestCase; @@ -65,6 +66,11 @@ public TestRetrieverBuilder(String value) { this.value = value; } + @Override + public QueryBuilder originalQuery(QueryBuilder leadQuery) { + throw new UnsupportedOperationException("only used for parsing tests"); + } + @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { throw new UnsupportedOperationException("only used for parsing tests"); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index f02e1ce5cbb2f..e13e21f083081 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -1,9 +1,8 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ package org.elasticsearch.xpack.rank.rrf; @@ -19,8 +18,10 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.StoredFieldsContext; @@ -37,6 +38,7 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.XPackPlugin; import java.io.IOException; import java.util.ArrayList; @@ -86,6 +88,9 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP if (context.clusterSupportsFeature(RRF_RETRIEVER_SUPPORTED) == false) { throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]"); } + if (RRFRankPlugin.RANK_RRF_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { + throw LicenseUtils.newComplianceException("Reciprocal Rank Fusion (RRF)"); + } return PARSER.apply(parser, context); } @@ -94,7 +99,7 @@ private record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder s private final List retrievers; private final int rankWindowSize; private final int rankConstant; - private final SetOnce rankDocs; + private final SetOnce rankDocsSupplier; public RRFRetrieverBuilder(List retrieverBuilders, int rankWindowSize, int rankConstant) { this( @@ -105,11 +110,30 @@ public RRFRetrieverBuilder(List retrieverBuilders, int rankWin ); } - private RRFRetrieverBuilder(List retrievers, int rankWindowSize, int rankConstant, SetOnce rankDocs) { + private RRFRetrieverBuilder( + List retrievers, + int rankWindowSize, + int rankConstant, + SetOnce rankDocsSupplier + ) { this.retrievers = retrievers; this.rankWindowSize = rankWindowSize; this.rankConstant = rankConstant; - this.rankDocs = rankDocs; + this.rankDocsSupplier = rankDocsSupplier; + } + + private RRFRetrieverBuilder( + RRFRetrieverBuilder clone, + List preFilterQueryBuilders, + List retrievers, + SetOnce rankDocsSupplier + ) { + super(clone); + this.preFilterQueryBuilders = preFilterQueryBuilders; + this.rankWindowSize = clone.rankWindowSize; + this.rankConstant = clone.rankConstant; + this.retrievers = retrievers; + this.rankDocsSupplier = rankDocsSupplier; } @Override @@ -127,8 +151,18 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (ctx.pointInTimeBuilder() == null) { throw new IllegalStateException("PIT is required"); } - List newRetrievers = new ArrayList<>(); + + if (rankDocsSupplier != null) { + return this; + } + + // Rewrite prefilters boolean hasChanged = false; + var newPreFilters = rewritePreFilters(ctx); + hasChanged |= newPreFilters != preFilterQueryBuilders; + + // Rewrite retriever sources + List newRetrievers = new ArrayList<>(); for (var entry : retrievers) { RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); if (newRetriever != entry.retriever) { @@ -142,21 +176,18 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { } } if (hasChanged) { - return new RRFRetrieverBuilder(newRetrievers, rankWindowSize, rankConstant, null); + return new RRFRetrieverBuilder(this, newPreFilters, newRetrievers, null); } - if (rankDocs != null) { - return this; - } - - MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + // execute searches + final SetOnce results = new SetOnce<>(); + final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); for (var entry : retrievers) { SearchRequest searchRequest = new SearchRequest().source(entry.source); // The can match phase can reorder shards, so we disable it to ensure the stable ordering searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); multiSearchRequest.add(searchRequest); } - final SetOnce results = new SetOnce<>(); ctx.registerAsyncAction((client, listener) -> { client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { @Override @@ -176,11 +207,17 @@ public void onFailure(Exception e) { } }); }); - return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get); + + return new RankDocsRetrieverBuilder( + rankWindowSize, + newRetrievers.stream().map(s -> s.retriever).toList(), + results::get, + newPreFilters + ); } @Override - public QueryBuilder originalQuery() { + public QueryBuilder originalQuery(QueryBuilder leadQuery) { throw new IllegalStateException(NAME + " cannot be nested"); } @@ -225,6 +262,18 @@ private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, Re .storedFields(new StoredFieldsContext(false)) .size(rankWindowSize); retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false); + + // apply the pre-filters + if (preFilterQueryBuilders.size() > 0) { + QueryBuilder query = sourceBuilder.query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + if (query != null) { + newQuery.must(query); + } + preFilterQueryBuilders.stream().forEach(newQuery::filter); + sourceBuilder.query(newQuery); + } + // Record the shard id in the sort result List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); if (sortBuilders.isEmpty()) { @@ -240,8 +289,10 @@ private ScoreDoc[] getTopDocs(SearchResponse searchResponse) { ScoreDoc[] docs = new ScoreDoc[size]; for (int i = 0; i < size; i++) { var hit = searchResponse.getHits().getAt(i); - int shardRequestIndex = ShardDocSortField.shardRequestIndex((long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]); - docs[i] = new ScoreDoc(hit.docId(), hit.getScore(), shardRequestIndex); + long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; + int doc = ShardDocSortField.decodeDoc(sortValue); + int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); + docs[i] = new ScoreDoc(doc, hit.getScore(), shardRequestIndex); } return docs; } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml index a4972d0557dab..24f0de4653918 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml @@ -80,17 +80,14 @@ setup: size: 10 - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term term" } - match: { hits.hits.0.fields.keyword.0: "other" } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term" } - match: { hits.hits.1.fields.keyword.0: "keyword" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -128,12 +125,10 @@ setup: - match: { hits.total.value : 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } @@ -176,17 +171,14 @@ setup: - match: { hits.total.value : 3 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml index 575723853f0aa..b4893bfec0849 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml @@ -164,11 +164,8 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 3 } - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0._rank: 2 } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 3 } - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 4 } --- "Standard pagination outside rank_window_size": @@ -378,7 +375,6 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 3 } --- @@ -489,9 +485,7 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "4" } - - match: { hits.hits.1._rank: 2 } - do: search: @@ -594,9 +588,7 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 3 } - match: { hits.hits.1._id: "2" } - - match: { hits.hits.1._rank: 4 } --- "Pagination within interleaved results, different result set sizes, rank_window_size covering all results": @@ -690,9 +682,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - do: search: @@ -779,9 +769,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._rank: 3 } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 4 } - do: search: @@ -868,7 +856,6 @@ setup: - match: { hits.total.value: 5 } - length: { hits.hits: 1 } - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0._rank: 5 } --- @@ -965,9 +952,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "4" } - - match: { hits.hits.1._rank: 2 } - do: search: diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml index 76cedf44d3dbe..fdb2ad3fd2206 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml @@ -129,19 +129,14 @@ setup: size: 5 - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "6" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.3._id: "7" } - - match: { hits.hits.3._rank: 4 } - match: { hits.hits.4._id: "3" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }} - match: { aggregations.sums.value.text_total: 25 } @@ -196,7 +191,6 @@ setup: - match: { hits.total.value: 6 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - close_to: { aggregations.sums.value.asc_total: { value: 33.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 39.0, error: 0.001 }} @@ -272,19 +266,14 @@ setup: size: 5 - match: { hits.hits.0._id: "6" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.2._id: "7" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.3._id: "4" } - - match: { hits.hits.3._rank: 4 } - match: { hits.hits.4._id: "8" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 30.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 30.0, error: 0.001 }} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml index d3d45ef2b18e8..42a37284cdca2 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml @@ -91,17 +91,14 @@ setup: size: 10 - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term term" } - match: { hits.hits.0.fields.keyword.0: "other" } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term" } - match: { hits.hits.1.fields.keyword.0: "keyword" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -143,12 +140,10 @@ setup: - match: { hits.total.value : 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } @@ -198,17 +193,14 @@ setup: - match: { hits.total.value : 3 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -263,11 +255,10 @@ setup: rank_window_size: 2 rank_constant: 1 - - match: { hits.total.value : 3 } + - match: { hits.total.value : 2 } - length: { hits.hits: 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } @@ -330,6 +321,5 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml index 520389d51b737..76322b0bd13d1 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml @@ -160,19 +160,14 @@ setup: ] - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "6" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.3._id: "7" } - - match: { hits.hits.3._rank: 4 } - match: { hits.hits.4._id: "3" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }} - match: { aggregations.sums.value.text_total: 25 } @@ -228,13 +223,12 @@ setup: 'desc_total': states.stream().mapToDouble(v -> v['desc_total']).sum() ] - - match: { hits.total.value: 6 } + - match: { hits.total.value: 3 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - - close_to: { aggregations.sums.value.asc_total: { value: 33.0, error: 0.001 }} - - close_to: { aggregations.sums.value.desc_total: { value: 39.0, error: 0.001 }} + - close_to: { aggregations.sums.value.asc_total: { value: 15.0, error: 0.001 }} + - close_to: { aggregations.sums.value.desc_total: { value: 21.0, error: 0.001 }} --- "rrf retriever using multiple knn retrievers and a standard retriever with a scripted metric aggregation": @@ -333,19 +327,14 @@ setup: ] - match: { hits.hits.0._id: "6" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.2._id: "7" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.3._id: "4" } - - match: { hits.hits.3._rank: 4 } - match: { hits.hits.4._id: "8" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 30.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 30.0, error: 0.001 }}