diff --git a/docs/changelog/118585.yaml b/docs/changelog/118585.yaml new file mode 100644 index 0000000000000..4caa5efabbd33 --- /dev/null +++ b/docs/changelog/118585.yaml @@ -0,0 +1,7 @@ +pr: 118585 +summary: Add a generic `rescorer` retriever based on the search request's rescore + functionality +area: Ranking +type: feature +issues: + - 118327 diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index cb04d4fb6fbf1..345f80d2d7c34 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -22,6 +22,9 @@ A <> that replaces the functionality of a traditi `knn`:: A <> that replaces the functionality of a <>. +`rescorer`:: +A <> that replaces the functionality of the <>. + `rrf`:: A <> that produces top documents from <>. @@ -371,6 +374,122 @@ GET movies/_search ---- // TEST[skip:uses ELSER] +[[rescorer-retriever]] +==== Rescorer Retriever + +The `rescorer` retriever re-scores only the results produced by its child retriever. +For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard. + +For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally. + +When using the `rescorer`, an error is returned if the following conditions are not met: + +* The minimum configured rescore's `window_size` is: +** Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups. +** Greater than or equal to the `size` of the search request when used as the primary retriever in the tree. + +* And the maximum rescore's `window_size` is: +** Smaller than or equal to the `size` or `rank_window_size` of the child retriever. + +[discrete] +[[rescorer-retriever-parameters]] +===== Parameters + +`rescore`:: +(Required. <>) ++ +Defines the <> applied sequentially to the top documents returned by the child retriever. + +`retriever`:: +(Required. <>) ++ +Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked. + +`filter`:: +(Optional. <>) ++ +Applies a <> to the retriever, ensuring that all documents match the filter criteria without affecting their scores. + +[discrete] +[[rescorer-retriever-example]] +==== Example + +The `rescorer` retriever can be placed at any level within the retriever tree. +The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever: + +[source,console] +---- +GET movies/_search +{ + "size": 10, <1> + "retriever": { + "rescorer": { <2> + "rescore": { + "query": { <3> + "window_size": 50, <4> + "rescore_query": { + "script_score": { + "script": { + "source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0", + "params": { + "queryVector": [-0.5, 90.0, -10, 14.8, -156.0] + } + } + } + } + } + }, + "retriever": { <5> + "rrf": { + "rank_window_size": 100, <6> + "retrievers": [ + { + "standard": { + "query": { + "sparse_vector": { + "field": "plot_embedding", + "inference_id": "my-elser-model", + "query": "films that explore psychological depths" + } + } + } + }, + { + "standard": { + "query": { + "multi_match": { + "query": "crime", + "fields": [ + "plot", + "title" + ] + } + } + } + }, + { + "knn": { + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + } + ] + } + } + } + } +} +---- +// TEST[skip:uses ELSER] +<1> Specifies the number of top documents to return in the final response. +<2> A `rescorer` retriever applied as the final step. +<3> The definition of the `query` rescorer. +<4> Defines the number of documents to rescore from the child retriever. +<5> Specifies the child retriever definition. +<6> Defines the number of documents returned by the `rrf` retriever, which limits the available documents to + [[text-similarity-reranker-retriever]] ==== Text Similarity Re-ranker Retriever @@ -772,4 +891,4 @@ When a retriever is specified as part of a search, the following elements are no * <> * <> * <> -* <> +* <> use a <> instead diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index 5c61a93d50a60..242460da9d424 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -258,4 +258,5 @@ tasks.named("yamlRestTestV7CompatTransform").configure({ task -> task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") task.skipTest("synonyms/90_synonyms_reloading_for_synset/Reload analyzers for specific synonym set", "Can't work until auto-expand replicas is 0-1 for synonyms index") + task.skipTest("search/90_search_after/_shard_doc sort", "restriction has been lifted in latest versions") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml new file mode 100644 index 0000000000000..2c16de61c6b15 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml @@ -0,0 +1,225 @@ +setup: + - requires: + cluster_features: [ "search.retriever.rescorer.enabled" ] + reason: "Support for rescorer retriever" + + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + available: + type: boolean + features: + type: rank_features + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"features": { "first_stage": 1, "second_stage": 10}, "available": true, "group": 1}' + - '{"index": {"_id": 2 }}' + - '{"features": { "first_stage": 2, "second_stage": 9}, "available": false, "group": 1}' + - '{"index": {"_id": 3 }}' + - '{"features": { "first_stage": 3, "second_stage": 8}, "available": false, "group": 3}' + - '{"index": {"_id": 4 }}' + - '{"features": { "first_stage": 4, "second_stage": 7}, "available": true, "group": 1}' + - '{"index": {"_id": 5 }}' + - '{"features": { "first_stage": 5, "second_stage": 6}, "available": true, "group": 3}' + - '{"index": {"_id": 6 }}' + - '{"features": { "first_stage": 6, "second_stage": 5}, "available": false, "group": 2}' + - '{"index": {"_id": 7 }}' + - '{"features": { "first_stage": 7, "second_stage": 4}, "available": true, "group": 3}' + - '{"index": {"_id": 8 }}' + - '{"features": { "first_stage": 8, "second_stage": 3}, "available": true, "group": 1}' + - '{"index": {"_id": 9 }}' + - '{"features": { "first_stage": 9, "second_stage": 2}, "available": true, "group": 2}' + - '{"index": {"_id": 10 }}' + - '{"features": { "first_stage": 10, "second_stage": 1}, "available": false, "group": 1}' + +--- +"Rescorer retriever basic": + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 10 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 9.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 3 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: {} + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: {} + size: 2 + + - match: {hits.total.value: 10} + - match: {hits.hits.0._id: "8"} + - match: { hits.hits.0._score: 3.0 } + - match: {hits.hits.1._id: "9"} + - match: { hits.hits.1._score: 2.0 } + +--- +"Rescorer retriever with pre-filters": + - do: + search: + index: test + body: + retriever: + rescorer: + filter: + match: + available: true + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 6 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "4" } + - match: { hits.hits.1._score: 7.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 4 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + filter: + match: + available: true + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 6 } + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.0._score: 6.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 4.0 } + +--- +"Rescorer retriever and collapsing": + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + collapse: + field: group + size: 3 + + - match: { hits.total.value: 10 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1._score: 8.0 } + - match: { hits.hits.2._id: "6" } + - match: { hits.hits.2._score: 5.0 } + +--- +"Rescorer retriever and invalid window size": + - do: + catch: "/\\[rescorer\\] requires \\[window_size: 5\\] be greater than or equal to \\[size: 10\\]/" + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 5 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 10 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/90_search_after.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/90_search_after.yml index 1fefc8bffffa1..d3b2b5a412717 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/90_search_after.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/90_search_after.yml @@ -218,31 +218,6 @@ - match: {hits.hits.0._source.timestamp: "2019-10-21 00:30:04.828740" } - match: {hits.hits.0.sort: [1571617804828740000] } - ---- -"_shard_doc sort": - - requires: - cluster_features: ["gte_v7.12.0"] - reason: _shard_doc sort was added in 7.12 - - - do: - indices.create: - index: test - - do: - index: - index: test - id: "1" - body: { id: 1, foo: bar, age: 18 } - - - do: - catch: /\[_shard_doc\] sort field cannot be used without \[point in time\]/ - search: - index: test - body: - size: 1 - sort: ["_shard_doc"] - search_after: [ 0L ] - --- "Format sort values": - requires: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index 6043688b7670a..93b155dbbecdb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -38,6 +38,7 @@ import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.ParseField; @@ -862,6 +863,20 @@ public void testRescorePhaseWithInvalidSort() throws Exception { } } ); + + assertResponse( + prepareSearch().addSort(SortBuilders.scoreSort()) + .addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)) + .setTrackScores(true) + .addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50), + response -> { + assertThat(response.getHits().getTotalHits().value(), equalTo(5L)); + assertThat(response.getHits().getHits().length, equalTo(5)); + for (SearchHit hit : response.getHits().getHits()) { + assertThat(hit.getScore(), equalTo(101f)); + } + } + ); } record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {} @@ -901,6 +916,10 @@ public void testRescoreAfterCollapse() throws Exception { .setQuery(fieldValueScoreQuery("firstPassScore")) .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore"))) .setCollapse(new CollapseBuilder("group")); + if (randomBoolean()) { + request.addSort(SortBuilders.scoreSort()); + request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + } assertResponse(request, resp -> { assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); assertThat(resp.getHits().getHits().length, equalTo(3)); @@ -980,6 +999,10 @@ public void testRescoreAfterCollapseRandom() throws Exception { .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups)) .setCollapse(new CollapseBuilder("group")) .setSize(Math.min(numGroups, 10)); + if (randomBoolean()) { + request.addSort(SortBuilders.scoreSort()); + request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + } long expectedNumHits = numHits; assertResponse(request, resp -> { assertThat(resp.getHits().getTotalHits().value, equalTo(expectedNumHits)); diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index 65a4ce8e2fe89..b3e8dc3b90dfe 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -73,6 +73,7 @@ import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.RescorePhase; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -410,7 +411,7 @@ public void preProcess() { ); } if (rescore != null) { - if (sort != null) { + if (RescorePhase.validateSort(sort) == false) { throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore]."); } int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow(); diff --git a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java index 6a89d66bb3411..9e03e5f1fe553 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java @@ -20,4 +20,11 @@ public final class SearchFeatures implements FeatureSpecification { public Set getFeatures() { return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED); } + + public static final NodeFeature RETRIEVER_RESCORER_ENABLED = new NodeFeature("search.retriever.rescorer.enabled"); + + @Override + public Set getTestFeatures() { + return Set.of(RETRIEVER_RESCORER_ENABLED); + } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index b8f50c6f9a62f..e8cc76f69872b 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -236,6 +236,7 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.RescorerRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; @@ -1103,6 +1104,7 @@ private void registerFetchSubPhase(FetchSubPhase subPhase) { private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); + registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent)); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index aa71d1ff069de..a9f2579c8838c 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -49,9 +49,7 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.slice.SliceBuilder; -import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -2360,18 +2358,6 @@ public ActionRequestValidationException validate( validationException = rescorer.validate(this, validationException); } } - - if (pointInTimeBuilder() == null && sorts() != null) { - for (var sortBuilder : sorts()) { - if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder - && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) { - validationException = addValidationError( - "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", - validationException - ); - } - } - } return validationException; } } diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java index cbc04dd460ff5..3d793a164f40a 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java @@ -58,6 +58,7 @@ import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.InternalProfileCollector; import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.RescorePhase; import org.elasticsearch.search.sort.SortAndFormats; import java.io.IOException; @@ -238,7 +239,7 @@ static CollectorManager createQueryPhaseCollectorMa int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); final boolean rescore = searchContext.rescore().isEmpty() == false; if (rescore) { - assert searchContext.sort() == null; + assert RescorePhase.validateSort(searchContext.sort()); for (RescoreContext rescoreContext : searchContext.rescore()) { numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index 7e3646e7689cc..c23df9cdfa441 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; @@ -22,9 +23,12 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.SearchTimeoutException; +import org.elasticsearch.search.sort.ShardDocSortField; +import org.elasticsearch.search.sort.SortAndFormats; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -39,15 +43,27 @@ public static void execute(SearchContext context) { if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) { return; } - + if (validateSort(context.sort()) == false) { + throw new IllegalStateException("Cannot use [sort] option in conjunction with [rescore], missing a validate?"); + } TopDocs topDocs = context.queryResult().topDocs().topDocs; if (topDocs.scoreDocs.length == 0) { return; } + // Populate FieldDoc#score using the primary sort field (_score) to ensure compatibility with top docs rescoring + Arrays.stream(topDocs.scoreDocs).forEach(t -> { + if (t instanceof FieldDoc fieldDoc) { + fieldDoc.score = (float) fieldDoc.fields[0]; + } + }); TopFieldGroups topGroups = null; + TopFieldDocs topFields = null; if (topDocs instanceof TopFieldGroups topFieldGroups) { - assert context.collapse() != null; + assert context.collapse() != null && validateSortFields(topFieldGroups.fields); topGroups = topFieldGroups; + } else if (topDocs instanceof TopFieldDocs topFieldDocs) { + assert validateSortFields(topFieldDocs.fields); + topFields = topFieldDocs; } try { Runnable cancellationCheck = getCancellationChecks(context); @@ -56,17 +72,18 @@ public static void execute(SearchContext context) { topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); // It is the responsibility of the rescorer to sort the resulted top docs, // here we only assert that this condition is met. - assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; + assert topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; ctx.setCancellationChecker(null); } + /** + * Since rescorers are building top docs with score only, we must reconstruct the {@link TopFieldGroups} + * or {@link TopFieldDocs} using their original version before rescoring. + */ if (topGroups != null) { assert context.collapse() != null; - /** - * Since rescorers don't preserve collapsing, we must reconstruct the group and field - * values from the originalTopGroups to create a new {@link TopFieldGroups} from the - * rescored top documents. - */ - topDocs = rewriteTopGroups(topGroups, topDocs); + topDocs = rewriteTopFieldGroups(topGroups, topDocs); + } else if (topFields != null) { + topDocs = rewriteTopFieldDocs(topFields, topDocs); } context.queryResult() .topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats()); @@ -81,29 +98,84 @@ public static void execute(SearchContext context) { } } - private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { - assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0]) - : "rescore must always sort by score descending"; + /** + * Returns whether the provided {@link SortAndFormats} can be used to rescore + * top documents. + */ + public static boolean validateSort(SortAndFormats sortAndFormats) { + if (sortAndFormats == null) { + return true; + } + return validateSortFields(sortAndFormats.sort.getSort()); + } + + private static boolean validateSortFields(SortField[] fields) { + if (fields[0].equals(SortField.FIELD_SCORE) == false) { + return false; + } + if (fields.length == 1) { + return true; + } + + // The ShardDocSortField can be used as a tiebreaker because it maintains + // the natural document ID order within the shard. + if (fields[1] instanceof ShardDocSortField == false || fields[1].getReverse()) { + return false; + } + return true; + } + + private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) { + Map docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length); + for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; i++) { + docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc) originalTopFieldDocs.scoreDocs[i]); + } + var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length]; + int pos = 0; + for (var doc : rescoredTopDocs.scoreDocs) { + newScoreDocs[pos] = docIdToFieldDoc.get(doc.doc); + newScoreDocs[pos].score = doc.score; + newScoreDocs[pos].fields[0] = newScoreDocs[pos].score; + pos++; + } + return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields); + } + + private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { + var newFieldDocs = rewriteFieldDocs((FieldDoc[]) originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs); + Map docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length); for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) { docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]); } - var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length]; var newGroupValues = new Object[originalTopGroups.groupValues.length]; int pos = 0; for (var doc : rescoredTopDocs.scoreDocs) { - newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score }); newGroupValues[pos++] = docIdToGroupValue.get(doc.doc); } return new TopFieldGroups( originalTopGroups.field, originalTopGroups.totalHits, - newScoreDocs, + newFieldDocs, originalTopGroups.fields, newGroupValues ); } + private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) { + Map docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length); + Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d)); + var newDocs = new FieldDoc[rescoredTopDocs.length]; + int pos = 0; + for (var doc : rescoredTopDocs) { + newDocs[pos] = docIdToFieldDoc.get(doc.doc); + newDocs[pos].score = doc.score; + newDocs[pos].fields[0] = doc.score; + pos++; + } + return newDocs; + } + /** * Returns true if the provided docs are sorted by score. */ diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index f624961515389..38a319321207f 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -39,7 +39,7 @@ public abstract class RescorerBuilder> protected Integer windowSize; - private static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size"); + public static final ParseField WINDOW_SIZE_FIELD = new ParseField("window_size"); /** * Construct an empty RescoreBuilder. diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 2ab6395db73b5..298340e5c579e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -32,10 +32,12 @@ import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; +import org.elasticsearch.xcontent.ParseField; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -49,6 +51,8 @@ public abstract class CompoundRetrieverBuilder rankWindowSize) { validationException = addValidationError( - "[" - + this.getName() - + "] requires [rank_window_size: " - + rankWindowSize - + "]" - + " be greater than or equal to [size: " - + source.size() - + "]", + String.format( + Locale.ROOT, + "[%s] requires [%s: %d] be greater than or equal to [size: %d]", + getName(), + getRankWindowSizeField().getPreferredName(), + rankWindowSize, + source.size() + ), validationException ); } @@ -231,6 +243,21 @@ public ActionRequestValidationException validate( } for (RetrieverSource innerRetriever : innerRetrievers) { validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults); + if (innerRetriever.retriever() instanceof CompoundRetrieverBuilder compoundChild) { + if (rankWindowSize > compoundChild.rankWindowSize) { + String errorMessage = String.format( + Locale.ROOT, + "[%s] requires [%s: %d] to be smaller than or equal to its sub retriever's %s [%s: %d]", + this.getName(), + getRankWindowSizeField().getPreferredName(), + rankWindowSize, + compoundChild.getName(), + compoundChild.getRankWindowSizeField(), + compoundChild.rankWindowSize + ); + validationException = addValidationError(errorMessage, validationException); + } + } } return validationException; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java new file mode 100644 index 0000000000000..09688b5b9b001 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.search.builder.SearchSourceBuilder.RESCORE_FIELD; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +/** + * A {@link CompoundRetrieverBuilder} that re-scores only the results produced by its child retriever. + */ +public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String NAME = "rescorer"; + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + args -> new RescorerRetrieverBuilder((RetrieverBuilder) args[0], (List>) args[1]) + ); + + static { + PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> { + RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context); + context.trackRetrieverUsage(innerRetriever.getName()); + return innerRetriever; + }, RETRIEVER_FIELD); + PARSER.declareField(constructorArg(), (parser, context) -> { + if (parser.currentToken() == XContentParser.Token.START_ARRAY) { + List> rescorers = new ArrayList<>(); + while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) { + rescorers.add(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name))); + } + return rescorers; + } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + return List.of(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name))); + } else { + throw new IllegalArgumentException( + "Unknown format for [rescorer.rescore], expects an object or an array of objects, got: " + parser.currentToken() + ); + } + }, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); + RetrieverBuilder.declareBaseParserFields(NAME, PARSER); + } + + public static RescorerRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + try { + return PARSER.apply(parser, context); + } catch (Exception e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + private final List> rescorers; + + public RescorerRetrieverBuilder(RetrieverBuilder retriever, List> rescorers) { + super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers)); + if (rescorers.isEmpty()) { + throw new IllegalArgumentException("Missing rescore definition"); + } + this.rescorers = rescorers; + } + + private RescorerRetrieverBuilder(RetrieverSource retriever, List> rescorers) { + super(List.of(retriever), extractMinWindowSize(rescorers)); + this.rescorers = rescorers; + } + + /** + * The minimum window size is used as the {@link CompoundRetrieverBuilder#rankWindowSize}, + * the final number of top documents to return in this retriever. + */ + private static int extractMinWindowSize(List> rescorers) { + int windowSize = Integer.MAX_VALUE; + for (var rescore : rescorers) { + windowSize = Math.min(rescore.windowSize() == null ? RescorerBuilder.DEFAULT_WINDOW_SIZE : rescore.windowSize(), windowSize); + } + return windowSize; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public ParseField getRankWindowSizeField() { + return RescorerBuilder.WINDOW_SIZE_FIELD; + } + + @Override + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) { + /** + * The re-scorer is passed downstream because this query operates only on + * the top documents retrieved by the child retriever. + * + * - If the sub-retriever is a {@link CompoundRetrieverBuilder}, only the top + * documents are re-scored since they are already determined at this stage. + * - For other retrievers that do not require a rewrite, the re-scorer's window + * size is applied per shard. As a result, more documents are re-scored + * compared to the final top documents produced by these retrievers in isolation. + */ + for (var rescorer : rescorers) { + source.addRescorer(rescorer); + } + return source; + } + + @Override + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); + builder.startArray(RESCORE_FIELD.getPreferredName()); + for (RescorerBuilder rescorer : rescorers) { + rescorer.toXContent(builder, params); + } + builder.endArray(); + } + + @Override + protected RescorerRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers); + newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders; + return newInstance; + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + assert rankResults.size() == 1; + ScoreDoc[] scoreDocs = rankResults.getFirst(); + RankDoc[] rankDocs = new RankDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + rankDocs[i].rank = i + 1; + } + return rankDocs; + } + + @Override + public boolean doEquals(Object o) { + RescorerRetrieverBuilder that = (RescorerRetrieverBuilder) o; + return super.doEquals(o) && Objects.equals(rescorers, that.rescorers); + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), rescorers); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index d52c354cad69e..b9bfdfdf3402f 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -63,7 +63,7 @@ protected static void declareBaseParserFields( AbstractObjectParser parser ) { parser.declareObjectArray( - (r, v) -> r.preFilterQueryBuilders = v, + (r, v) -> r.preFilterQueryBuilders = new ArrayList<>(v), (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage), PRE_FILTER_FIELD ); diff --git a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java index 6dc4c62b5b397..853f99261e110 100644 --- a/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java @@ -23,6 +23,7 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.store.BaseDirectoryWrapper; @@ -262,7 +263,10 @@ protected Engine.Searcher acquireSearcherInternal(String source) { // resultWindow not greater than maxResultWindow and both rescore and sort are not null context1.from(0); DocValueFormat docValueFormat = mock(DocValueFormat.class); - SortAndFormats sortAndFormats = new SortAndFormats(new Sort(), new DocValueFormat[] { docValueFormat }); + SortAndFormats sortAndFormats = new SortAndFormats( + new Sort(new SortField[] { SortField.FIELD_DOC }), + new DocValueFormat[] { docValueFormat } + ); context1.sort(sortAndFormats); RescoreContext rescoreContext = mock(RescoreContext.class); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java new file mode 100644 index 0000000000000..fa83246d90cb2 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.rescore.QueryRescorerBuilderTests; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Collections.emptyList; + +public class RescorerRetrieverBuilderParsingTests extends AbstractXContentTestCase { + private static List xContentRegistryEntries; + + @BeforeClass + public static void init() { + xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents(); + } + + @AfterClass + public static void afterClass() throws Exception { + xContentRegistryEntries = null; + } + + @Override + protected RescorerRetrieverBuilder createTestInstance() { + int num = randomIntBetween(1, 3); + List> rescorers = new ArrayList<>(); + for (int i = 0; i < num; i++) { + rescorers.add(QueryRescorerBuilderTests.randomRescoreBuilder()); + } + return new RescorerRetrieverBuilder(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), rescorers); + } + + @Override + protected RescorerRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (RescorerRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), n -> true) + ); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List entries = new ArrayList<>(xContentRegistryEntries); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + TestRetrieverBuilder.TEST_SPEC.getName(), + (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c), + TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() + ) + ); + return new NamedXContentRegistry(entries); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 6ea8afd3e51c2..a4434ac8f7330 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -50,7 +50,6 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder PARSER = new ConstructingObjectParser<>( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 0cc1a300a02b2..f0751dc5d09e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -47,7 +47,6 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); - public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index 2c3f217243aa4..b2d470c6618ea 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -22,6 +22,7 @@ dependencies { testImplementation(testArtifact(project(xpackModule('core')))) testImplementation(testArtifact(project(':server'))) + clusterModules project(':modules:mapper-extras') clusterModules project(xpackModule('rank-rrf')) clusterModules project(xpackModule('inference')) clusterModules project(':modules:lang-painless') diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index f1171b74f7468..c1447623dd5b1 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -48,7 +48,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder