From 9feae3340cd05de3ebd572e54621fd48bc2b6d7f Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 12 Dec 2024 15:59:07 +0000 Subject: [PATCH] Add a new `rescorer` retriever This change adds a new `rescorer` retriever that re-scores only the top documents returned by its child retriever. --- docs/reference/search/retriever.asciidoc | 117 ++++++- .../30_rescorer_retriever.yml | 200 +++++++++++ .../elasticsearch/search/SearchFeatures.java | 7 + .../elasticsearch/search/SearchModule.java | 2 + .../retriever/RescorerRetrieverBuilder.java | 156 +++++++++ .../search/retriever/RetrieverBuilder.java | 2 +- .../RescorerRetrieverBuilderParsingTests.java | 66 ++++ x-pack/plugin/rank-rrf/build.gradle | 1 + .../rrf/RRFRankClientYamlTestSuiteIT.java | 1 + .../test/rrf/900_rrf_with_rescorer.yml | 324 ++++++++++++++++++ 10 files changed, 874 insertions(+), 2 deletions(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java create mode 100644 server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/900_rrf_with_rescorer.yml diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index cb04d4fb6fbf1..55a2af37cfd8a 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -22,6 +22,9 @@ A <> that replaces the functionality of a traditi `knn`:: A <> that replaces the functionality of a <>. +`rescorer`:: +A <> that replaces the functionality of the <>. + `rrf`:: A <> that produces top documents from <>. @@ -371,6 +374,118 @@ GET movies/_search ---- // TEST[skip:uses ELSER] +[[rescorer-retriever]] +==== Rescorer Retriever + +The `rescorer` retriever re-scores only the results produced by its child retriever. +For the `standard` and `knn` retrievers, the `window_size` parameter specifies the number of documents examined per shard. + +For compound retrievers like `rrf`, the `window_size` parameter defines the total number of documents examined globally. + +When using the `rescorer`, ensure its minimum rescore's `window_size` is: +- Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups. +- Greater than or equal to the `size` of the search request when used as the primary retriever in the tree. + +And that its maximum rescore's `window_size` is: +- Smaller than or equal to the `size` or `rank_window_size` of the child retriever. + +===== Parameters + +`rescore`:: +(Required. <>) ++ +Defines the <> applied sequentially to the top documents returned by the child retriever. + +`retriever`:: +(Required. <>) ++ +Specifies the child retriever responsible for generating the initial set of top documents to be re-ranked. + +`filter`:: +(Optional. <>) ++ +Applies a <> to the retriever, ensuring that all documents match the filter criteria without affecting their scores. + +[discrete] +[[rescorer-retriever-example]] +==== Example + +The `rescorer` retriever can be placed at any level within the retriever tree. +The following example demonstrates a `rescorer` applied to the results produced by an `rrf` retriever: + +[source,console] +---- +GET movies/_search +{ + "size": 10, <1> + "retriever": { + "rescorer": { <2> + "rescore": { + "query": { <3> + "window_size": 50, <4> + "rescore_query": { + "script_score": { + "script": { + "source": "cosineSimilarity(params.queryVector, 'product-vector_final_stage') + 1.0", + "params": { + "queryVector": [-0.5, 90.0, -10, 14.8, -156.0] + } + } + } + } + } + }, + "retriever": { <5> + "rrf": { + "rank_window_size": 100, <6> + "retrievers": [ + { + "standard": { + "query": { + "sparse_vector": { + "field": "plot_embedding", + "inference_id": "my-elser-model", + "query": "films that explore psychological depths" + } + } + } + }, + { + "standard": { + "query": { + "multi_match": { + "query": "crime", + "fields": [ + "plot", + "title" + ] + } + } + } + }, + { + "knn": { + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + } + ] + } + } + } + } +} +---- + +<1> Specifies the number of top documents to return in the final response. +<2> A `rescorer` retriever applied as the final step. +<3> The definition of the `query` rescorer. +<4> Defines the number of documents to rescore from the child retriever. +<5> Specifies the child retriever definition. +<6> Defines the number of documents returned by the `rrf` retriever, which limits the available documents to + [[text-similarity-reranker-retriever]] ==== Text Similarity Re-ranker Retriever @@ -772,4 +887,4 @@ When a retriever is specified as part of a search, the following elements are no * <> * <> * <> -* <> +* <> use a <> instead diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml new file mode 100644 index 0000000000000..61661a835e1d4 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/30_rescorer_retriever.yml @@ -0,0 +1,200 @@ +setup: + - requires: + cluster_features: [ "search.retriever.rescorer.enabled" ] + reason: "Support for rescorer retriever" + + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + available: + type: boolean + features: + type: rank_features + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"features": { "first_stage": 1, "second_stage": 10}, "available": true, "group": 1}' + - '{"index": {"_id": 2 }}' + - '{"features": { "first_stage": 2, "second_stage": 9}, "available": false, "group": 1}' + - '{"index": {"_id": 3 }}' + - '{"features": { "first_stage": 3, "second_stage": 8}, "available": false, "group": 3}' + - '{"index": {"_id": 4 }}' + - '{"features": { "first_stage": 4, "second_stage": 7}, "available": true, "group": 1}' + - '{"index": {"_id": 5 }}' + - '{"features": { "first_stage": 5, "second_stage": 6}, "available": true, "group": 3}' + - '{"index": {"_id": 6 }}' + - '{"features": { "first_stage": 6, "second_stage": 5}, "available": false, "group": 2}' + - '{"index": {"_id": 7 }}' + - '{"features": { "first_stage": 7, "second_stage": 4}, "available": true, "group": 3}' + - '{"index": {"_id": 8 }}' + - '{"features": { "first_stage": 8, "second_stage": 3}, "available": true, "group": 1}' + - '{"index": {"_id": 9 }}' + - '{"features": { "first_stage": 9, "second_stage": 2}, "available": true, "group": 2}' + - '{"index": {"_id": 10 }}' + - '{"features": { "first_stage": 10, "second_stage": 1}, "available": false, "group": 1}' + +--- +"Rescorer retriever basic": + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 10 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 9.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 3 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: {} + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: {} + size: 2 + + - match: {hits.total.value: 10} + - match: {hits.hits.0._id: "8"} + - match: { hits.hits.0._score: 3.0 } + - match: {hits.hits.1._id: "9"} + - match: { hits.hits.1._score: 2.0 } + +--- +"Rescorer retriever with pre-filters": + - do: + search: + index: test + body: + retriever: + rescorer: + filter: + match: + available: true + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 6 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "4" } + - match: { hits.hits.1._score: 7.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 4 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + filter: + match: + available: true + query: + rank_feature: + field: "features.first_stage" + linear: { } + size: 2 + + - match: { hits.total.value: 6 } + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.0._score: 6.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 4.0 } + +--- +"Rescorer retriever and collapsing": + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.second_stage" + linear: { } + query_weight: 0 + retriever: + standard: + query: + rank_feature: + field: "features.first_stage" + linear: { } + collapse: + field: group + size: 3 + + - match: { hits.total.value: 10 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1._score: 8.0 } + - match: { hits.hits.2._id: "6" } + - match: { hits.hits.2._score: 5.0 } diff --git a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java index beac39c2de304..553511346b182 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/SearchFeatures.java @@ -23,4 +23,11 @@ public final class SearchFeatures implements FeatureSpecification { public Set getFeatures() { return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED, LUCENE_10_0_0_UPGRADE); } + + public static final NodeFeature RETRIEVER_RESCORER_ENABLED = new NodeFeature("search.retriever.rescorer.enabled"); + + @Override + public Set getTestFeatures() { + return Set.of(RETRIEVER_RESCORER_ENABLED); + } } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index d282ba425b126..3294e1ba03f6b 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -231,6 +231,7 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.RescorerRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; @@ -1080,6 +1081,7 @@ private void registerFetchSubPhase(FetchSubPhase subPhase) { private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); + registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent)); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java new file mode 100644 index 0000000000000..efb9ad0089a39 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.search.builder.SearchSourceBuilder.RESCORE_FIELD; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +/** + * A {@link CompoundRetrieverBuilder} that re-scores only the results produced by its child retriever. + */ +public final class RescorerRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String NAME = "rescorer"; + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + args -> new RescorerRetrieverBuilder((RetrieverBuilder) args[0], (List>) args[1]) + ); + + static { + PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> { + RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context); + context.trackRetrieverUsage(innerRetriever.getName()); + return innerRetriever; + }, RETRIEVER_FIELD); + PARSER.declareField(constructorArg(), (parser, context) -> { + if (parser.currentToken() == XContentParser.Token.START_ARRAY) { + List> rescorers = new ArrayList<>(); + while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) { + rescorers.add(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name))); + } + return rescorers; + } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + return List.of(RescorerBuilder.parseFromXContent(parser, name -> context.trackRescorerUsage(name))); + } else { + throw new IllegalArgumentException( + "Unknown format for [rescorer.rescore], expects an object or an array of objects, got: " + parser.currentToken() + ); + } + }, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); + RetrieverBuilder.declareBaseParserFields(NAME, PARSER); + } + + public static RescorerRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + try { + return PARSER.apply(parser, context); + } catch (Exception e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + private final List> rescorers; + + public RescorerRetrieverBuilder(RetrieverBuilder retriever, List> rescorers) { + super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers)); + if (rescorers.isEmpty()) { + throw new IllegalArgumentException("Missing rescore definition"); + } + this.rescorers = rescorers; + } + + private RescorerRetrieverBuilder(RetrieverSource retriever, List> rescorers) { + super(List.of(retriever), extractMinWindowSize(rescorers)); + this.rescorers = rescorers; + } + + /** + * The minimum window size is used as the {@link CompoundRetrieverBuilder#rankWindowSize}, + * the final number of top documents to return in this retriever. + */ + private static int extractMinWindowSize(List> rescorers) { + int windowSize = Integer.MAX_VALUE; + for (var rescore : rescorers) { + windowSize = Math.min(rescore.windowSize() == null ? RescorerBuilder.DEFAULT_WINDOW_SIZE : rescore.windowSize(), windowSize); + } + return windowSize; + } + + @Override + public String getName() { + return NAME; + } + + @Override + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) { + for (var rescorer : rescorers) { + source.addRescorer(rescorer); + } + return source; + } + + @Override + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); + builder.startArray(RESCORE_FIELD.getPreferredName()); + for (RescorerBuilder rescorer : rescorers) { + rescorer.toXContent(builder, params); + } + builder.endArray(); + } + + @Override + protected RescorerRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers); + newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders; + return newInstance; + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + assert rankResults.size() == 1; + ScoreDoc[] scoreDocs = rankResults.getFirst(); + RankDoc[] rankDocs = new RankDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + rankDocs[i].rank = i + 1; + } + return rankDocs; + } + + @Override + public boolean doEquals(Object o) { + RescorerRetrieverBuilder that = (RescorerRetrieverBuilder) o; + return super.doEquals(o) && Objects.equals(rescorers, that.rescorers); + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), rescorers); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index d52c354cad69e..b9bfdfdf3402f 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -63,7 +63,7 @@ protected static void declareBaseParserFields( AbstractObjectParser parser ) { parser.declareObjectArray( - (r, v) -> r.preFilterQueryBuilders = v, + (r, v) -> r.preFilterQueryBuilders = new ArrayList<>(v), (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage), PRE_FILTER_FIELD ); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java new file mode 100644 index 0000000000000..91875049029d3 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilderParsingTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.rescore.QueryRescorerBuilderTests; +import org.elasticsearch.search.rescore.RescorerBuilder; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + +public class RescorerRetrieverBuilderParsingTests extends AbstractXContentTestCase { + @Override + protected RescorerRetrieverBuilder createTestInstance() { + int num = randomIntBetween(1, 3); + List> rescorers = new ArrayList<>(); + for (int i = 0; i < num; i++) { + rescorers.add(QueryRescorerBuilderTests.randomRescoreBuilder()); + } + final RetrieverBuilder retriever; + if (randomBoolean()) { + retriever = KnnRetrieverBuilderParsingTests.createRandomKnnRetrieverBuilder(); + } else { + retriever = StandardRetrieverBuilderParsingTests.createRandomStandardRetrieverBuilder((xContent, data) -> { + try { + return createParser(xContent, data); + } catch (IOException ioe) { + throw new UncheckedIOException(ioe); + } + }); + } + return new RescorerRetrieverBuilder(retriever, rescorers); + } + + @Override + protected RescorerRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (RescorerRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), n -> true) + ); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()); + } +} diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index 2c3f217243aa4..b2d470c6618ea 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -22,6 +22,7 @@ dependencies { testImplementation(testArtifact(project(xpackModule('core')))) testImplementation(testArtifact(project(':server'))) + clusterModules project(':modules:mapper-extras') clusterModules project(xpackModule('rank-rrf')) clusterModules project(xpackModule('inference')) clusterModules project(':modules:lang-painless') diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java index 32b5aedd5d99a..1a22f8738a26a 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java @@ -21,6 +21,7 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { @ClassRule public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .nodes(2) + .module("mapper-extras") .module("rank-rrf") .module("lang-painless") .module("x-pack-inference") diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/900_rrf_with_rescorer.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/900_rrf_with_rescorer.yml new file mode 100644 index 0000000000000..44f475e350eb4 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/900_rrf_with_rescorer.yml @@ -0,0 +1,324 @@ +setup: + - requires: + cluster_features: [ "search.retriever.rescorer.enabled" ] + reason: "Support for rescorer retriever" + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 3 + mappings: + properties: + available: + type: boolean + features: + type: rank_features + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"features": { "first_query": 1, "second_query": 3, "final_score": 7}, "available": true}' + - '{"index": {"_id": 2 }}' + - '{"features": { "first_query": 5, "second_query": 7, "final_score": 4}, "available": false}' + - '{"index": {"_id": 3 }}' + - '{"features": { "first_query": 6, "second_query": 5, "final_score": 3}, "available": false}' + - '{"index": {"_id": 4 }}' + - '{"features": { "first_query": 3, "second_query": 2, "final_score": 2}, "available": true}' + - '{"index": {"_id": 5 }}' + - '{"features": { "first_query": 2, "second_query": 1, "final_score": 1}, "available": true}' + - '{"index": {"_id": 6 }}' + - '{"features": { "first_query": 4, "second_query": 4, "final_score": 8}, "available": false}' + - '{"index": {"_id": 7 }}' + - '{"features": { "first_query": 7, "second_query": 10, "final_score": 9}, "available": true}' + - '{"index": {"_id": 8 }}' + - '{"features": { "first_query": 8, "second_query": 8, "final_score": 10}, "available": true}' + - '{"index": {"_id": 9 }}' + - '{"features": { "first_query": 9, "second_query": 9, "final_score": 5}, "available": true}' + - '{"index": {"_id": 10 }}' + - '{"features": { "first_query": 10, "second_query": 6, "final_score": 6}, "available": false}' + +--- +"RRF with rescorer retriever basic": + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 10 + query: + rescore_query: + rank_feature: + field: "features.final_score" + linear: { } + query_weight: 0 + retriever: + rrf: + rank_window_size: 10 + retrievers: [ + { + standard: { + query: { + rank_feature: { + field: "features.first_query", + linear: { } + } + } + } + }, + { + standard: { + query: { + rank_feature: { + field: "features.second_query", + linear: { } + } + } + } + } + ] + size: 3 + + - match: { hits.total.value: 10 } + - length: { hits.hits: 3} + - match: { hits.hits.0._id: "8" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 9.0 } + - match: { hits.hits.2._id: "6" } + - match: { hits.hits.2._score: 8.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + rescore: + window_size: 5 + query: + rescore_query: + rank_feature: + field: "features.final_score" + linear: { } + query_weight: 0 + retriever: + rrf: + rank_window_size: 5 + retrievers: [ + { + standard: { + query: { + rank_feature: { + field: "features.first_query", + linear: { } + } + } + } + }, + { + standard: { + query: { + rank_feature: { + field: "features.second_query", + linear: { } + } + } + } + } + ] + size: 3 + + - match: { hits.total.value: 10 } + - length: { hits.hits: 3} + - match: { hits.hits.0._id: "8" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 9.0 } + - match: { hits.hits.2._id: "10" } + - match: { hits.hits.2._score: 6.0 } + +--- +"RRF with rescorer retriever and prefilters": + - do: + search: + index: test + body: + retriever: + rescorer: + filter: + match: + available: true + rescore: + window_size: 5 + query: + rescore_query: + rank_feature: + field: "features.final_score" + linear: { } + query_weight: 0 + retriever: + rrf: + rank_window_size: 5 + retrievers: [ + { + standard: { + query: { + rank_feature: { + field: "features.first_query", + linear: { } + } + } + } + }, + { + standard: { + query: { + rank_feature: { + field: "features.second_query", + linear: { } + } + } + } + } + ] + size: 3 + + - match: { hits.total.value: 6 } + - length: { hits.hits: 3} + - match: { hits.hits.0._id: "8" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 9.0 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 7.0 } + + - do: + search: + index: test + body: + retriever: + rescorer: + filter: + match: + available: true + rescore: + window_size: 5 + query: + rescore_query: + rank_feature: + field: "features.final_score" + linear: { } + query_weight: 0 + retriever: + rrf: + rank_window_size: 5 + retrievers: [ + { + standard: { + query: { + rank_feature: { + field: "features.first_query", + linear: { } + } + } + } + }, + { + standard: { + filter: { + match: { + available: true + } + }, + query: { + rank_feature: { + field: "features.second_query", + linear: { } + } + } + } + } + ] + size: 3 + + - match: { hits.total.value: 6 } + - length: { hits.hits: 3} + - match: { hits.hits.0._id: "8" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 9.0 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 7.0 } + +--- +"RRF with rescorer retriever and aggs": + - do: + search: + index: test + body: + aggs: + 1: + terms: + field: available + retriever: + rescorer: + rescore: + window_size: 5 + query: + rescore_query: + rank_feature: + field: "features.final_score" + linear: { } + query_weight: 0 + retriever: + rrf: + rank_window_size: 5 + retrievers: [ + { + standard: { + query: { + rank_feature: { + field: "features.first_query", + linear: { } + } + } + } + }, + { + standard: { + filter: { + match: { + available: true + } + }, + query: { + rank_feature: { + field: "features.second_query", + linear: { } + } + } + } + } + ] + size: 3 + + - match: { hits.total.value: 10 } + - length: { hits.hits: 3} + - match: { hits.hits.0._id: "8" } + - match: { hits.hits.0._score: 10.0 } + - match: { hits.hits.1._id: "7" } + - match: { hits.hits.1._score: 9.0 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 7.0 } + - length: { aggregations.1.buckets: 2} + - match: { aggregations.1.buckets.0.key: 1} + - match: { aggregations.1.buckets.0.doc_count: 6} + - match: { aggregations.1.buckets.1.key: 0 } + - match: { aggregations.1.buckets.1.doc_count: 4 } +