From 6d7ae82180fa30ee4b48bd58721cae2cda809b15 Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 3 Oct 2024 06:44:28 +0100 Subject: [PATCH 1/4] Assert that REST params are consumed iff supported (#113933) REST APIs which declare their supported parameters must consume exactly those parameters: consuming an unsupported parameter means that requests including that parameter will be rejected, whereas failing to consume a supported parameter means that this parameter has no effect and should be removed. This commit adds an assertion to verify that we are consuming the correct parameters. Closes #113854 --- .../rest/RestGetDataStreamsAction.java | 3 ++- .../org/elasticsearch/rest/BaseRestHandler.java | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/rest/RestGetDataStreamsAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/rest/RestGetDataStreamsAction.java index da55376fb403b..7a27eddfaf8c7 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/rest/RestGetDataStreamsAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/rest/RestGetDataStreamsAction.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.datastreams.GetDataStreamAction; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.DataStreamLifecycle; import org.elasticsearch.common.Strings; import org.elasticsearch.common.util.set.Sets; @@ -35,12 +36,12 @@ public class RestGetDataStreamsAction extends BaseRestHandler { Set.of( "name", "include_defaults", - "timeout", "master_timeout", IndicesOptions.WildcardOptions.EXPAND_WILDCARDS, IndicesOptions.ConcreteTargetOptions.IGNORE_UNAVAILABLE, IndicesOptions.WildcardOptions.ALLOW_NO_INDICES, IndicesOptions.GatekeeperOptions.IGNORE_THROTTLED, + DataStream.isFailureStoreFeatureFlagEnabled() ? IndicesOptions.FailureStoreOptions.FAILURE_STORE : "name", "verbose" ) ) diff --git a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java index 99fa3e0166963..2f7bb80a8d46a 100644 --- a/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/BaseRestHandler.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.RestApiVersion; @@ -104,6 +105,8 @@ public final void handleRequest(RestRequest request, RestChannel channel, NodeCl // prepare the request for execution; has the side effect of touching the request parameters try (var action = prepareRequest(request, client)) { + assert assertConsumesSupportedParams(supported, request); + // validate unconsumed params, but we must exclude params used to format the response // use a sorted set so the unconsumed parameters appear in a reliable sorted order final SortedSet unconsumedParams = request.unconsumedParams() @@ -148,6 +151,20 @@ public void close() { } } + private boolean assertConsumesSupportedParams(@Nullable Set supported, RestRequest request) { + if (supported != null) { + final var supportedAndCommon = new TreeSet<>(supported); + supportedAndCommon.add("error_trace"); + supportedAndCommon.addAll(ALWAYS_SUPPORTED); + supportedAndCommon.removeAll(RestRequest.INTERNAL_MARKER_REQUEST_PARAMETERS); + final var consumed = new TreeSet<>(request.consumedParams()); + consumed.removeAll(RestRequest.INTERNAL_MARKER_REQUEST_PARAMETERS); + assert supportedAndCommon.equals(consumed) + : getName() + ": consumed params " + consumed + " while supporting " + supportedAndCommon; + } + return true; + } + protected static String unrecognized(RestRequest request, Set invalids, Set candidates, String detail) { StringBuilder message = new StringBuilder().append("request [") .append(request.path()) From 539d4fdff56492ca7f6b8e59384006650a1ed946 Mon Sep 17 00:00:00 2001 From: Kostas Krikellas <131142368+kkrik-es@users.noreply.github.com> Date: Thu, 3 Oct 2024 09:24:20 +0300 Subject: [PATCH 2/4] Restore 20_synthetic_source/object array in object with dynamic override (#113990) Fixes #113966 --- muted-tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index 7c1f40e5e5639..68422264221f2 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -344,9 +344,6 @@ tests: - class: org.elasticsearch.kibana.KibanaThreadPoolIT method: testBlockedThreadPoolsRejectUserRequests issue: https://github.com/elastic/elasticsearch/issues/113939 -- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT - method: test {p0=indices.create/20_synthetic_source/object array in object with dynamic override} - issue: https://github.com/elastic/elasticsearch/issues/113966 - class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT method: testPutE5Small_withPlatformAgnosticVariant issue: https://github.com/elastic/elasticsearch/issues/113983 From dc8c20d3b63dc667f20db14b87f16e3dc2db7b8a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 3 Oct 2024 12:39:13 +0300 Subject: [PATCH 3/4] Rework RRF to be evaluated during rewrite phase (#112648) --- docs/reference/search/rrf.asciidoc | 7 +- .../retriever/RankDocRetrieverBuilderIT.java | 90 +- .../org/elasticsearch/TransportVersions.java | 1 + .../action/search/TransportSearchAction.java | 2 + .../elasticsearch/common/lucene/Lucene.java | 3 - .../uhighlight/CustomUnifiedHighlighter.java | 2 +- .../elasticsearch/search/SearchModule.java | 2 - .../search/builder/SearchSourceBuilder.java | 23 +- .../search/fetch/StoredFieldsContext.java | 2 +- .../search/fetch/subphase/ExplainPhase.java | 12 +- .../elasticsearch/search/rank/RankDoc.java | 23 +- .../search/rank/feature/RankFeatureDoc.java | 2 +- .../retriever/CompoundRetrieverBuilder.java | 255 ++++ .../search/retriever/KnnRetrieverBuilder.java | 28 +- .../retriever/RankDocsRetrieverBuilder.java | 102 +- .../search/retriever/RetrieverBuilder.java | 61 +- .../retriever/StandardRetrieverBuilder.java | 49 +- .../retriever/rankdoc/RankDocsQuery.java | 396 ++++-- .../rankdoc/RankDocsQueryBuilder.java | 60 +- .../rankdoc/RankDocsSortBuilder.java | 114 -- .../retriever/rankdoc/RankDocsSortField.java | 102 -- .../search/sort/ShardDocSortField.java | 14 + .../action/search/SearchRequestTests.java | 87 +- .../search/rank/RankDocTests.java | 5 - .../KnnRetrieverBuilderParsingTests.java | 19 +- .../RankDocsRetrieverBuilderTests.java | 74 +- .../retriever/RetrieverBuilderErrorTests.java | 18 +- .../rankdoc/RankDocsQueryBuilderTests.java | 119 +- .../rankdoc/RankDocsSortBuilderTests.java | 72 -- .../TestCompoundRetrieverBuilder.java | 52 + .../TextSimilarityRankRetrieverBuilder.java | 10 +- ...xtSimilarityRankRetrieverBuilderTests.java | 3 +- .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 656 ++++++++++ .../rrf/RRFRetrieverBuilderNestedDocsIT.java | 171 +++ .../xpack/rank/rrf/RRFFeatures.java | 4 +- .../RRFQueryPhaseRankCoordinatorContext.java | 6 +- .../rrf/RRFQueryPhaseRankShardContext.java | 6 +- .../xpack/rank/rrf/RRFRankDoc.java | 57 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 150 ++- .../xpack/rank/rrf/RRFRankContextTests.java | 108 +- .../xpack/rank/rrf/RRFRankDocTests.java | 77 +- .../rrf/RRFRetrieverBuilderParsingTests.java | 18 +- .../rank/rrf/RRFRetrieverBuilderTests.java | 87 +- .../rest-api-spec/test/rrf/100_rank_rrf.yml | 8 - .../test/rrf/150_rank_rrf_pagination.yml | 15 - .../test/rrf/200_rank_rrf_script.yml | 20 - .../test/rrf/300_rrf_retriever.yml | 87 +- .../test/rrf/350_rrf_retriever_pagination.yml | 1112 +++++++++++++++++ .../test/rrf/400_rrf_retriever_script.yml | 25 +- .../test/rrf/500_rrf_retriever_explain.yml | 16 +- .../test/rrf/600_rrf_retriever_profile.yml | 36 +- ...rrf_retriever_search_api_compatibility.yml | 541 ++++++++ 52 files changed, 3902 insertions(+), 1107 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java delete mode 100644 server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilder.java delete mode 100644 server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortField.java delete mode 100644 server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilderTests.java create mode 100644 test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java create mode 100644 x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java create mode 100644 x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml diff --git a/docs/reference/search/rrf.asciidoc b/docs/reference/search/rrf.asciidoc index 2525dfff23b94..2a676e5fba336 100644 --- a/docs/reference/search/rrf.asciidoc +++ b/docs/reference/search/rrf.asciidoc @@ -300,13 +300,12 @@ We have both the ranker's `score` and the `_rank` option to show our top-ranked "value" : 5, "relation" : "eq" }, - "max_score" : null, + "max_score" : ..., "hits" : [ { "_index" : "example-index", "_id" : "3", "_score" : 0.8333334, - "_rank" : 1, "_source" : { "integer" : 1, "vector" : [ @@ -319,7 +318,6 @@ We have both the ranker's `score` and the `_rank` option to show our top-ranked "_index" : "example-index", "_id" : "2", "_score" : 0.5833334, - "_rank" : 2, "_source" : { "integer" : 2, "vector" : [ @@ -332,7 +330,6 @@ We have both the ranker's `score` and the `_rank` option to show our top-ranked "_index" : "example-index", "_id" : "4", "_score" : 0.5, - "_rank" : 3, "_source" : { "integer" : 2, "text" : "rrf rrf rrf rrf" @@ -499,7 +496,6 @@ Working with the example above, and by adding `explain=true` to the search reque "_index": "example-index", "_id": "3", "_score": 0.8333334, - "_rank": 1, "_explanation": { "value": 0.8333334, <1> @@ -608,7 +604,6 @@ The response would now include the named query in the explanation: "_index": "example-index", "_id": "3", "_score": 0.8333334, - "_rank": 1, "_explanation": { "value": 0.8333334, diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java index 26af82cf021f2..891096dfa67a9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.NestedSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.test.ESIntegTestCase; @@ -189,8 +190,10 @@ public void testRankDocsRetrieverBasicWithPagination() { SearchSourceBuilder source = new SearchSourceBuilder(); StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.queryStringQuery("quick").defaultField(TEXT_FIELD)) - .boost(10L); + standard0.queryBuilder = QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); // this one retrieves docs 2 and 6 due to prefilter standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); @@ -205,8 +208,8 @@ public void testRankDocsRetrieverBasicWithPagination() { null ); // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 7, 3 and with pagination, we'd just omit the first result + // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) + // so ideal rank would be: 6, 2, 1, 3, 4, 7 and with pagination, we'd just omit the first result source.retriever( new CompoundRetrieverWithRankDocs( rankWindowSize, @@ -227,9 +230,9 @@ public void testRankDocsRetrieverBasicWithPagination() { assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); }); } @@ -242,8 +245,10 @@ public void testRankDocsRetrieverWithAggs() { SearchSourceBuilder source = new SearchSourceBuilder(); StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.queryStringQuery("quick").defaultField(TEXT_FIELD)) - .boost(10L); + standard0.queryBuilder = QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); // this one retrieves docs 2 and 6 due to prefilter standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); @@ -267,13 +272,15 @@ public void testRankDocsRetrieverWithAggs() { ) ) ); + source.size(1); source.aggregation(new TermsAggregationBuilder("topic").field(TOPIC_FIELD)); SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); ElasticsearchAssertions.assertResponse(req, resp -> { assertNull(resp.pointInTimeId()); assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(1L)); + assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); assertNotNull(resp.getAggregations()); assertNotNull(resp.getAggregations().get("topic")); @@ -291,8 +298,10 @@ public void testRankDocsRetrieverWithCollapse() { SearchSourceBuilder source = new SearchSourceBuilder(); StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.queryStringQuery("quick").defaultField(TEXT_FIELD)) - .boost(10L); + standard0.queryBuilder = QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); // this one retrieves docs 2 and 6 due to prefilter standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); @@ -307,8 +316,8 @@ public void testRankDocsRetrieverWithCollapse() { null ); // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 7, 3 + // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) + // so ideal rank would be: 6, 2, 1, 3, 4, 7 // with collapsing on topic field we would have 6, 2, 1, 7 source.retriever( new CompoundRetrieverWithRankDocs( @@ -338,7 +347,6 @@ public void testRankDocsRetrieverWithCollapse() { assertThat(resp.getHits().getAt(1).field(TOPIC_FIELD).getValue().toString(), equalTo("astronomy")); assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); assertThat(resp.getHits().getAt(2).field(TOPIC_FIELD).getValue().toString(), equalTo("technology")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getHits().length, equalTo(3)); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); @@ -347,17 +355,15 @@ public void testRankDocsRetrieverWithCollapse() { }); } - public void testRankDocsRetrieverWithCollapseAndAggs() { - // same as above, but we only want to bring back the top result from each subsearch - // so that would be 1, 2, and 7 - // and final rank would be (based on score): 2, 1, 7 - // aggs should still account for the same docs as the testRankDocsRetriever test, i.e. all but doc_5 + public void testRankDocsRetrieverWithNestedCollapseAndAggs() { final int rankWindowSize = 10; SearchSourceBuilder source = new SearchSourceBuilder(); StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); // this one retrieves docs 1 and 6 as doc_4 is collapsed to doc_1 - standard0.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.queryStringQuery("quick").defaultField(TEXT_FIELD)) - .boost(10L); + standard0.queryBuilder = QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); standard0.collapseBuilder = new CollapseBuilder(TOPIC_FIELD).setInnerHits( new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) ); @@ -375,8 +381,8 @@ public void testRankDocsRetrieverWithCollapseAndAggs() { null ); // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 7, 3 + // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) + // so ideal rank would be: 6, 2, 1, 3, 4, 7 source.retriever( new CompoundRetrieverWithRankDocs( rankWindowSize, @@ -392,7 +398,7 @@ public void testRankDocsRetrieverWithCollapseAndAggs() { ElasticsearchAssertions.assertResponse(req, resp -> { assertNull(resp.pointInTimeId()); assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); assertNotNull(resp.getAggregations()); @@ -427,8 +433,8 @@ public void testRankDocsRetrieverWithNestedQuery() { null ); // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 3, 7 + // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) + // so ideal rank would be: 6, 2, 1, 3, 4, 7 source.retriever( new CompoundRetrieverWithRankDocs( rankWindowSize, @@ -460,8 +466,10 @@ public void testRankDocsRetrieverMultipleCompoundRetrievers() { SearchSourceBuilder source = new SearchSourceBuilder(); StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.queryStringQuery("quick").defaultField(TEXT_FIELD)) - .boost(10L); + standard0.queryBuilder = QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); // this one retrieves docs 2 and 6 due to prefilter standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); @@ -506,11 +514,11 @@ public void testRankDocsRetrieverMultipleCompoundRetrievers() { assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_7")); }); } @@ -545,9 +553,9 @@ public void testRankDocsRetrieverDifferentNestedSorting() { assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); }); } @@ -673,22 +681,14 @@ private RankDoc[] getRankDocs(SearchResponse searchResponse) { for (int i = 0; i < size; i++) { var hit = searchResponse.getHits().getAt(i); long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; - int doc = decodeDoc(sortValue); - int shardRequestIndex = decodeShardRequestIndex(sortValue); + int doc = ShardDocSortField.decodeDoc(sortValue); + int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); docs[i].rank = i + 1; } return docs; } - public static int decodeDoc(long value) { - return (int) value; - } - - public static int decodeShardRequestIndex(long value) { - return (int) (value >> 32); - } - record RankDocAndHitRatio(RankDoc rankDoc, float hitRatio) {} /** diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0ced472ea310c..7ff0ed1bbe82c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -231,6 +231,7 @@ static TransportVersion def(int id) { public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0); public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0); public static final TransportVersion REGEX_AND_RANGE_INTERVAL_QUERIES = def(8_757_00_0); + public static final TransportVersion RRF_QUERY_REWRITE = def(8_758_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index e3d663ec13618..553ee1d05a052 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -502,6 +502,8 @@ void executeRequest( }); final SearchSourceBuilder source = original.source(); if (shouldOpenPIT(source)) { + // disabling shard reordering for request + original.setPreFilterShardSize(Integer.MAX_VALUE); openPIT(client, original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> { // We set the keep alive to -1 to indicate that we don't need the pit id in the response. // This is needed since we delete the pit prior to sending the response so the id doesn't exist anymore. diff --git a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java index c526652fc4e67..5043508c781f0 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/Lucene.java @@ -74,7 +74,6 @@ import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.lucene.grouping.TopFieldGroups; -import org.elasticsearch.search.retriever.rankdoc.RankDocsSortField; import org.elasticsearch.search.sort.ShardDocSortField; import java.io.IOException; @@ -553,8 +552,6 @@ private static SortField rewriteMergeSortField(SortField sortField) { return newSortField; } else if (sortField.getClass() == ShardDocSortField.class) { return new SortField(sortField.getField(), SortField.Type.LONG, sortField.getReverse()); - } else if (sortField.getClass() == RankDocsSortField.class) { - return new SortField(sortField.getField(), SortField.Type.INT, sortField.getReverse()); } else { return sortField; } diff --git a/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java b/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java index 27e3b264a17e8..d1c7d0415ad15 100644 --- a/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java +++ b/server/src/main/java/org/elasticsearch/lucene/search/uhighlight/CustomUnifiedHighlighter.java @@ -260,7 +260,7 @@ public void visitLeaf(Query leafQuery) { * KnnScoreDocQuery and RankDocsQuery requires the same reader that built the docs * When using {@link HighlightFlag#WEIGHT_MATCHES} different readers are used and isn't supported by this query */ - if (leafQuery instanceof KnnScoreDocQuery || leafQuery instanceof RankDocsQuery) { + if (leafQuery instanceof KnnScoreDocQuery || leafQuery instanceof RankDocsQuery.TopQuery) { hasUnknownLeaf[0] = true; } super.visitLeaf(query); diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 6308b19358410..0bb914a9dbf97 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -239,7 +239,6 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; -import org.elasticsearch.search.retriever.rankdoc.RankDocsSortBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; @@ -868,7 +867,6 @@ private void registerSorts() { namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScoreSortBuilder.NAME, ScoreSortBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScriptSortBuilder.NAME, ScriptSortBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, FieldSortBuilder.NAME, FieldSortBuilder::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, RankDocsSortBuilder::new)); } private static void registerFromPlugin(List plugins, Function> producer, Consumer consumer) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index 26780f85a15e0..fc0cb72bb82e0 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -2207,12 +2207,7 @@ public ActionRequestValidationException validate( boolean allowPartialSearchResults ) { if (retriever() != null) { - if (allowPartialSearchResults && retriever().isCompound()) { - validationException = addValidationError( - "cannot specify a compound retriever and [allow_partial_search_results]", - validationException - ); - } + validationException = retriever().validate(this, validationException, allowPartialSearchResults); List specified = new ArrayList<>(); if (subSearches().isEmpty() == false) { specified.add(QUERY_FIELD.getPreferredName()); @@ -2229,9 +2224,6 @@ public ActionRequestValidationException validate( if (sorts() != null) { specified.add(SORT_FIELD.getPreferredName()); } - if (minScore() != null) { - specified.add(MIN_SCORE_FIELD.getPreferredName()); - } if (rankBuilder() != null) { specified.add(RANK_FIELD.getPreferredName()); } @@ -2331,21 +2323,10 @@ public ActionRequestValidationException validate( if (rescores() != null && rescores().isEmpty() == false) { validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); } - if (sorts() != null && sorts().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [sort]", validationException); - } - if (collapse() != null) { - validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); - } + if (suggest() != null && suggest().getSuggestions().isEmpty() == false) { validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); } - if (highlighter() != null) { - validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); - } - if (pointInTimeBuilder() != null) { - validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); - } } if (rescores() != null) { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/StoredFieldsContext.java b/server/src/main/java/org/elasticsearch/search/fetch/StoredFieldsContext.java index 62eaadfb15690..3076337d43c84 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/StoredFieldsContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/StoredFieldsContext.java @@ -34,7 +34,7 @@ public class StoredFieldsContext implements Writeable { private final List fieldNames; private final boolean fetchFields; - private StoredFieldsContext(boolean fetchFields) { + public StoredFieldsContext(boolean fetchFields) { this.fetchFields = fetchFields; this.fieldNames = null; } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/ExplainPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/ExplainPhase.java index 7a2913ce56128..0e6172323277d 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/ExplainPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/ExplainPhase.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; +import org.elasticsearch.index.mapper.NestedLookup; import org.elasticsearch.search.fetch.FetchContext; import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; @@ -45,8 +46,17 @@ public void process(HitContext hitContext) throws IOException { for (RescoreContext rescore : context.rescore()) { explanation = rescore.rescorer().explain(topLevelDocId, context.searcher(), rescore, explanation); } + if (context.rankBuilder() != null) { - explanation = context.rankBuilder().explainHit(explanation, hitContext.rankDoc(), queryNames); + // if we have nested fields, then the query is wrapped using an additional filter on the _primary_term field + // through the DefaultSearchContext#buildFilteredQuery so we have to extract the actual query + if (context.getSearchExecutionContext().nestedLookup() != NestedLookup.EMPTY) { + explanation = explanation.getDetails()[0]; + } + + if (context.rankBuilder() != null) { + explanation = context.rankBuilder().explainHit(explanation, hitContext.rankDoc(), queryNames); + } } // we use the top level doc id, since we work with the top level searcher hitContext.hit().explanation(explanation); diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java index 94b584607878a..b16a234931115 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java @@ -24,7 +24,7 @@ * {@code RankDoc} is the base class for all ranked results. * Subclasses should extend this with additional information required for their global ranking method. */ -public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment { +public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable { public static final String NAME = "rank_doc"; @@ -40,6 +40,17 @@ public String getWriteableName() { return NAME; } + @Override + public final int compareTo(RankDoc other) { + if (score != other.score) { + return score < other.score ? 1 : -1; + } + if (shardIndex != other.shardIndex) { + return shardIndex < other.shardIndex ? -1 : 1; + } + return doc < other.doc ? -1 : 1; + } + public record RankKey(int doc, int shardIndex) {} public RankDoc(int doc, float score, int shardIndex) { @@ -65,8 +76,12 @@ public final void writeTo(StreamOutput out) throws IOException { /** * Explain the ranking of this document. */ - public Explanation explain() { - return Explanation.match(rank, "doc [" + doc + "] with an original score of [" + score + "] is at rank [" + rank + "]."); + public Explanation explain(Explanation[] sourceExplanations, String[] queryNames) { + return Explanation.match( + rank, + "doc [" + doc + "] with an original score of [" + score + "] is at rank [" + rank + "] from the following source queries.", + sourceExplanations + ); } @Override @@ -104,6 +119,6 @@ protected int doHashCode() { @Override public String toString() { - return "RankDoc{" + "_rank=" + rank + ", _doc=" + doc + ", _shard=" + shardIndex + ", _score=" + score + '}'; + return "RankDoc{" + "_rank=" + rank + ", _doc=" + doc + ", _shard=" + shardIndex + ", _score=" + score + "}"; } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java index aadcb94c4b242..cd8d9392aced8 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java @@ -38,7 +38,7 @@ public RankFeatureDoc(StreamInput in) throws IOException { } @Override - public Explanation explain() { + public Explanation explain(Explanation[] sources, String[] queryNames) { throw new UnsupportedOperationException("explain is not supported for {" + getClass() + "}"); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java new file mode 100644 index 0000000000000..1962145d7336d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -0,0 +1,255 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.TransportMultiSearchAction; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; +import org.elasticsearch.search.sort.SortBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; + +/** + * This abstract retriever defines a compound retriever. The idea is that it is not a leaf-retriever, i.e. it does not + * perform actual searches itself. Instead, it is a container for a set of child retrievers and is responsible for combining + * the results of the child retrievers according to the implementation of {@code combineQueryPhaseResults}. + */ +public abstract class CompoundRetrieverBuilder> extends RetrieverBuilder { + + public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} + + protected final int rankWindowSize; + protected final List innerRetrievers; + + protected CompoundRetrieverBuilder(List innerRetrievers, int rankWindowSize) { + this.rankWindowSize = rankWindowSize; + this.innerRetrievers = innerRetrievers; + } + + @SuppressWarnings("unchecked") + public T addChild(RetrieverBuilder retrieverBuilder) { + innerRetrievers.add(new RetrieverSource(retrieverBuilder, null)); + return (T) this; + } + + /** + * Returns a clone of the original retriever, replacing the sub-retrievers with + * the provided {@code newChildRetrievers}. + */ + protected abstract T clone(List newChildRetrievers); + + /** + * Combines the provided {@code rankResults} to return the final top documents. + */ + protected abstract RankDoc[] combineInnerRetrieverResults(List rankResults); + + @Override + public final boolean isCompound() { + return true; + } + + @Override + public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (ctx.getPointInTimeBuilder() == null) { + throw new IllegalStateException("PIT is required"); + } + + // Rewrite prefilters + boolean hasChanged = false; + var newPreFilters = rewritePreFilters(ctx); + hasChanged |= newPreFilters != preFilterQueryBuilders; + + // Rewrite retriever sources + List newRetrievers = new ArrayList<>(); + for (var entry : innerRetrievers) { + RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); + if (newRetriever != entry.retriever) { + newRetrievers.add(new RetrieverSource(newRetriever, null)); + hasChanged |= true; + } else { + var sourceBuilder = entry.source != null + ? entry.source + : createSearchSourceBuilder(ctx.getPointInTimeBuilder(), newRetriever); + var rewrittenSource = sourceBuilder.rewrite(ctx); + newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource)); + hasChanged |= rewrittenSource != entry.source; + } + } + if (hasChanged) { + return clone(newRetrievers); + } + + // execute searches + final SetOnce results = new SetOnce<>(); + final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (var entry : innerRetrievers) { + SearchRequest searchRequest = new SearchRequest().source(entry.source); + // The can match phase can reorder shards, so we disable it to ensure the stable ordering + searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); + multiSearchRequest.add(searchRequest); + } + ctx.registerAsyncAction((client, listener) -> { + client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { + @Override + public void onResponse(MultiSearchResponse items) { + List topDocs = new ArrayList<>(); + List failures = new ArrayList<>(); + for (int i = 0; i < items.getResponses().length; i++) { + var item = items.getResponses()[i]; + if (item.isFailure()) { + failures.add(item.getFailure()); + } else { + assert item.getResponse() != null; + var rankDocs = getRankDocs(item.getResponse()); + innerRetrievers.get(i).retriever().setRankDocs(rankDocs); + topDocs.add(rankDocs); + } + } + if (false == failures.isEmpty()) { + IllegalStateException ex = new IllegalStateException("Search failed - some nested retrievers returned errors."); + failures.forEach(ex::addSuppressed); + listener.onFailure(ex); + } else { + results.set(combineInnerRetrieverResults(topDocs)); + listener.onResponse(null); + } + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + }); + + return new RankDocsRetrieverBuilder( + rankWindowSize, + newRetrievers.stream().map(s -> s.retriever).toList(), + results::get, + newPreFilters + ); + } + + @Override + public final QueryBuilder topDocsQuery() { + throw new IllegalStateException(getName() + " cannot be nested"); + } + + @Override + public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new IllegalStateException("Should not be called, missing a rewrite?"); + } + + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean allowPartialSearchResults + ) { + validationException = super.validate(source, validationException, allowPartialSearchResults); + if (source.size() > rankWindowSize) { + validationException = addValidationError( + "[" + + this.getName() + + "] requires [rank_window_size: " + + rankWindowSize + + "]" + + " be greater than or equal to [size: " + + source.size() + + "]", + validationException + ); + } + if (allowPartialSearchResults) { + validationException = addValidationError( + "cannot specify a compound retriever and [allow_partial_search_results]", + validationException + ); + } + return validationException; + } + + @Override + public boolean doEquals(Object o) { + CompoundRetrieverBuilder that = (CompoundRetrieverBuilder) o; + return rankWindowSize == that.rankWindowSize && Objects.equals(innerRetrievers, that.innerRetrievers); + } + + @Override + public int doHashCode() { + return Objects.hash(innerRetrievers); + } + + private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) + .trackTotalHits(false) + .storedFields(new StoredFieldsContext(false)) + .size(rankWindowSize); + if (preFilterQueryBuilders.isEmpty() == false) { + retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } + retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); + + // apply the pre-filters + if (preFilterQueryBuilders.size() > 0) { + QueryBuilder query = sourceBuilder.query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + if (query != null) { + newQuery.must(query); + } + preFilterQueryBuilders.forEach(newQuery::filter); + sourceBuilder.query(newQuery); + } + + // Record the shard id in the sort result + List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); + if (sortBuilders.isEmpty()) { + sortBuilders.add(new ScoreSortBuilder()); + } + sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + sourceBuilder.sort(sortBuilders); + return sourceBuilder; + } + + private RankDoc[] getRankDocs(SearchResponse searchResponse) { + int size = searchResponse.getHits().getHits().length; + RankDoc[] docs = new RankDoc[size]; + for (int i = 0; i < size; i++) { + var hit = searchResponse.getHits().getAt(i); + long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; + int doc = ShardDocSortField.decodeDoc(sortValue); + int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); + docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); + docs[i].rank = i + 1; + } + return docs; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index ceab04ebb55e4..8e564430ef57a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -127,12 +127,30 @@ public String getName() { @Override public QueryBuilder topDocsQuery() { - assert rankDocs != null : "{rankDocs} should have been materialized at this point"; + assert rankDocs != null : "rankDocs should have been materialized by now"; + var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true); + if (preFilterQueryBuilders.isEmpty()) { + return rankDocsQuery.queryName(retrieverName); + } + BoolQueryBuilder res = new BoolQueryBuilder().must(rankDocsQuery); + preFilterQueryBuilders.forEach(res::filter); + return res.queryName(retrieverName); + } - BoolQueryBuilder knnTopResultsQuery = new BoolQueryBuilder().filter(new RankDocsQueryBuilder(rankDocs)) - .should(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity)); - preFilterQueryBuilders.forEach(knnTopResultsQuery::filter); - return knnTopResultsQuery; + @Override + public QueryBuilder explainQuery() { + assert rankDocs != null : "rankDocs should have been materialized by now"; + var rankDocsQuery = new RankDocsQueryBuilder( + rankDocs, + new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) }, + true + ); + if (preFilterQueryBuilders.isEmpty()) { + return rankDocsQuery.queryName(retrieverName); + } + BoolQueryBuilder res = new BoolQueryBuilder().must(rankDocsQuery); + preFilterQueryBuilders.forEach(res::filter); + return res.queryName(retrieverName); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index 89daafdf05b4b..535db5c8fe28e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -10,14 +10,11 @@ package org.elasticsearch.search.retriever; import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.DisMaxQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; -import org.elasticsearch.search.retriever.rankdoc.RankDocsSortBuilder; -import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; @@ -32,7 +29,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder { public static final String NAME = "rank_docs_retriever"; - private final int rankWindowSize; + final int rankWindowSize; final List sources; final Supplier rankDocs; @@ -44,6 +41,9 @@ public RankDocsRetrieverBuilder( ) { this.rankWindowSize = rankWindowSize; this.rankDocs = rankDocs; + if (sources == null || sources.isEmpty()) { + throw new IllegalArgumentException("sources must not be null or empty"); + } this.sources = sources; this.preFilterQueryBuilders = preFilterQueryBuilders; } @@ -53,6 +53,10 @@ public String getName() { return NAME; } + private boolean sourceHasMinScore() { + return minScore != null || sources.stream().anyMatch(x -> x.minScore() != null); + } + private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException { for (var source : sources) { if (source.isCompound()) { @@ -80,64 +84,92 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { public QueryBuilder topDocsQuery() { // this is used to fetch all documents form the parent retrievers (i.e. sources) // so that we can use all the matched documents to compute aggregations, nested hits etc - DisMaxQueryBuilder disMax = new DisMaxQueryBuilder().tieBreaker(0f); + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); for (var retriever : sources) { var query = retriever.topDocsQuery(); if (query != null) { if (retriever.retrieverName() != null) { query.queryName(retriever.retrieverName()); } - disMax.add(query); + boolQuery.should(query); } } // ignore prefilters of this level, they are already propagated to children - return disMax; + return boolQuery; + } + + @Override + public QueryBuilder explainQuery() { + return new RankDocsQueryBuilder( + rankDocs.get(), + sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), + true + ); } @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - // here we force a custom sort based on the rank of the documents - // TODO: should we adjust to account for other fields sort options just for the top ranked docs? - if (searchSourceBuilder.rescores() == null || searchSourceBuilder.rescores().isEmpty()) { - searchSourceBuilder.sort(Arrays.asList(new RankDocsSortBuilder(rankDocs.get()), new ScoreSortBuilder())); - } - if (searchSourceBuilder.explain() != null && searchSourceBuilder.explain()) { - searchSourceBuilder.trackScores(true); - } - BoolQueryBuilder boolQuery = new BoolQueryBuilder(); - RankDocsQueryBuilder rankQuery = new RankDocsQueryBuilder(rankDocs.get()); + final RankDocsQueryBuilder rankQuery; // if we have aggregations we need to compute them based on all doc matches, not just the top hits - // so we just post-filter the top hits based on the rank queries we have - if (searchSourceBuilder.aggregations() != null) { - boolQuery.should(rankQuery); - // compute a disjunction of all the query sources that were executed to compute the top rank docs - QueryBuilder disjunctionOfSources = topDocsQuery(); - if (disjunctionOfSources != null) { - boolQuery.should(disjunctionOfSources); + // similarly, for profile and explain we re-run all parent queries to get all needed information + RankDoc[] rankDocResults = rankDocs.get(); + if (hasAggregations(searchSourceBuilder) + || isExplainRequest(searchSourceBuilder) + || isProfileRequest(searchSourceBuilder) + || shouldTrackTotalHits(searchSourceBuilder)) { + if (false == isExplainRequest(searchSourceBuilder)) { + rankQuery = new RankDocsQueryBuilder( + rankDocResults, + sources.stream().map(RetrieverBuilder::topDocsQuery).toArray(QueryBuilder[]::new), + false + ); + } else { + rankQuery = new RankDocsQueryBuilder( + rankDocResults, + sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), + false + ); } - // post filter the results so that the top docs are still the same - searchSourceBuilder.postFilter(rankQuery); } else { - boolQuery.must(rankQuery); + rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); } - // add any prefilters present in the retriever - for (var preFilterQueryBuilder : preFilterQueryBuilders) { - boolQuery.filter(preFilterQueryBuilder); + // ignore prefilters of this level, they are already propagated to children + searchSourceBuilder.query(rankQuery); + if (sourceHasMinScore()) { + searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); + } + if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) { + searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from())); } - searchSourceBuilder.query(boolQuery); + } + + private boolean hasAggregations(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.aggregations() != null; + } + + private boolean isExplainRequest(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.explain() != null && searchSourceBuilder.explain(); + } + + private boolean isProfileRequest(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.profile(); + } + + private boolean shouldTrackTotalHits(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.trackTotalHitsUpTo() == null || searchSourceBuilder.trackTotalHitsUpTo() > rankDocs.get().length; } @Override protected boolean doEquals(Object o) { RankDocsRetrieverBuilder other = (RankDocsRetrieverBuilder) o; - return Arrays.equals(rankDocs.get(), other.rankDocs.get()) - && sources.equals(other.sources) - && rankWindowSize == other.rankWindowSize; + return rankWindowSize == other.rankWindowSize + && Arrays.equals(rankDocs.get(), other.rankDocs.get()) + && sources.equals(other.sources); } @Override protected int doHashCode() { - return Objects.hash(super.hashCode(), Arrays.hashCode(rankDocs.get()), sources, rankWindowSize); + return Objects.hash(super.hashCode(), rankWindowSize, Arrays.hashCode(rankDocs.get()), sources); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index e8f6a2d795724..1328106896bcb 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.retriever; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.SuggestingErrorOnUnknown; @@ -53,6 +54,8 @@ public abstract class RetrieverBuilder implements Rewriteable, public static final ParseField PRE_FILTER_FIELD = new ParseField("filter"); + public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); + public static final ParseField NAME_FIELD = new ParseField("_name"); protected static void declareBaseParserFields( @@ -65,41 +68,25 @@ protected static void declareBaseParserFields( return preFilterQueryBuilder; }, PRE_FILTER_FIELD); parser.declareString(RetrieverBuilder::retrieverName, NAME_FIELD); + parser.declareFloat(RetrieverBuilder::minScore, MIN_SCORE_FIELD); } - private void retrieverName(String retrieverName) { + public RetrieverBuilder retrieverName(String retrieverName) { this.retrieverName = retrieverName; + return this; + } + + public RetrieverBuilder minScore(Float minScore) { + this.minScore = minScore; + return this; } - /** - * This method parsers a top-level retriever within a search and tracks its own depth. Currently, the - * maximum depth allowed is limited to 2 as a compound retriever cannot currently contain another - * compound retriever. - */ public static RetrieverBuilder parseTopLevelRetrieverBuilder(XContentParser parser, RetrieverParserContext context) throws IOException { parser = new FilterXContentParserWrapper(parser) { - int nestedDepth = 0; - @Override public T namedObject(Class categoryClass, String name, Object context) throws IOException { - if (categoryClass.equals(RetrieverBuilder.class)) { - nestedDepth++; - - if (nestedDepth > 2) { - throw new IllegalArgumentException( - "the nested depth of the [" + name + "] retriever exceeds the maximum nested depth [2] for retrievers" - ); - } - } - - T namedObject = getXContentRegistry().parseNamedObject(categoryClass, name, this, context); - - if (categoryClass.equals(RetrieverBuilder.class)) { - nestedDepth--; - } - - return namedObject; + return getXContentRegistry().parseNamedObject(categoryClass, name, this, context); } }; @@ -186,6 +173,8 @@ protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser pars protected String retrieverName; + protected Float minScore; + /** * Determines if this retriever contains sub-retrievers that need to be executed prior to search. */ @@ -217,6 +206,14 @@ protected final List rewritePreFilters(QueryRewriteContext ctx) th */ public abstract QueryBuilder topDocsQuery(); + public QueryBuilder explainQuery() { + return topDocsQuery(); + } + + public Float minScore() { + return minScore; + } + public void setRankDocs(RankDoc[] rankDocs) { this.rankDocs = rankDocs; } @@ -239,6 +236,14 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { */ public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed); + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean allowPartialSearchResults + ) { + return validationException; + } + // ---- FOR TESTING XCONTENT PARSING ---- public abstract String getName(); @@ -267,14 +272,16 @@ public final boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RetrieverBuilder that = (RetrieverBuilder) o; - return Objects.equals(preFilterQueryBuilders, that.preFilterQueryBuilders) && doEquals(o); + return Objects.equals(preFilterQueryBuilders, that.preFilterQueryBuilders) + && Objects.equals(minScore, that.minScore) + && doEquals(o); } protected abstract boolean doEquals(Object o); @Override public final int hashCode() { - return Objects.hash(getClass(), preFilterQueryBuilders, doHashCode()); + return Objects.hash(getClass(), preFilterQueryBuilders, minScore, doHashCode()); } protected abstract int doHashCode(); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index 57381bdf558c9..ac329eb293e90 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -43,7 +43,6 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements public static final ParseField SEARCH_AFTER_FIELD = new ParseField("search_after"); public static final ParseField TERMINATE_AFTER_FIELD = new ParseField("terminate_after"); public static final ParseField SORT_FIELD = new ParseField("sort"); - public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); public static final ParseField COLLAPSE_FIELD = new ParseField("collapse"); public static final ObjectParser PARSER = new ObjectParser<>( @@ -76,12 +75,6 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements return sortBuilders; }, SORT_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); - PARSER.declareField((r, v) -> r.minScore = v, (p, c) -> { - float minScore = p.floatValue(); - c.trackSectionUsage(NAME + ":" + MIN_SCORE_FIELD.getPreferredName()); - return minScore; - }, MIN_SCORE_FIELD, ObjectParser.ValueType.FLOAT); - PARSER.declareField((r, v) -> r.collapseBuilder = v, (p, c) -> { CollapseBuilder collapseBuilder = CollapseBuilder.fromXContent(p); if (collapseBuilder.getField() != null) { @@ -104,23 +97,29 @@ public static StandardRetrieverBuilder fromXContent(XContentParser parser, Retri SearchAfterBuilder searchAfterBuilder; int terminateAfter = SearchContext.DEFAULT_TERMINATE_AFTER; List> sortBuilders; - Float minScore; CollapseBuilder collapseBuilder; + public StandardRetrieverBuilder() {} + + public StandardRetrieverBuilder(QueryBuilder queryBuilder) { + this.queryBuilder = queryBuilder; + } + @Override public QueryBuilder topDocsQuery() { - // TODO: for compound retrievers this will have to be reworked as queries like knn could be executed twice if (preFilterQueryBuilders.isEmpty()) { - return queryBuilder; + QueryBuilder qb = queryBuilder; + qb.queryName(this.retrieverName); + return qb; } - var ret = new BoolQueryBuilder().filter(queryBuilder); + var ret = new BoolQueryBuilder().filter(queryBuilder).queryName(this.retrieverName); preFilterQueryBuilders.stream().forEach(ret::filter); return ret; } @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (preFilterQueryBuilders.isEmpty() == false) { + if (preFilterQueryBuilders.isEmpty() == false || minScore != null) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) { @@ -130,7 +129,6 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder if (queryBuilder != null) { boolQueryBuilder.must(queryBuilder); } - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); } else if (queryBuilder != null) { searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(queryBuilder)); @@ -157,32 +155,14 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } if (sortBuilders != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + SORT_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.sort(sortBuilders); } if (minScore != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + MIN_SCORE_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.minScore(minScore); } if (collapseBuilder != null) { - if (compoundUsed) { - throw new IllegalArgumentException( - "[" + COLLAPSE_FIELD.getPreferredName() + "] cannot be used in children of compound retrievers" - ); - } - searchSourceBuilder.collapse(collapseBuilder); } } @@ -212,10 +192,6 @@ public void doToXContent(XContentBuilder builder, ToXContent.Params params) thro builder.field(SORT_FIELD.getPreferredName(), sortBuilders); } - if (minScore != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); - } - if (collapseBuilder != null) { builder.field(COLLAPSE_FIELD.getPreferredName(), collapseBuilder); } @@ -228,13 +204,12 @@ public boolean doEquals(Object o) { && Objects.equals(queryBuilder, that.queryBuilder) && Objects.equals(searchAfterBuilder, that.searchAfterBuilder) && Objects.equals(sortBuilders, that.sortBuilders) - && Objects.equals(minScore, that.minScore) && Objects.equals(collapseBuilder, that.collapseBuilder); } @Override public int doHashCode() { - return Objects.hash(queryBuilder, searchAfterBuilder, terminateAfter, sortBuilders, minScore, collapseBuilder); + return Objects.hash(queryBuilder, searchAfterBuilder, terminateAfter, sortBuilders, collapseBuilder); } // ---- END FOR TESTING ---- diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java index 079e725dd375b..fb5015a82dbdb 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java @@ -9,20 +9,27 @@ package org.elasticsearch.search.retriever.rankdoc; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.elasticsearch.search.rank.RankDoc; import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Objects; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -33,142 +40,315 @@ * after performing any reranking or filtering. */ public class RankDocsQuery extends Query { + /** + * A {@link Query} that matches only the specified {@link RankDoc}, using the provided {@link Query} sources + * solely for the purpose of explainability. + */ + public static class TopQuery extends Query { + private final RankDoc[] docs; + private final Query[] sources; + private final String[] queryNames; + private final int[] segmentStarts; + private final Object contextIdentity; + + TopQuery(RankDoc[] docs, Query[] sources, String[] queryNames, int[] segmentStarts, Object contextIdentity) { + assert sources.length == queryNames.length; + this.docs = docs; + this.sources = sources; + this.queryNames = queryNames; + this.segmentStarts = segmentStarts; + this.contextIdentity = contextIdentity; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + if (docs.length == 0) { + return new MatchNoDocsQuery(); + } + boolean changed = false; + Query[] newSources = new Query[sources.length]; + for (int i = 0; i < sources.length; i++) { + newSources[i] = sources[i].rewrite(searcher); + changed |= newSources[i] != sources[i]; + } + if (changed) { + return new TopQuery(docs, newSources, queryNames, segmentStarts, contextIdentity); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (searcher.getIndexReader().getContext().id() != contextIdentity) { + throw new IllegalStateException("This RankDocsDocQuery was created by a different reader"); + } + Weight[] weights = new Weight[sources.length]; + for (int i = 0; i < sources.length; i++) { + weights[i] = sources[i].createWeight(searcher, scoreMode, boost); + } + return new Weight(this) { + @Override + public int count(LeafReaderContext context) { + return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + int found = binarySearch(docs, 0, docs.length, doc + context.docBase); + if (found < 0) { + return Explanation.noMatch("doc not found in top " + docs.length + " rank docs"); + } + Explanation[] sourceExplanations = new Explanation[sources.length]; + for (int i = 0; i < sources.length; i++) { + sourceExplanations[i] = weights[i].explain(context, doc); + } + return docs[found].explain(sourceExplanations, queryNames); + } + + @Override + public Scorer scorer(LeafReaderContext context) { + // Segment starts indicate how many docs are in the segment, + // upper equalling lower indicates no documents for this segment + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return null; + } + return new Scorer(this) { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + float score; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return currentDocId(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return currentDocId(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docId) { + if (docId != NO_MORE_DOCS) { + docId += context.docBase; + } + float maxScore = 0; + for (int idx = Math.max(lower, upTo); idx < upper && docs[idx].doc <= docId; idx++) { + maxScore = Math.max(maxScore, docs[idx].score); + } + return maxScore; + } + + @Override + public float score() { + return docs[upTo].score; + } + + @Override + public int docID() { + return currentDocId(); + } + + private int currentDocId() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo].doc - context.docBase; + } + + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return this.getClass().getSimpleName() + "{rank_docs:" + Arrays.toString(docs) + "}"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + TopQuery other = (TopQuery) obj; + return Arrays.equals(docs, other.docs) + && Arrays.equals(segmentStarts, other.segmentStarts) + && contextIdentity == other.contextIdentity; + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), Arrays.hashCode(docs), Arrays.hashCode(segmentStarts), contextIdentity); + } + } private final RankDoc[] docs; - private final int[] segmentStarts; - private final Object contextIdentity; + // topQuery is used to match just the top docs from all the original queries. This match is based on the RankDoc array + // provided when constructing the object. This is the only clause that actually contributes to scoring + private final Query topQuery; + // tailQuery is used to match all the original documents that were used to compute the top docs. + // This is useful if we want to compute aggregations, total hits etc based on all matching documents, and not just the top + // RankDocs provided. This query does not contribute to scoring, as it is set as filter when creating the weight + private final Query tailQuery; + private final boolean onlyRankDocs; /** * Creates a {@code RankDocsQuery} based on the provided docs. * - * @param docs the global doc IDs of documents that match, in ascending order - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param contextIdentity an object identifying the reader context that was used to build this - * query + * @param rankDocs The global doc IDs of documents that match, in ascending order + * @param sources The original queries that were used to compute the top documents + * @param queryNames The names (if present) of the original retrievers + * @param onlyRankDocs Whether the query should only match the provided rank docs */ - RankDocsQuery(RankDoc[] docs, int[] segmentStarts, Object contextIdentity) { + public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs) { + assert sources.length == queryNames.length; + // clone to avoid side-effect after sorting + this.docs = rankDocs.clone(); + // sort rank docs by doc id + Arrays.sort(docs, Comparator.comparingInt(a -> a.doc)); + this.topQuery = new TopQuery(docs, sources, queryNames, findSegmentStarts(reader, docs), reader.getContext().id()); + if (sources.length > 0 && false == onlyRankDocs) { + var bq = new BooleanQuery.Builder(); + for (var source : sources) { + bq.add(source, BooleanClause.Occur.SHOULD); + } + this.tailQuery = bq.build(); + } else { + this.tailQuery = null; + } + this.onlyRankDocs = onlyRankDocs; + } + + private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs) { this.docs = docs; - this.segmentStarts = segmentStarts; - this.contextIdentity = contextIdentity; + this.topQuery = topQuery; + this.tailQuery = tailQuery; + this.onlyRankDocs = onlyRankDocs; } - @Override - public Query rewrite(IndexSearcher searcher) throws IOException { - if (docs.length == 0) { - return new MatchNoDocsQuery(); + private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) { + return Arrays.binarySearch(docs, fromIndex, toIndex, new RankDoc(key, Float.NaN, -1), Comparator.comparingInt(a -> a.doc)); + } + + private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + + resultIndex = binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; } - return this; + return starts; } RankDoc[] rankDocs() { return docs; } + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + if (tailQuery == null) { + return topQuery; + } + boolean hasChanged = false; + var topRewrite = topQuery.rewrite(searcher); + if (topRewrite != topQuery) { + hasChanged = true; + } + var tailRewrite = tailQuery.rewrite(searcher); + if (tailRewrite != tailQuery) { + hasChanged = true; + } + return hasChanged ? new RankDocsQuery(docs, topRewrite, tailRewrite, onlyRankDocs) : this; + } + @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - if (searcher.getIndexReader().getContext().id() != contextIdentity) { - throw new IllegalStateException("This RankDocsDocQuery was created by a different reader"); + if (tailQuery == null) { + throw new IllegalArgumentException("[tailQuery] should not be null; maybe missing a rewrite?"); } + var combined = new BooleanQuery.Builder().add(topQuery, onlyRankDocs ? BooleanClause.Occur.MUST : BooleanClause.Occur.SHOULD) + .add(tailQuery, BooleanClause.Occur.FILTER) + .build(); + var topWeight = topQuery.createWeight(searcher, scoreMode, boost); + var combinedWeight = searcher.rewrite(combined).createWeight(searcher, scoreMode, boost); return new Weight(this) { - @Override - public int count(LeafReaderContext context) { - return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + public int count(LeafReaderContext context) throws IOException { + return combinedWeight.count(context); } @Override - public Explanation explain(LeafReaderContext context, int doc) { - int found = Arrays.binarySearch(docs, doc + context.docBase, (a, b) -> Integer.compare(((RankDoc) a).doc, (int) b)); - if (found < 0) { - return Explanation.noMatch("doc not found in top " + docs.length + " rank docs"); - } - return docs[found].explain(); + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return topWeight.explain(context, doc); } @Override - public Scorer scorer(LeafReaderContext context) { - // Segment starts indicate how many docs are in the segment, - // upper equalling lower indicates no documents for this segment - if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { - return null; - } - return new Scorer(this) { - final int lower = segmentStarts[context.ord]; - final int upper = segmentStarts[context.ord + 1]; - int upTo = -1; - float score; - - @Override - public DocIdSetIterator iterator() { - return new DocIdSetIterator() { - @Override - public int docID() { - return currentDocId(); - } - - @Override - public int nextDoc() { - if (upTo == -1) { - upTo = lower; - } else { - ++upTo; - } - return currentDocId(); - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - - @Override - public long cost() { - return upper - lower; - } - }; - } - - @Override - public float getMaxScore(int docId) { - if (docId != NO_MORE_DOCS) { - docId += context.docBase; - } - float maxScore = 0; - for (int idx = Math.max(lower, upTo); idx < upper && docs[idx].doc <= docId; idx++) { - maxScore = Math.max(maxScore, docs[idx].score); - } - return maxScore; - } - - @Override - public float score() { - return docs[upTo].score; - } + public Scorer scorer(LeafReaderContext context) throws IOException { + return combinedWeight.scorer(context); + } - @Override - public int docID() { - return currentDocId(); - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return combinedWeight.isCacheable(ctx); + } - private int currentDocId() { - if (upTo == -1) { - return -1; - } - if (upTo >= upper) { - return NO_MORE_DOCS; - } - return docs[upTo].doc - context.docBase; - } + @Override + public Matches matches(LeafReaderContext context, int doc) throws IOException { + return combinedWeight.matches(context, doc); + } - }; + @Override + public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { + return combinedWeight.bulkScorer(context); } @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return combinedWeight.scorerSupplier(context); } }; } @@ -180,7 +360,10 @@ public String toString(String field) { @Override public void visit(QueryVisitor visitor) { - visitor.visitLeaf(this); + topQuery.visit(visitor); + if (tailQuery != null) { + tailQuery.visit(visitor); + } } @Override @@ -188,13 +371,12 @@ public boolean equals(Object obj) { if (sameClassAs(obj) == false) { return false; } - return Arrays.equals(docs, ((RankDocsQuery) obj).docs) - && Arrays.equals(segmentStarts, ((RankDocsQuery) obj).segmentStarts) - && contextIdentity == ((RankDocsQuery) obj).contextIdentity; + RankDocsQuery other = (RankDocsQuery) obj; + return Objects.equals(topQuery, other.topQuery) && Objects.equals(tailQuery, other.tailQuery) && onlyRankDocs == other.onlyRankDocs; } @Override public int hashCode() { - return Objects.hash(classHash(), Arrays.hashCode(docs), Arrays.hashCode(segmentStarts), contextIdentity); + return Objects.hash(classHash(), topQuery, tailQuery, onlyRankDocs); } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java index 2b77ae543a86c..86cb27cb7ba7e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java @@ -13,30 +13,45 @@ import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Arrays; -import java.util.Comparator; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE; public class RankDocsQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "rank_docs_query"; private final RankDoc[] rankDocs; + private final QueryBuilder[] queryBuilders; + private final boolean onlyRankDocs; - public RankDocsQueryBuilder(RankDoc[] rankDocs) { + public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, boolean onlyRankDocs) { this.rankDocs = rankDocs; + this.queryBuilders = queryBuilders; + this.onlyRankDocs = onlyRankDocs; } public RankDocsQueryBuilder(StreamInput in) throws IOException { super(in); this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new); + if (in.getTransportVersion().onOrAfter(RRF_QUERY_REWRITE)) { + this.queryBuilders = in.readOptionalArray(c -> c.readNamedWriteable(QueryBuilder.class), QueryBuilder[]::new); + this.onlyRankDocs = in.readBoolean(); + } else { + this.queryBuilders = null; + this.onlyRankDocs = false; + } } RankDoc[] rankDocs() { @@ -46,6 +61,10 @@ RankDoc[] rankDocs() { @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); + if (out.getTransportVersion().onOrAfter(RRF_QUERY_REWRITE)) { + out.writeOptionalArray(StreamOutput::writeNamedWriteable, queryBuilders); + out.writeBoolean(onlyRankDocs); + } } @Override @@ -57,29 +76,22 @@ public String getWriteableName() { protected Query doToQuery(SearchExecutionContext context) throws IOException { RankDoc[] shardRankDocs = Arrays.stream(rankDocs) .filter(r -> r.shardIndex == context.getShardRequestIndex()) - .sorted(Comparator.comparingInt(r -> r.doc)) .toArray(RankDoc[]::new); IndexReader reader = context.getIndexReader(); - int[] segmentStarts = findSegmentStarts(reader, shardRankDocs); - return new RankDocsQuery(shardRankDocs, segmentStarts, reader.getContext().id()); - } - - private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper, (a, b) -> Integer.compare(((RankDoc) a).doc, (int) b)); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; + final Query[] queries; + final String[] queryNames; + if (queryBuilders != null) { + queries = new Query[queryBuilders.length]; + queryNames = new String[queryBuilders.length]; + for (int i = 0; i < queryBuilders.length; i++) { + queries[i] = queryBuilders[i].toQuery(context); + queryNames[i] = queryBuilders[i].queryName(); } - starts[i] = resultIndex; + } else { + queries = new Query[0]; + queryNames = Strings.EMPTY_ARRAY; } - return starts; + return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs); } @Override @@ -97,12 +109,14 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep @Override protected boolean doEquals(RankDocsQueryBuilder other) { - return Arrays.equals(rankDocs, other.rankDocs); + return Arrays.equals(rankDocs, other.rankDocs) + && Arrays.equals(queryBuilders, other.queryBuilders) + && onlyRankDocs == other.onlyRankDocs; } @Override protected int doHashCode() { - return Arrays.hashCode(rankDocs); + return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilder.java deleted file mode 100644 index cfe307af1767a..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilder.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever.rankdoc; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.search.DocValueFormat; -import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.sort.BucketedSort; -import org.elasticsearch.search.sort.SortBuilder; -import org.elasticsearch.search.sort.SortFieldAndFormat; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Objects; - -/** - * Builds a {@code RankDocsSortField} that sorts documents by their rank as computed through the {@code RankDocsRetrieverBuilder}. - */ -public class RankDocsSortBuilder extends SortBuilder { - public static final String NAME = "rank_docs_sort"; - - private RankDoc[] rankDocs; - - public RankDocsSortBuilder(RankDoc[] rankDocs) { - this.rankDocs = rankDocs; - } - - public RankDocsSortBuilder(StreamInput in) throws IOException { - this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new); - } - - public RankDocsSortBuilder(RankDocsSortBuilder original) { - this.rankDocs = original.rankDocs; - } - - public RankDocsSortBuilder rankDocs(RankDoc[] rankDocs) { - this.rankDocs = rankDocs; - return this; - } - - public RankDoc[] rankDocs() { - return this.rankDocs; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public SortBuilder rewrite(QueryRewriteContext ctx) throws IOException { - return this; - } - - @Override - protected SortFieldAndFormat build(SearchExecutionContext context) throws IOException { - RankDoc[] shardRankDocs = Arrays.stream(rankDocs) - .filter(r -> r.shardIndex == context.getShardRequestIndex()) - .toArray(RankDoc[]::new); - return new SortFieldAndFormat(new RankDocsSortField(shardRankDocs), DocValueFormat.RAW); - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.RANK_DOCS_RETRIEVER; - } - - @Override - public BucketedSort buildBucketedSort(SearchExecutionContext context, BigArrays bigArrays, int bucketSize, BucketedSort.ExtraData extra) - throws IOException { - throw new UnsupportedOperationException("buildBucketedSort() is not supported for " + this.getClass()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - throw new UnsupportedOperationException("toXContent() is not supported for " + this.getClass()); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - RankDocsSortBuilder that = (RankDocsSortBuilder) obj; - return Arrays.equals(rankDocs, that.rankDocs) && this.order.equals(that.order); - } - - @Override - public int hashCode() { - return Objects.hash(Arrays.hashCode(this.rankDocs), this.order); - } -} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortField.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortField.java deleted file mode 100644 index 9fd2aceaf7949..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortField.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever.rankdoc; - -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.FieldComparator; -import org.apache.lucene.search.FieldComparatorSource; -import org.apache.lucene.search.LeafFieldComparator; -import org.apache.lucene.search.Pruning; -import org.apache.lucene.search.Scorable; -import org.apache.lucene.search.SortField; -import org.apache.lucene.search.comparators.NumericComparator; -import org.apache.lucene.util.hnsw.IntToIntFunction; -import org.elasticsearch.search.rank.RankDoc; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * A {@link SortField} that sorts documents by their rank as computed through the {@code RankDocsRetrieverBuilder}. - * This is used when we want to score and rank the documents irrespective of their original scores, - * but based on the provided rank they were assigned, e.g. through an RRF retriever. - **/ -public class RankDocsSortField extends SortField { - - public static final String NAME = "_rank"; - - public RankDocsSortField(RankDoc[] rankDocs) { - super(NAME, new FieldComparatorSource() { - @Override - public FieldComparator newComparator(String fieldname, int numHits, Pruning pruning, boolean reversed) { - return new RankDocsComparator(numHits, rankDocs); - } - }); - } - - private static class RankDocsComparator extends NumericComparator { - private final int[] values; - private final Map rankDocMap; - private int topValue; - private int bottom; - - private RankDocsComparator(int numHits, RankDoc[] rankDocs) { - super(NAME, Integer.MAX_VALUE, false, Pruning.NONE, Integer.BYTES); - this.values = new int[numHits]; - this.rankDocMap = Arrays.stream(rankDocs).collect(Collectors.toMap(k -> k.doc, v -> v.rank)); - } - - @Override - public int compare(int slot1, int slot2) { - return Integer.compare(values[slot1], values[slot2]); - } - - @Override - public Integer value(int slot) { - return Integer.valueOf(values[slot]); - } - - @Override - public void setTopValue(Integer value) { - topValue = value; - } - - @Override - public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException { - IntToIntFunction docToRank = doc -> rankDocMap.getOrDefault(context.docBase + doc, Integer.MAX_VALUE); - return new LeafFieldComparator() { - @Override - public void setBottom(int slot) throws IOException { - bottom = values[slot]; - } - - @Override - public int compareBottom(int doc) { - return Integer.compare(bottom, docToRank.apply(doc)); - } - - @Override - public int compareTop(int doc) { - return Integer.compare(topValue, docToRank.apply(doc)); - } - - @Override - public void copy(int slot, int doc) { - values[slot] = docToRank.apply(doc); - } - - @Override - public void setScorer(Scorable scorer) {} - }; - } - } -} diff --git a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java index 0bdb3ea0cd247..a38a24eb75fca 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java @@ -73,4 +73,18 @@ public LeafFieldComparator getLeafComparator(LeafReaderContext context) { } }; } + + /** + * Get the doc id encoded in the sort value. + */ + public static int decodeDoc(long value) { + return (int) value; + } + + /** + * Get the shard request index encoded in the sort value. + */ + public static int decodeShardRequestIndex(long value) { + return (int) (value >> 32); + } } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index 1d430f2ae1079..23c956e6e52f2 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -28,10 +28,10 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; -import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.rank.TestRankBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.TestCompoundRetrieverBuilder; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; @@ -263,40 +263,9 @@ public void testValidate() throws IOException { } { // allow_partial_results and compound retriever - SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().retriever(new RetrieverBuilder() { - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - // no-op - } - - @Override - public String getName() { - return "compound_retriever"; - } - - @Override - protected void doToXContent(XContentBuilder builder, Params params) throws IOException {} - - @Override - protected boolean doEquals(Object o) { - return false; - } - - @Override - protected int doHashCode() { - return 0; - } - - @Override - public boolean isCompound() { - return true; - } - - @Override - public QueryBuilder topDocsQuery() { - return null; - } - })); + SearchRequest searchRequest = createSearchRequest().source( + new SearchSourceBuilder().retriever(new TestCompoundRetrieverBuilder(randomIntBetween(1, 10))) + ); searchRequest.allowPartialSearchResults(true); searchRequest.scroll((Scroll) null); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -583,30 +552,6 @@ public QueryBuilder topDocsQuery() { assertEquals(1, validationErrors.validationErrors().size()); assertEquals("[rank] cannot be used with [rescore]", validationErrors.validationErrors().get(0)); } - { - SearchRequest searchRequest = new SearchRequest().source( - new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) - .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) - .sort("test") - ); - ActionRequestValidationException validationErrors = searchRequest.validate(); - assertNotNull(validationErrors); - assertEquals(1, validationErrors.validationErrors().size()); - assertEquals("[rank] cannot be used with [sort]", validationErrors.validationErrors().get(0)); - } - { - SearchRequest searchRequest = new SearchRequest().source( - new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) - .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) - .collapse(new CollapseBuilder("field")) - ); - ActionRequestValidationException validationErrors = searchRequest.validate(); - assertNotNull(validationErrors); - assertEquals(1, validationErrors.validationErrors().size()); - assertEquals("[rank] cannot be used with [collapse]", validationErrors.validationErrors().get(0)); - } { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) @@ -619,30 +564,6 @@ public QueryBuilder topDocsQuery() { assertEquals(1, validationErrors.validationErrors().size()); assertEquals("[rank] cannot be used with [suggest]", validationErrors.validationErrors().get(0)); } - { - SearchRequest searchRequest = new SearchRequest().source( - new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) - .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) - .highlighter(new HighlightBuilder().field("field")) - ); - ActionRequestValidationException validationErrors = searchRequest.validate(); - assertNotNull(validationErrors); - assertEquals(1, validationErrors.validationErrors().size()); - assertEquals("[rank] cannot be used with [highlighter]", validationErrors.validationErrors().get(0)); - } - { - SearchRequest searchRequest = new SearchRequest().source( - new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) - .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) - .pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("test"))) - ); - ActionRequestValidationException validationErrors = searchRequest.validate(); - assertNotNull(validationErrors); - assertEquals(1, validationErrors.validationErrors().size()); - assertEquals("[rank] cannot be used with [point in time]", validationErrors.validationErrors().get(0)); - } { SearchRequest searchRequest = new SearchRequest("test").source( new SearchSourceBuilder().pointInTimeBuilder(new PointInTimeBuilder(BytesArray.EMPTY)) diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java index db419b4019acf..d190139309c31 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java @@ -50,9 +50,4 @@ protected RankDoc mutateInstance(RankDoc instance) throws IOException { } return mutated; } - - public void testExplain() { - RankDoc instance = createTestRankDoc(); - assertEquals(instance.explain().toString(), instance.explain().toString()); - } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 23a6357fa61be..f3dd86e0b1fa2 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; -import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -122,17 +121,15 @@ public void testTopDocsQuery() { final int preFilters = knnRetriever.preFilterQueryBuilders.size(); QueryBuilder topDocsQuery = knnRetriever.topDocsQuery(); assertNotNull(topDocsQuery); - assertThat(topDocsQuery, instanceOf(BoolQueryBuilder.class)); - assertThat(((BoolQueryBuilder) topDocsQuery).filter().size(), equalTo(1 + preFilters)); - assertThat(((BoolQueryBuilder) topDocsQuery).filter().get(0), instanceOf(RankDocsQueryBuilder.class)); - for (int i = 0; i < preFilters; i++) { - assertThat( - ((BoolQueryBuilder) topDocsQuery).filter().get(i + 1), - instanceOf(knnRetriever.preFilterQueryBuilders.get(i).getClass()) - ); + assertThat(topDocsQuery, anyOf(instanceOf(RankDocsQueryBuilder.class), instanceOf(BoolQueryBuilder.class))); + if (topDocsQuery instanceof BoolQueryBuilder bq) { + assertThat(bq.must().size(), equalTo(1)); + assertThat(bq.must().get(0), instanceOf(RankDocsQueryBuilder.class)); + assertThat(bq.filter().size(), equalTo(preFilters)); + for (int i = 0; i < preFilters; i++) { + assertThat(bq.filter().get(i), instanceOf(knnRetriever.preFilterQueryBuilders.get(i).getClass())); + } } - assertThat(((BoolQueryBuilder) topDocsQuery).should().size(), equalTo(1)); - assertThat(((BoolQueryBuilder) topDocsQuery).should().get(0), instanceOf(ExactKnnQueryBuilder.class)); } @Override diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index e8ad7d128dac2..bcb93b100ea48 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -10,9 +10,6 @@ package org.elasticsearch.search.retriever; import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.DisMaxQueryBuilder; -import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; @@ -21,8 +18,6 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; -import org.elasticsearch.search.retriever.rankdoc.RankDocsSortBuilder; -import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -30,10 +25,10 @@ import java.util.List; import java.util.function.Supplier; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -93,7 +88,7 @@ private List preFilters() { } private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder() { - return new RankDocsRetrieverBuilder(randomInt(100), innerRetrievers(), rankDocsSupplier(), preFilters()); + return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(), rankDocsSupplier(), preFilters()); } public void testExtractToSearchSourceBuilder() { @@ -102,32 +97,30 @@ public void testExtractToSearchSourceBuilder() { if (randomBoolean()) { source.aggregation(new TermsAggregationBuilder("name").field("field")); } + source.explain(randomBoolean()); + source.profile(randomBoolean()); + source.trackTotalHits(randomBoolean()); + final int preFilters = retriever.preFilterQueryBuilders.size(); retriever.extractToSearchSourceBuilder(source, randomBoolean()); - assertThat(source.sorts().size(), equalTo(2)); - assertThat(source.sorts().get(0), instanceOf(RankDocsSortBuilder.class)); - assertThat(source.sorts().get(1), instanceOf(ScoreSortBuilder.class)); - assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); - BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); - if (source.aggregations() != null) { - assertThat(bq.must().size(), equalTo(0)); - assertThat(bq.should().size(), greaterThanOrEqualTo(1)); - assertThat(bq.should().get(0), instanceOf(RankDocsQueryBuilder.class)); - assertNotNull(source.postFilter()); - assertThat(source.postFilter(), instanceOf(RankDocsQueryBuilder.class)); - } else { + assertNull(source.sorts()); + assertThat(source.query(), anyOf(instanceOf(BoolQueryBuilder.class), instanceOf(RankDocsQueryBuilder.class))); + if (source.query() instanceof BoolQueryBuilder bq) { assertThat(bq.must().size(), equalTo(1)); assertThat(bq.must().get(0), instanceOf(RankDocsQueryBuilder.class)); - assertNull(source.postFilter()); + assertThat(bq.filter().size(), equalTo(preFilters)); + for (int i = 0; i < preFilters; i++) { + assertThat(bq.filter().get(i), instanceOf(retriever.preFilterQueryBuilders.get(i).getClass())); + } } - assertThat(bq.filter().size(), equalTo(retriever.preFilterQueryBuilders.size())); + assertNull(source.postFilter()); } public void testTopDocsQuery() { RankDocsRetrieverBuilder retriever = createRandomRankDocsRetrieverBuilder(); QueryBuilder topDocs = retriever.topDocsQuery(); assertNotNull(topDocs); - assertThat(topDocs, instanceOf(DisMaxQueryBuilder.class)); - assertThat(((DisMaxQueryBuilder) topDocs).innerQueries(), hasSize(retriever.sources.size())); + assertThat(topDocs, instanceOf(BoolQueryBuilder.class)); + assertThat(((BoolQueryBuilder) topDocs).should(), hasSize(retriever.sources.size())); } public void testRewrite() throws IOException { @@ -144,22 +137,27 @@ public boolean isCompound() { } SearchSourceBuilder source = new SearchSourceBuilder().retriever(retriever); QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); - if (compoundAdded) { - expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); + int size = source.size() < 0 ? DEFAULT_SIZE : source.size(); + if (retriever.rankWindowSize < size) { + if (compoundAdded) { + expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); + } } else { - SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext); - assertNull(rewrittenSource.retriever()); - assertTrue(rewrittenSource.knnSearch().isEmpty()); - assertThat( - rewrittenSource.query(), - anyOf(instanceOf(BoolQueryBuilder.class), instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class)) - ); - if (rewrittenSource.query() instanceof BoolQueryBuilder) { - BoolQueryBuilder bq = (BoolQueryBuilder) rewrittenSource.query(); - assertThat(bq.filter().size(), equalTo(retriever.preFilterQueryBuilders.size())); - // we don't have any aggregations so the RankDocs query is set as a must clause - assertThat(bq.must().size(), equalTo(1)); - assertThat(bq.must().get(0), instanceOf(RankDocsQueryBuilder.class)); + if (compoundAdded) { + expectThrows(AssertionError.class, () -> Rewriteable.rewrite(source, queryRewriteContext)); + } else { + SearchSourceBuilder rewrittenSource = Rewriteable.rewrite(source, queryRewriteContext); + assertNull(rewrittenSource.retriever()); + assertTrue(rewrittenSource.knnSearch().isEmpty()); + assertThat(rewrittenSource.query(), instanceOf(RankDocsQueryBuilder.class)); + if (rewrittenSource.query() instanceof BoolQueryBuilder) { + BoolQueryBuilder bq = (BoolQueryBuilder) rewrittenSource.query(); + assertThat(bq.filter().size(), equalTo(retriever.preFilterQueryBuilders.size())); + assertThat(bq.must().size(), equalTo(1)); + assertThat(bq.must().get(0), instanceOf(BoolQueryBuilder.class)); + assertThat(bq.should().size(), equalTo(1)); + assertThat(bq.should().get(0), instanceOf(RankDocsQueryBuilder.class)); + } } } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RetrieverBuilderErrorTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RetrieverBuilderErrorTests.java index cc8f5fe85d09a..66240e205e26b 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RetrieverBuilderErrorTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RetrieverBuilderErrorTests.java @@ -73,14 +73,6 @@ public void testRetrieverExtractionErrors() throws IOException { assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [terminate_after]")); } - try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"sort\": [\"field\"], \"retriever\":{\"standard\":{}}}")) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - ssb.parseXContent(parser, true, nf -> true); - ActionRequestValidationException iae = ssb.validate(null, false, false); - assertNotNull(iae); - assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [sort]")); - } - try ( XContentParser parser = createParser( JsonXContent.jsonXContent, @@ -94,14 +86,6 @@ public void testRetrieverExtractionErrors() throws IOException { assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [rescore]")); } - try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"min_score\": 2, \"retriever\":{\"standard\":{}}}")) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - ssb.parseXContent(parser, true, nf -> true); - ActionRequestValidationException iae = ssb.validate(null, false, false); - assertNotNull(iae); - assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [min_score]")); - } - try ( XContentParser parser = createParser( JsonXContent.jsonXContent, @@ -112,7 +96,7 @@ public void testRetrieverExtractionErrors() throws IOException { ssb.parseXContent(parser, true, nf -> true); ActionRequestValidationException iae = ssb.validate(null, false, false); assertNotNull(iae); - assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query, terminate_after, min_score]")); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query, terminate_after]")); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilderTests.java index 01c915530bc4c..ca05c57b7d733 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilderTests.java @@ -10,8 +10,16 @@ package org.elasticsearch.search.retriever.rankdoc; import org.apache.lucene.document.Document; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopScoreDocCollectorManager; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.index.query.QueryBuilder; @@ -21,6 +29,10 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Random; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class RankDocsQueryBuilderTests extends AbstractQueryTestCase { @@ -39,7 +51,7 @@ private RankDoc[] generateRandomRankDocs() { @Override protected RankDocsQueryBuilder doCreateTestQueryBuilder() { RankDoc[] rankDocs = generateRandomRankDocs(); - return new RankDocsQueryBuilder(rankDocs); + return new RankDocsQueryBuilder(rankDocs, null, false); } @Override @@ -104,6 +116,111 @@ public void testMustRewrite() throws IOException { } } + public void testRankDocsQueryEarlyTerminate() throws IOException { + try (Directory directory = newDirectory()) { + IndexWriterConfig config = new IndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE); + try (IndexWriter iw = new IndexWriter(directory, config)) { + int seg = atLeast(5); + int numDocs = atLeast(20); + for (int i = 0; i < seg; i++) { + for (int j = 0; j < numDocs; j++) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("active", 1)); + iw.addDocument(doc); + } + if (frequently()) { + iw.flush(); + } + } + } + try (IndexReader reader = DirectoryReader.open(directory)) { + int topSize = randomIntBetween(1, reader.maxDoc() / 5); + RankDoc[] rankDocs = new RankDoc[topSize]; + int index = 0; + for (int r : randomSample(random(), reader.maxDoc(), topSize)) { + rankDocs[index++] = new RankDoc(r, randomFloat(), randomIntBetween(0, 5)); + } + Arrays.sort(rankDocs); + for (int i = 0; i < rankDocs.length; i++) { + rankDocs[i].rank = i; + } + IndexSearcher searcher = new IndexSearcher(reader); + for (int totalHitsThreshold = 0; totalHitsThreshold < reader.maxDoc(); totalHitsThreshold += randomIntBetween(1, 10)) { + // Tests that the query matches only the {@link RankDoc} when the hit threshold is reached. + RankDocsQuery q = new RankDocsQuery( + reader, + rankDocs, + new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, + new String[1], + false + ); + var topDocsManager = new TopScoreDocCollectorManager(topSize, null, totalHitsThreshold); + var col = searcher.search(q, topDocsManager); + // depending on the doc-ids of the RankDocs (i.e. the actual docs to have score) we could visit them last, + // so worst case is we could end up collecting up to 1 + max(topSize , totalHitsThreshold) + rankDocs.length documents + // as we could have already filled the priority queue with non-optimal docs + assertThat( + col.totalHits.value, + lessThanOrEqualTo((long) (1 + Math.max(topSize, totalHitsThreshold) + rankDocs.length)) + ); + assertEqualTopDocs(col.scoreDocs, rankDocs); + } + + { + // Return all docs (rank + tail) + RankDocsQuery q = new RankDocsQuery( + reader, + rankDocs, + new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, + new String[1], + false + ); + var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE); + var col = searcher.search(q, topDocsManager); + assertThat(col.totalHits.value, equalTo((long) reader.maxDoc())); + assertEqualTopDocs(col.scoreDocs, rankDocs); + } + + { + // Only rank docs + RankDocsQuery q = new RankDocsQuery( + reader, + rankDocs, + new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, + new String[1], + true + ); + var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE); + var col = searcher.search(q, topDocsManager); + assertThat(col.totalHits.value, equalTo((long) topSize)); + assertEqualTopDocs(col.scoreDocs, rankDocs); + } + } + } + } + + private static int[] randomSample(Random rand, int n, int k) { + int[] reservoir = new int[k]; + for (int i = 0; i < k; i++) { + reservoir[i] = i; + } + for (int i = k; i < n; i++) { + int j = rand.nextInt(i + 1); + if (j < k) { + reservoir[j] = i; + } + } + return reservoir; + } + + private static void assertEqualTopDocs(ScoreDoc[] scoreDocs, RankDoc[] rankDocs) { + for (int i = 0; i < scoreDocs.length; i++) { + assertEquals(rankDocs[i].doc, scoreDocs[i].doc); + assertEquals(rankDocs[i].score, scoreDocs[i].score, 0f); + assertEquals(-1, scoreDocs[i].shardIndex); + } + } + @Override public void testFromXContent() throws IOException { // no-op since RankDocsQueryBuilder is an internal only API diff --git a/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilderTests.java deleted file mode 100644 index 2c12126769c35..0000000000000 --- a/server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsSortBuilderTests.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever.rankdoc; - -import org.apache.lucene.search.SortField; -import org.elasticsearch.search.DocValueFormat; -import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.sort.AbstractSortTestCase; -import org.elasticsearch.search.sort.SortOrder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -public class RankDocsSortBuilderTests extends AbstractSortTestCase { - - @Override - protected RankDocsSortBuilder createTestItem() { - return randomRankDocsSortBuulder(); - } - - private RankDocsSortBuilder randomRankDocsSortBuulder() { - RankDoc[] rankDocs = randomRankDocs(randomInt(100)); - return new RankDocsSortBuilder(rankDocs); - } - - private RankDoc[] randomRankDocs(int totalDocs) { - RankDoc[] rankDocs = new RankDoc[totalDocs]; - for (int i = 0; i < totalDocs; i++) { - rankDocs[i] = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1)); - rankDocs[i].rank = i + 1; - } - return rankDocs; - } - - @Override - protected RankDocsSortBuilder mutate(RankDocsSortBuilder original) throws IOException { - RankDocsSortBuilder mutated = new RankDocsSortBuilder(original); - mutated.rankDocs(randomRankDocs(original.rankDocs().length + randomIntBetween(10, 100))); - return mutated; - } - - @Override - public void testFromXContent() throws IOException { - // no-op - } - - @Override - protected RankDocsSortBuilder fromXContent(XContentParser parser, String fieldName) throws IOException { - throw new UnsupportedOperationException( - "{" + RankDocsSortBuilder.class.getSimpleName() + "} does not support parsing from XContent" - ); - } - - @Override - protected void sortFieldAssertions(RankDocsSortBuilder builder, SortField sortField, DocValueFormat format) throws IOException { - assertThat(builder.order(), equalTo(SortOrder.ASC)); - assertThat(sortField, instanceOf(RankDocsSortField.class)); - assertThat(sortField.getField(), equalTo(RankDocsSortField.NAME)); - assertThat(sortField.getType(), equalTo(SortField.Type.CUSTOM)); - assertThat(sortField.getReverse(), equalTo(false)); - } -} diff --git a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java new file mode 100644 index 0000000000000..9f199aa7f3ef8 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.net.UnknownServiceException; +import java.util.ArrayList; +import java.util.List; + +public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String NAME = "test_compound_retriever_builder"; + + public TestCompoundRetrieverBuilder(int rankWindowSize) { + this(new ArrayList<>(), rankWindowSize); + } + + TestCompoundRetrieverBuilder(List childRetrievers, int rankWindowSize) { + super(childRetrievers, rankWindowSize); + } + + @Override + protected TestCompoundRetrieverBuilder clone(List newChildRetrievers) { + return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize); + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + return new RankDoc[0]; + } + + @Override + public String getName() { + return NAME; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + throw new UnknownServiceException("should not be called"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 927c708268a49..ab013e0275a69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -43,7 +43,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text"); public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); - public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> { @@ -52,9 +51,8 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; - Float minScore = (Float) args[5]; - return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, minScore); + return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize); }); static { @@ -63,7 +61,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD); PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD); RetrieverBuilder.declareBaseParserFields(TextSimilarityRankBuilder.NAME, PARSER); } @@ -84,22 +81,19 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(XContentParser par private final String inferenceText; private final String field; private final int rankWindowSize; - private final Float minScore; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, String inferenceId, String inferenceText, String field, - int rankWindowSize, - Float minScore + int rankWindowSize ) { this.retrieverBuilder = retrieverBuilder; this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; this.rankWindowSize = rankWindowSize; - this.minScore = minScore; } public TextSimilarityRankRetrieverBuilder( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index c834f58f1134b..1a72cb0da2899 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -62,8 +62,7 @@ public static TextSimilarityRankRetrieverBuilder createRandomTextSimilarityRankR randomAlphaOfLength(10), randomAlphaOfLength(20), randomAlphaOfLength(50), - randomIntBetween(100, 10000), - randomBoolean() ? null : randomFloatBetween(-1.0f, 1.0f, true) + randomIntBetween(100, 10000) ); } diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java new file mode 100644 index 0000000000000..8b924af48c631 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -0,0 +1,656 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.InnerHitBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.collapse.CollapseBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.xcontent.XContentType; +import org.junit.Before; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 3) +public class RRFRetrieverBuilderIT extends ESIntegTestCase { + + protected static String INDEX = "test_index"; + protected static final String ID_FIELD = "_id"; + protected static final String DOC_FIELD = "doc"; + protected static final String TEXT_FIELD = "text"; + protected static final String VECTOR_FIELD = "vector"; + protected static final String TOPIC_FIELD = "topic"; + + @Override + protected Collection> nodePlugins() { + return List.of(RRFRankPlugin.class); + } + + @Before + public void setup() throws Exception { + setupIndex(); + } + + protected void setupIndex() { + String mapping = """ + { + "properties": { + "vector": { + "type": "dense_vector", + "dims": 1, + "element_type": "float", + "similarity": "l2_norm", + "index": true, + "index_options": { + "type": "hnsw" + } + }, + "text": { + "type": "text" + }, + "doc": { + "type": "keyword" + }, + "topic": { + "type": "keyword" + }, + "views": { + "type": "nested", + "properties": { + "last30d": { + "type": "integer" + }, + "all": { + "type": "integer" + } + } + } + } + } + """; + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 5).build()); + admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); + indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term"); + indexDoc( + INDEX, + "doc_2", + DOC_FIELD, + "doc_2", + TOPIC_FIELD, + "astronomy", + TEXT_FIELD, + "search term term", + VECTOR_FIELD, + new float[] { 2.0f } + ); + indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f }); + indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term"); + indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); + indexDoc( + INDEX, + "doc_6", + DOC_FIELD, + "doc_6", + TEXT_FIELD, + "search term term term term term term", + VECTOR_FIELD, + new float[] { 6.0f } + ); + indexDoc( + INDEX, + "doc_7", + DOC_FIELD, + "doc_7", + TOPIC_FIELD, + "biology", + TEXT_FIELD, + "term term term term term term term", + VECTOR_FIELD, + new float[] { 7.0f } + ); + refresh(INDEX); + } + + public void testRRFPagination() { + final int rankWindowSize = 100; + final int rankConstant = 10; + final List expectedDocIds = List.of("doc_2", "doc_6", "doc_7", "doc_1", "doc_3", "doc_4"); + final int totalDocs = expectedDocIds.size(); + for (int i = 0; i < randomIntBetween(1, 5); i++) { + int from = randomIntBetween(0, totalDocs - 1); + int size = randomIntBetween(1, totalDocs - from); + for (int docs_to_fetch = from; docs_to_fetch < totalDocs; docs_to_fetch += size) { + SearchSourceBuilder source = new SearchSourceBuilder(); + source.from(docs_to_fetch); + source.size(size); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + int fDocs_to_fetch = docs_to_fetch; + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, lessThanOrEqualTo(size)); + for (int k = 0; k < Math.min(size, resp.getHits().getHits().length); k++) { + assertThat(resp.getHits().getAt(k).getId(), equalTo(expectedDocIds.get(k + fDocs_to_fetch))); + } + }); + } + } + } + + public void testRRFWithAggs() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testRRFWithCollapse() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + }); + } + + public void testRankDocsRetrieverWithCollapseAndAggs() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testMultipleRRFRetrievers() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource( + // this one returns docs 6, 7, 1, 3, and 4 + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ), + null + ), + // this one bring just doc 7 which should be ranked first eventually + new CompoundRetrieverBuilder.RetrieverSource( + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null), + null + ) + ), + rankWindowSize, + rankConstant + ) + ); + + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_4")); + }); + } + + public void testRRFExplainWithNamedRetrievers() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(rrfDetails.getDetails().length, equalTo(3)); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 2]")); + + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); + assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [1] in query at index [1]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [2] in query at index [2]")); + }); + } + + public void testRRFExplainWithAnotherNestedRRF() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 3, 2, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + + RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ); + StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L) + ); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(nestedRRF, null), + new CompoundRetrieverBuilder.RetrieverSource(standard2, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var rrfTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(rrfTopLevel.getDetails().length, equalTo(2)); + assertThat(rrfTopLevel.getDescription(), containsString("computed for initial ranks [2, 1]")); + assertThat(rrfTopLevel.getDetails()[0].getDetails()[0].getDescription(), containsString("rrf score")); + assertThat(rrfTopLevel.getDetails()[1].getDetails()[0].getDescription(), containsString("ConstantScore")); + + var rrfDetails = rrfTopLevel.getDetails()[0].getDetails()[0]; + assertThat(rrfDetails.getDetails().length, equalTo(3)); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [4, 2, 3]")); + + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); + assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [2] in query at index [1]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [3] in query at index [2]")); + }); + } + + public void testRRFInnerRetrieverSearchError() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will throw an error during evaluation + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10)) + ); + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize, + rankConstant + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + Exception ex = expectThrows(IllegalStateException.class, req::get); + assertThat(ex, instanceOf(IllegalStateException.class)); + assertThat(ex.getMessage(), containsString("Search failed - some nested retrievers returned errors")); + assertThat(ex.getSuppressed().length, greaterThan(0)); + } + + public void testRRFInnerRetrieverErrorWhenExtractingToSource() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + return QueryBuilders.matchAllQuery(); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.size(1); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } + + public void testRRFInnerRetrieverErrorOnTopDocs() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + throw new UnsupportedOperationException("simulated failure"); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } +} diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java new file mode 100644 index 0000000000000..3a4ace9b6754a --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.xcontent.XContentType; + +import java.util.Arrays; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.hamcrest.Matchers.equalTo; + +public class RRFRetrieverBuilderNestedDocsIT extends RRFRetrieverBuilderIT { + + private static final String LAST_30D_FIELD = "views.last30d"; + private static final String ALL_TIME_FIELD = "views.all"; + + @Override + protected void setupIndex() { + String mapping = """ + { + "properties": { + "vector": { + "type": "dense_vector", + "dims": 1, + "element_type": "float", + "similarity": "l2_norm", + "index": true, + "index_options": { + "type": "hnsw" + } + }, + "text": { + "type": "text" + }, + "doc": { + "type": "keyword" + }, + "topic": { + "type": "keyword" + }, + "views": { + "type": "nested", + "properties": { + "last30d": { + "type": "integer" + }, + "all": { + "type": "integer" + } + } + } + } + } + """; + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 5).build()); + admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); + indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term", LAST_30D_FIELD, 100); + indexDoc( + INDEX, + "doc_2", + DOC_FIELD, + "doc_2", + TOPIC_FIELD, + "astronomy", + TEXT_FIELD, + "search term term", + VECTOR_FIELD, + new float[] { 2.0f }, + LAST_30D_FIELD, + 3 + ); + indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f }); + indexDoc( + INDEX, + "doc_4", + DOC_FIELD, + "doc_4", + TOPIC_FIELD, + "technology", + TEXT_FIELD, + "term term term term", + ALL_TIME_FIELD, + 100, + LAST_30D_FIELD, + 40 + ); + indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); + indexDoc( + INDEX, + "doc_6", + DOC_FIELD, + "doc_6", + TEXT_FIELD, + "search term term term term term term", + VECTOR_FIELD, + new float[] { 6.0f }, + LAST_30D_FIELD, + 15 + ); + indexDoc( + INDEX, + "doc_7", + DOC_FIELD, + "doc_7", + TOPIC_FIELD, + "biology", + TEXT_FIELD, + "term term term term term term term", + VECTOR_FIELD, + new float[] { 7.0f }, + ALL_TIME_FIELD, + 1000 + ); + refresh(INDEX); + } + + public void testRRFRetrieverWithNestedQuery() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 4 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gte(30L), ScoreMode.Avg) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 6 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.fetchField(TOPIC_FIELD); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); + }); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java index 816b25d53d375..bbc0f622724a3 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java @@ -12,6 +12,8 @@ import java.util.Set; +import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED; + /** * A set of features specifically for the rrf plugin. */ @@ -19,6 +21,6 @@ public class RRFFeatures implements FeatureSpecification { @Override public Set getFeatures() { - return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED); + return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankCoordinatorContext.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankCoordinatorContext.java index b6a1ad52d5d15..56054955d25e7 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankCoordinatorContext.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankCoordinatorContext.java @@ -115,7 +115,7 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) { final int frank = rank; results.compute(new RankKey(rrfRankDoc.doc, rrfRankDoc.shardIndex), (key, value) -> { if (value == null) { - value = new RRFRankDoc(rrfRankDoc.doc, rrfRankDoc.shardIndex, fqc); + value = new RRFRankDoc(rrfRankDoc.doc, rrfRankDoc.shardIndex, fqc, rankConstant); } value.score += 1.0f / (rankConstant + frank); @@ -171,4 +171,8 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) { // and completion suggesters are not allowed return topResults; } + + public int rankConstant() { + return rankConstant; + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankShardContext.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankShardContext.java index 9843b14df6903..62e261d752d3e 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankShardContext.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFQueryPhaseRankShardContext.java @@ -48,7 +48,7 @@ public RRFRankShardResult combineQueryPhaseResults(List rankResults) { final int frank = rank; docsToRankResults.compute(scoreDoc.doc, (key, value) -> { if (value == null) { - value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries); + value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant); } // calculate the current rrf score for this document @@ -100,4 +100,8 @@ public RRFRankShardResult combineQueryPhaseResults(List rankResults) { } return new RRFRankShardResult(rankResults.size(), topResults); } + + public int rankConstant() { + return rankConstant; + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java index 4dbc9a6a54dcf..500ed17395127 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.rank.rrf; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.search.rank.RankDoc; @@ -15,12 +16,15 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; + +import static org.elasticsearch.xpack.rank.rrf.RRFRankBuilder.DEFAULT_RANK_CONSTANT; /** * {@code RRFRankDoc} supports additional ranking information * required for RRF. */ -public class RRFRankDoc extends RankDoc { +public final class RRFRankDoc extends RankDoc { static final String NAME = "rrf_rank_doc"; @@ -42,11 +46,14 @@ public class RRFRankDoc extends RankDoc { */ public final float[] scores; - public RRFRankDoc(int doc, int shardIndex, int queryCount) { + public final int rankConstant; + + public RRFRankDoc(int doc, int shardIndex, int queryCount, int rankConstant) { super(doc, 0f, shardIndex); positions = new int[queryCount]; Arrays.fill(positions, NO_RANK); scores = new float[queryCount]; + this.rankConstant = rankConstant; } public RRFRankDoc(StreamInput in) throws IOException { @@ -54,21 +61,43 @@ public RRFRankDoc(StreamInput in) throws IOException { rank = in.readVInt(); positions = in.readIntArray(); scores = in.readFloatArray(); + if (in.getTransportVersion().onOrAfter(TransportVersions.RRF_QUERY_REWRITE)) { + this.rankConstant = in.readVInt(); + } else { + this.rankConstant = DEFAULT_RANK_CONSTANT; + } } @Override - public Explanation explain() { - // ideally we'd need access to the rank constant to provide score info for this one + public Explanation explain(Explanation[] sources, String[] queryNames) { + assert sources.length == scores.length; int queries = positions.length; Explanation[] details = new Explanation[queries]; for (int i = 0; i < queries; i++) { - final String queryIndex = "at index [" + i + "]"; + final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]"; + final String queryIdentifier = "at index [" + i + "]" + queryAlias; if (positions[i] == RRFRankDoc.NO_RANK) { - final String description = "rrf score: [0], result not found in query " + queryIndex; + final String description = "rrf score: [0], result not found in query " + queryIdentifier; details[i] = Explanation.noMatch(description); } else { final int rank = positions[i] + 1; - details[i] = Explanation.match(rank, "rank [" + (rank) + "] in query " + queryIndex); + final float rrfScore = (1f / (rank + rankConstant)); + details[i] = Explanation.match( + rank, + "rrf score: [" + + rrfScore + + "], " + + "for rank [" + + (rank) + + "] in query " + + queryIdentifier + + " computed as [1 / (" + + (rank) + + " + " + + rankConstant + + ")], for matching query with score", + sources[i] + ); } } return Explanation.match( @@ -77,6 +106,8 @@ public Explanation explain() { + score + "] computed for initial ranks " + Arrays.toString(Arrays.stream(positions).map(x -> x + 1).toArray()) + + " with rankConstant: [" + + rankConstant + "] as sum of [1 / (rank + rankConstant)] for each query", details ); @@ -87,17 +118,22 @@ public void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(rank); out.writeIntArray(positions); out.writeFloatArray(scores); + if (out.getTransportVersion().onOrAfter(TransportVersions.RRF_QUERY_REWRITE)) { + out.writeVInt(rankConstant); + } } @Override public boolean doEquals(RankDoc rd) { RRFRankDoc rrfrd = (RRFRankDoc) rd; - return Arrays.equals(positions, rrfrd.positions) && Arrays.equals(scores, rrfrd.scores); + return Arrays.equals(positions, rrfrd.positions) + && Arrays.equals(scores, rrfrd.scores) + && Objects.equals(rankConstant, rrfrd.rankConstant); } @Override public int doHashCode() { - int result = Arrays.hashCode(positions); + int result = Arrays.hashCode(positions) + Objects.hash(rankConstant); result = 31 * result + Arrays.hashCode(scores); return result; } @@ -117,6 +153,8 @@ public String toString() { + doc + ", shardIndex=" + shardIndex + + ", rankConstant=" + + rankConstant + '}'; } @@ -129,5 +167,6 @@ public String getWriteableName() { protected void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field("positions", positions); builder.field("scores", scores); + builder.field("rankConstant", rankConstant); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 0d6208e474eea..496af99574431 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -7,24 +7,30 @@ package org.elasticsearch.xpack.rank.rrf; +import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; -import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.XPackPlugin; import java.io.IOException; -import java.util.Collections; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.rank.rrf.RRFRankPlugin.NAME; /** @@ -34,30 +40,39 @@ * top docs that will then be combined and ranked according to the rrf * formula. */ -public final class RRFRetrieverBuilder extends RetrieverBuilder { +public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder { + public static final String NAME = "rrf"; public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported"); + public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported"); public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant"); - public static final ObjectParser PARSER = new ObjectParser<>( + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, - RRFRetrieverBuilder::new + false, + args -> { + List childRetrievers = (List) args[0]; + List innerRetrievers = childRetrievers.stream().map(r -> new RetrieverSource(r, null)).toList(); + int rankWindowSize = args[1] == null ? RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; + int rankConstant = args[2] == null ? RRFRankBuilder.DEFAULT_RANK_CONSTANT : (int) args[2]; + return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant); + } ); static { - PARSER.declareObjectArray((r, v) -> r.retrieverBuilders = v, (p, c) -> { + PARSER.declareObjectArray(constructorArg(), (p, c) -> { p.nextToken(); String name = p.currentName(); RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c); p.nextToken(); return retrieverBuilder; }, RETRIEVERS_FIELD); - PARSER.declareInt((r, v) -> r.rankWindowSize = v, RANK_WINDOW_SIZE_FIELD); - PARSER.declareInt((r, v) -> r.rankConstant = v, RANK_CONSTANT_FIELD); - + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -65,76 +80,115 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP if (context.clusterSupportsFeature(RRF_RETRIEVER_SUPPORTED) == false) { throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]"); } + if (context.clusterSupportsFeature(RRF_RETRIEVER_COMPOSITION_SUPPORTED) == false) { + throw new UnsupportedOperationException("[rrf] retriever composition feature is not supported by all nodes in the cluster"); + } if (RRFRankPlugin.RANK_RRF_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { throw LicenseUtils.newComplianceException("Reciprocal Rank Fusion (RRF)"); } return PARSER.apply(parser, context); } - List retrieverBuilders = Collections.emptyList(); - int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE; - int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT; + private final int rankConstant; + + public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) { + this(new ArrayList<>(), rankWindowSize, rankConstant); + } + + RRFRetrieverBuilder(List childRetrievers, int rankWindowSize, int rankConstant) { + super(childRetrievers, rankWindowSize); + this.rankConstant = rankConstant; + } @Override - public QueryBuilder topDocsQuery() { - throw new IllegalStateException("{" + getName() + "} cannot be nested"); + public String getName() { + return NAME; } @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (compoundUsed) { - throw new IllegalArgumentException("[rank] cannot be used in children of compound retrievers"); - } + protected RRFRetrieverBuilder clone(List newRetrievers) { + return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); + } - for (RetrieverBuilder retrieverBuilder : retrieverBuilders) { - if (preFilterQueryBuilders.isEmpty() == false) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + @Override + protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults) { + // combine the disjointed sets of TopDocs into a single set or RRFRankDocs + // each RRFRankDoc will have both the position and score for each query where + // it was within the result set for that query + // if a doc isn't part of a result set its position will be NO_RANK [0] and + // its score is [0f] + int queries = rankResults.size(); + Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + int index = 0; + for (var rrfRankResult : rankResults) { + int rank = 1; + for (ScoreDoc scoreDoc : rrfRankResult) { + final int findex = index; + final int frank = rank; + docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> { + if (value == null) { + value = new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, rankConstant); + } + + // calculate the current rrf score for this document + // later used to sort and covert to a rank + value.score += 1.0f / (rankConstant + frank); + + // record the position for each query + // for explain and debugging + value.positions[findex] = frank - 1; + + // record the score for each query + // used to later re-rank on the coordinator + value.scores[findex] = scoreDoc.score; + + return value; + }); + ++rank; } - - retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, true); + ++index; } - searchSourceBuilder.rankBuilder(new RRFRankBuilder(rankWindowSize, rankConstant)); + // sort the results based on rrf score, tiebreaker based on smaller doc id + RRFRankDoc[] sortedResults = docsToRankResults.values().toArray(RRFRankDoc[]::new); + Arrays.sort(sortedResults); + // trim the results if needed, otherwise each shard will always return `rank_window_sieze` results. + RRFRankDoc[] topResults = new RRFRankDoc[Math.min(rankWindowSize, sortedResults.length)]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = sortedResults[rank]; + topResults[rank].rank = rank + 1; + } + return topResults; } // ---- FOR TESTING XCONTENT PARSING ---- @Override - public String getName() { - return NAME; + public boolean doEquals(Object o) { + RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; + return super.doEquals(o) && rankConstant == that.rankConstant; + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), rankConstant); } @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { - if (retrieverBuilders.isEmpty() == false) { + if (innerRetrievers.isEmpty() == false) { builder.startArray(RETRIEVERS_FIELD.getPreferredName()); - for (RetrieverBuilder retrieverBuilder : retrieverBuilders) { + for (var entry : innerRetrievers) { builder.startObject(); - builder.field(retrieverBuilder.getName()); - retrieverBuilder.toXContent(builder, params); + builder.field(entry.retriever().getName()); + entry.retriever().toXContent(builder, params); builder.endObject(); } - builder.endArray(); } builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); } - - @Override - public boolean doEquals(Object o) { - RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; - return rankWindowSize == that.rankWindowSize - && rankConstant == that.rankConstant - && Objects.equals(retrieverBuilders, that.retrieverBuilders); - } - - @Override - public int doHashCode() { - return Objects.hash(retrieverBuilders, rankWindowSize, rankConstant); - } - - // ---- END FOR TESTING ---- } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java index 61859e280acdf..cd6883e1e54fd 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java @@ -71,7 +71,7 @@ public void testShardCombine() { assertEquals(2, result.queryCount); assertEquals(10, result.rrfRankDocs.length); - RRFRankDoc expected = new RRFRankDoc(8, -1, 2); + RRFRankDoc expected = new RRFRankDoc(8, -1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 7; expected.positions[1] = 0; @@ -80,7 +80,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[0]); - expected = new RRFRankDoc(1, -1, 2); + expected = new RRFRankDoc(1, -1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 0; expected.positions[1] = NO_RANK; @@ -89,7 +89,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[1]); - expected = new RRFRankDoc(9, -1, 2); + expected = new RRFRankDoc(9, -1, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = 8; expected.positions[1] = 1; @@ -98,7 +98,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[2]); - expected = new RRFRankDoc(10, -1, 2); + expected = new RRFRankDoc(10, -1, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = 9; expected.positions[1] = 2; @@ -107,7 +107,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[3]); - expected = new RRFRankDoc(2, -1, 2); + expected = new RRFRankDoc(2, -1, 2, context.rankConstant()); expected.rank = 5; expected.positions[0] = 1; expected.positions[1] = NO_RANK; @@ -116,7 +116,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[4]); - expected = new RRFRankDoc(3, -1, 2); + expected = new RRFRankDoc(3, -1, 2, context.rankConstant()); expected.rank = 6; expected.positions[0] = 2; expected.positions[1] = NO_RANK; @@ -125,7 +125,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[5]); - expected = new RRFRankDoc(4, -1, 2); + expected = new RRFRankDoc(4, -1, 2, context.rankConstant()); expected.rank = 7; expected.positions[0] = 3; expected.positions[1] = NO_RANK; @@ -134,7 +134,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[6]); - expected = new RRFRankDoc(11, -1, 2); + expected = new RRFRankDoc(11, -1, 2, context.rankConstant()); expected.rank = 8; expected.positions[0] = NO_RANK; expected.positions[1] = 3; @@ -143,7 +143,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[7]); - expected = new RRFRankDoc(5, -1, 2); + expected = new RRFRankDoc(5, -1, 2, context.rankConstant()); expected.rank = 9; expected.positions[0] = 4; expected.positions[1] = NO_RANK; @@ -152,7 +152,7 @@ public void testShardCombine() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[8]); - expected = new RRFRankDoc(12, -1, 2); + expected = new RRFRankDoc(12, -1, 2, context.rankConstant()); expected.rank = 10; expected.positions[0] = NO_RANK; expected.positions[1] = 4; @@ -166,27 +166,27 @@ public void testCoordinatorRank() { RRFQueryPhaseRankCoordinatorContext context = new RRFQueryPhaseRankCoordinatorContext(4, 0, 5, 1); QuerySearchResult qsr0 = new QuerySearchResult(); qsr0.setShardIndex(1); - RRFRankDoc rd11 = new RRFRankDoc(1, -1, 2); + RRFRankDoc rd11 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd11.positions[0] = 2; rd11.positions[1] = 0; rd11.scores[0] = 3.0f; rd11.scores[1] = 8.0f; - RRFRankDoc rd12 = new RRFRankDoc(2, -1, 2); + RRFRankDoc rd12 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd12.positions[0] = 3; rd12.positions[1] = 1; rd12.scores[0] = 2.0f; rd12.scores[1] = 7.0f; - RRFRankDoc rd13 = new RRFRankDoc(3, -1, 2); + RRFRankDoc rd13 = new RRFRankDoc(3, -1, 2, context.rankConstant()); rd13.positions[0] = 0; rd13.positions[1] = NO_RANK; rd13.scores[0] = 10.0f; rd13.scores[1] = 0.0f; - RRFRankDoc rd14 = new RRFRankDoc(4, -1, 2); + RRFRankDoc rd14 = new RRFRankDoc(4, -1, 2, context.rankConstant()); rd14.positions[0] = 4; rd14.positions[1] = 2; rd14.scores[0] = 1.0f; rd14.scores[1] = 6.0f; - RRFRankDoc rd15 = new RRFRankDoc(5, -1, 2); + RRFRankDoc rd15 = new RRFRankDoc(5, -1, 2, context.rankConstant()); rd15.positions[0] = 1; rd15.positions[1] = NO_RANK; rd15.scores[0] = 9.0f; @@ -195,27 +195,27 @@ public void testCoordinatorRank() { QuerySearchResult qsr1 = new QuerySearchResult(); qsr1.setShardIndex(2); - RRFRankDoc rd21 = new RRFRankDoc(1, -1, 2); + RRFRankDoc rd21 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd21.positions[0] = 0; rd21.positions[1] = 0; rd21.scores[0] = 9.5f; rd21.scores[1] = 7.5f; - RRFRankDoc rd22 = new RRFRankDoc(2, -1, 2); + RRFRankDoc rd22 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd22.positions[0] = 1; rd22.positions[1] = 1; rd22.scores[0] = 8.5f; rd22.scores[1] = 6.5f; - RRFRankDoc rd23 = new RRFRankDoc(3, -1, 2); + RRFRankDoc rd23 = new RRFRankDoc(3, -1, 2, context.rankConstant()); rd23.positions[0] = 2; rd23.positions[1] = 2; rd23.scores[0] = 7.5f; rd23.scores[1] = 4.5f; - RRFRankDoc rd24 = new RRFRankDoc(4, -1, 2); + RRFRankDoc rd24 = new RRFRankDoc(4, -1, 2, context.rankConstant()); rd24.positions[0] = 3; rd24.positions[1] = NO_RANK; rd24.scores[0] = 5.5f; rd24.scores[1] = 0.0f; - RRFRankDoc rd25 = new RRFRankDoc(5, -1, 2); + RRFRankDoc rd25 = new RRFRankDoc(5, -1, 2, context.rankConstant()); rd25.positions[0] = NO_RANK; rd25.positions[1] = 3; rd25.scores[0] = 0.0f; @@ -228,7 +228,7 @@ public void testCoordinatorRank() { assertEquals(4, tds.fetchHits); assertEquals(4, scoreDocs.length); - RRFRankDoc expected = new RRFRankDoc(1, 2, 2); + RRFRankDoc expected = new RRFRankDoc(1, 2, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 1; expected.positions[1] = 1; @@ -237,7 +237,7 @@ public void testCoordinatorRank() { expected.score = 0.6666667f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[0]); - expected = new RRFRankDoc(3, 1, 2); + expected = new RRFRankDoc(3, 1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 0; expected.positions[1] = NO_RANK; @@ -246,7 +246,7 @@ public void testCoordinatorRank() { expected.score = 0.5f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[1]); - expected = new RRFRankDoc(1, 1, 2); + expected = new RRFRankDoc(1, 1, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = NO_RANK; expected.positions[1] = 0; @@ -255,7 +255,7 @@ public void testCoordinatorRank() { expected.score = 0.5f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[2]); - expected = new RRFRankDoc(2, 2, 2); + expected = new RRFRankDoc(2, 2, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = 3; expected.positions[1] = 3; @@ -277,7 +277,7 @@ public void testShardTieBreaker() { assertEquals(2, result.queryCount); assertEquals(2, result.rrfRankDocs.length); - RRFRankDoc expected = new RRFRankDoc(1, -1, 2); + RRFRankDoc expected = new RRFRankDoc(1, -1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 0; expected.positions[1] = 1; @@ -286,7 +286,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[0]); - expected = new RRFRankDoc(2, -1, 2); + expected = new RRFRankDoc(2, -1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 1; expected.positions[1] = 0; @@ -304,7 +304,7 @@ public void testShardTieBreaker() { assertEquals(2, result.queryCount); assertEquals(4, result.rrfRankDocs.length); - expected = new RRFRankDoc(3, -1, 2); + expected = new RRFRankDoc(3, -1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 2; expected.positions[1] = 1; @@ -313,7 +313,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[0]); - expected = new RRFRankDoc(2, -1, 2); + expected = new RRFRankDoc(2, -1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 1; expected.positions[1] = 2; @@ -322,7 +322,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[1]); - expected = new RRFRankDoc(1, -1, 2); + expected = new RRFRankDoc(1, -1, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = 0; expected.positions[1] = -1; @@ -331,7 +331,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[2]); - expected = new RRFRankDoc(4, -1, 2); + expected = new RRFRankDoc(4, -1, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = -1; expected.positions[1] = 0; @@ -349,7 +349,7 @@ public void testShardTieBreaker() { assertEquals(2, result.queryCount); assertEquals(4, result.rrfRankDocs.length); - expected = new RRFRankDoc(1, -1, 2); + expected = new RRFRankDoc(1, -1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 0; expected.positions[1] = -1; @@ -358,7 +358,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[0]); - expected = new RRFRankDoc(2, -1, 2); + expected = new RRFRankDoc(2, -1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = -1; expected.positions[1] = 0; @@ -367,7 +367,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[1]); - expected = new RRFRankDoc(3, -1, 2); + expected = new RRFRankDoc(3, -1, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = 1; expected.positions[1] = -1; @@ -376,7 +376,7 @@ public void testShardTieBreaker() { expected.score = Float.NaN; assertRDEquals(expected, result.rrfRankDocs[2]); - expected = new RRFRankDoc(4, -1, 2); + expected = new RRFRankDoc(4, -1, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = -1; expected.positions[1] = 1; @@ -391,7 +391,7 @@ public void testCoordinatorRankTieBreaker() { QuerySearchResult qsr0 = new QuerySearchResult(); qsr0.setShardIndex(1); - RRFRankDoc rd11 = new RRFRankDoc(1, -1, 2); + RRFRankDoc rd11 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd11.positions[0] = 0; rd11.positions[1] = 0; rd11.scores[0] = 10.0f; @@ -400,7 +400,7 @@ public void testCoordinatorRankTieBreaker() { QuerySearchResult qsr1 = new QuerySearchResult(); qsr1.setShardIndex(2); - RRFRankDoc rd21 = new RRFRankDoc(1, -1, 2); + RRFRankDoc rd21 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd21.positions[0] = 0; rd21.positions[1] = 0; rd21.scores[0] = 9.0f; @@ -413,7 +413,7 @@ public void testCoordinatorRankTieBreaker() { assertEquals(2, tds.fetchHits); assertEquals(2, scoreDocs.length); - RRFRankDoc expected = new RRFRankDoc(1, 1, 2); + RRFRankDoc expected = new RRFRankDoc(1, 1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 0; expected.positions[1] = 1; @@ -422,7 +422,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.8333333730697632f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[0]); - expected = new RRFRankDoc(1, 2, 2); + expected = new RRFRankDoc(1, 2, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 1; expected.positions[1] = 0; @@ -433,12 +433,12 @@ public void testCoordinatorRankTieBreaker() { qsr0 = new QuerySearchResult(); qsr0.setShardIndex(1); - rd11 = new RRFRankDoc(1, -1, 2); + rd11 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd11.positions[0] = 0; rd11.positions[1] = -1; rd11.scores[0] = 10.0f; rd11.scores[1] = 0.0f; - RRFRankDoc rd12 = new RRFRankDoc(2, -1, 2); + RRFRankDoc rd12 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd12.positions[0] = 0; rd12.positions[1] = 1; rd12.scores[0] = 9.0f; @@ -447,12 +447,12 @@ public void testCoordinatorRankTieBreaker() { qsr1 = new QuerySearchResult(); qsr1.setShardIndex(2); - rd21 = new RRFRankDoc(1, -1, 2); + rd21 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd21.positions[0] = -1; rd21.positions[1] = 0; rd21.scores[0] = 0.0f; rd21.scores[1] = 11.0f; - RRFRankDoc rd22 = new RRFRankDoc(2, -1, 2); + RRFRankDoc rd22 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd22.positions[0] = 0; rd22.positions[1] = 1; rd22.scores[0] = 9.0f; @@ -465,7 +465,7 @@ public void testCoordinatorRankTieBreaker() { assertEquals(4, tds.fetchHits); assertEquals(4, scoreDocs.length); - expected = new RRFRankDoc(2, 2, 2); + expected = new RRFRankDoc(2, 2, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 2; expected.positions[1] = 1; @@ -474,7 +474,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.5833333730697632f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[0]); - expected = new RRFRankDoc(2, 1, 2); + expected = new RRFRankDoc(2, 1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = 1; expected.positions[1] = 2; @@ -483,7 +483,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.5833333730697632f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[1]); - expected = new RRFRankDoc(1, 1, 2); + expected = new RRFRankDoc(1, 1, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = 0; expected.positions[1] = -1; @@ -492,7 +492,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.5f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[2]); - expected = new RRFRankDoc(1, 2, 2); + expected = new RRFRankDoc(1, 2, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = -1; expected.positions[1] = 0; @@ -503,12 +503,12 @@ public void testCoordinatorRankTieBreaker() { qsr0 = new QuerySearchResult(); qsr0.setShardIndex(1); - rd11 = new RRFRankDoc(1, -1, 2); + rd11 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd11.positions[0] = 0; rd11.positions[1] = -1; rd11.scores[0] = 10.0f; rd11.scores[1] = 0.0f; - rd12 = new RRFRankDoc(2, -1, 2); + rd12 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd12.positions[0] = -1; rd12.positions[1] = 0; rd12.scores[0] = 0.0f; @@ -517,12 +517,12 @@ public void testCoordinatorRankTieBreaker() { qsr1 = new QuerySearchResult(); qsr1.setShardIndex(2); - rd21 = new RRFRankDoc(1, -1, 2); + rd21 = new RRFRankDoc(1, -1, 2, context.rankConstant()); rd21.positions[0] = 0; rd21.positions[1] = -1; rd21.scores[0] = 3.0f; rd21.scores[1] = 0.0f; - rd22 = new RRFRankDoc(2, -1, 2); + rd22 = new RRFRankDoc(2, -1, 2, context.rankConstant()); rd22.positions[0] = -1; rd22.positions[1] = 0; rd22.scores[0] = 0.0f; @@ -535,7 +535,7 @@ public void testCoordinatorRankTieBreaker() { assertEquals(4, tds.fetchHits); assertEquals(4, scoreDocs.length); - expected = new RRFRankDoc(1, 1, 2); + expected = new RRFRankDoc(1, 1, 2, context.rankConstant()); expected.rank = 1; expected.positions[0] = 0; expected.positions[1] = -1; @@ -544,7 +544,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.5f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[0]); - expected = new RRFRankDoc(2, 1, 2); + expected = new RRFRankDoc(2, 1, 2, context.rankConstant()); expected.rank = 2; expected.positions[0] = -1; expected.positions[1] = 0; @@ -553,7 +553,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.5f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[1]); - expected = new RRFRankDoc(1, 2, 2); + expected = new RRFRankDoc(1, 2, 2, context.rankConstant()); expected.rank = 3; expected.positions[0] = 1; expected.positions[1] = -1; @@ -562,7 +562,7 @@ public void testCoordinatorRankTieBreaker() { expected.score = 0.3333333333333333f; assertRDEquals(expected, (RRFRankDoc) scoreDocs[2]); - expected = new RRFRankDoc(2, 2, 2); + expected = new RRFRankDoc(2, 2, 2, context.rankConstant()); expected.rank = 4; expected.positions[0] = -1; expected.positions[1] = 1; diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java index 0b8ee30fe0680..4b64b6c173c92 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.Writeable.Reader; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -17,7 +18,12 @@ public class RRFRankDocTests extends AbstractWireSerializingTestCase { static RRFRankDoc createTestRRFRankDoc(int queryCount) { - RRFRankDoc instance = new RRFRankDoc(randomNonNegativeInt(), randomBoolean() ? -1 : randomNonNegativeInt(), queryCount); + RRFRankDoc instance = new RRFRankDoc( + randomNonNegativeInt(), + randomBoolean() ? -1 : randomNonNegativeInt(), + queryCount, + randomIntBetween(1, 100) + ); instance.score = randomFloat(); instance.rank = randomBoolean() ? NO_RANK : randomIntBetween(1, 10000); for (int qi = 0; qi < queryCount; ++qi) { @@ -46,34 +52,49 @@ protected RRFRankDoc createTestInstance() { @Override protected RRFRankDoc mutateInstance(RRFRankDoc instance) throws IOException { - RRFRankDoc mutated = new RRFRankDoc(instance.doc, instance.shardIndex, instance.positions.length); - mutated.score = instance.score; - mutated.rank = instance.rank; - System.arraycopy(instance.positions, 0, mutated.positions, 0, instance.positions.length); - System.arraycopy(instance.scores, 0, mutated.scores, 0, instance.positions.length); - mutated.rank = mutated.rank == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; - if (rarely()) { - int ri = randomInt(mutated.positions.length - 1); - mutated.positions[ri] = mutated.positions[ri] == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; - } - if (rarely()) { - int ri = randomInt(mutated.positions.length - 1); - mutated.scores[ri] = randomFloat(); - } - if (rarely()) { - mutated.doc = randomNonNegativeInt(); - } - if (rarely()) { - mutated.score = randomFloat(); - } - if (frequently()) { - mutated.shardIndex = mutated.shardIndex == -1 ? randomNonNegativeInt() : -1; + int doc = instance.doc; + int shardIndex = instance.shardIndex; + float score = instance.score; + int rankConstant = instance.rankConstant; + int rank = instance.rank; + int queries = instance.positions.length; + int[] positions = new int[queries]; + float[] scores = new float[queries]; + + switch (randomInt(6)) { + case 0: + doc = randomValueOtherThan(doc, ESTestCase::randomNonNegativeInt); + break; + case 1: + shardIndex = shardIndex == -1 ? randomNonNegativeInt() : -1; + break; + case 2: + score = randomValueOtherThan(score, ESTestCase::randomFloat); + break; + case 3: + rankConstant = randomValueOtherThan(rankConstant, () -> randomIntBetween(1, 100)); + break; + case 4: + rank = rank == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; + break; + case 5: + for (int i = 0; i < queries; i++) { + positions[i] = instance.positions[i] == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; + } + break; + case 6: + for (int i = 0; i < queries; i++) { + scores[i] = randomValueOtherThan(scores[i], ESTestCase::randomFloat); + } + break; + default: + throw new AssertionError(); } + RRFRankDoc mutated = new RRFRankDoc(doc, shardIndex, queries, rankConstant); + System.arraycopy(positions, 0, mutated.positions, 0, instance.positions.length); + System.arraycopy(scores, 0, mutated.scores, 0, instance.scores.length); + mutated.rank = rank; + mutated.score = score; return mutated; } - - public void testExplain() { - RRFRankDoc instance = createTestRRFRankDoc(); - assertEquals(instance.explain().toString(), instance.explain().toString()); - } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index 330c936327b81..e360237371a82 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -29,25 +29,21 @@ public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase(retrieverCount); - while (retrieverCount > 0) { - rrfRetrieverBuilder.retrieverBuilders.add(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); + ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder()); --retrieverCount; } - - return rrfRetrieverBuilder; + return ret; } @Override diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index f5a9f4e9b0c3e..d20f0f88aeb16 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -8,9 +8,12 @@ package org.elasticsearch.xpack.rank.rrf; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; @@ -50,7 +53,8 @@ public void testRetrieverExtractionErrors() throws IOException { SearchSourceBuilder ssb = new SearchSourceBuilder(); IllegalArgumentException iae = expectThrows( IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + () -> ssb.parseXContent(parser, true, nf -> true) + .rewrite(new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")))) ); assertEquals("[search_after] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -65,88 +69,11 @@ public void testRetrieverExtractionErrors() throws IOException { SearchSourceBuilder ssb = new SearchSourceBuilder(); IllegalArgumentException iae = expectThrows( IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + () -> ssb.parseXContent(parser, true, nf -> true) + .rewrite(new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")))) ); assertEquals("[terminate_after] cannot be used in children of compound retrievers", iae.getMessage()); } - - try ( - XContentParser parser = createParser( - JsonXContent.jsonXContent, - "{\"retriever\":{\"rrf_nl\":{\"retrievers\":" + "[{\"standard\":{\"sort\":[\"f1\"]}},{\"standard\":{\"sort\":[\"f2\"]}}]}}}" - ) - ) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) - ); - assertEquals("[sort] cannot be used in children of compound retrievers", iae.getMessage()); - } - - try ( - XContentParser parser = createParser( - JsonXContent.jsonXContent, - "{\"retriever\":{\"rrf_nl\":{\"retrievers\":" + "[{\"standard\":{\"min_score\":1}},{\"standard\":{\"min_score\":2}}]}}}" - ) - ) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) - ); - assertEquals("[min_score] cannot be used in children of compound retrievers", iae.getMessage()); - } - - try ( - XContentParser parser = createParser( - JsonXContent.jsonXContent, - "{\"retriever\":{\"rrf_nl\":{\"retrievers\":" - + "[{\"standard\":{\"collapse\":{\"field\":\"f0\"}}},{\"standard\":{\"collapse\":{\"field\":\"f1\"}}}]}}}" - ) - ) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) - ); - assertEquals("[collapse] cannot be used in children of compound retrievers", iae.getMessage()); - } - - try ( - XContentParser parser = createParser( - JsonXContent.jsonXContent, - "{\"retriever\":{\"rrf_nl\":{\"retrievers\":[{\"rrf_nl\":{}}]}}}" - ) - ) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) - ); - assertEquals("[rank] cannot be used in children of compound retrievers", iae.getMessage()); - } - } - - /** Tests max depth errors related to compound retrievers. These tests require a compound retriever which is why they are here. */ - public void testRetrieverBuilderParsingMaxDepth() throws IOException { - try ( - XContentParser parser = createParser( - JsonXContent.jsonXContent, - "{\"retriever\":{\"rrf_nl\":{\"retrievers\":[{\"rrf_nl\":{\"retrievers\":[{\"standard\":{}}]}}]}}}" - ) - ) { - SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) - ); - assertEquals("[1:65] [rrf] failed to parse field [retrievers]", iae.getMessage()); - assertEquals( - "the nested depth of the [standard] retriever exceeds the maximum nested depth [2] for retrievers", - iae.getCause().getCause().getMessage() - ); - } } @Override diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml index 4f76f52409810..647540644ce9e 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/100_rank_rrf.yml @@ -80,17 +80,14 @@ setup: size: 10 - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term term" } - match: { hits.hits.0.fields.keyword.0: "other" } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term" } - match: { hits.hits.1.fields.keyword.0: "keyword" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -128,12 +125,10 @@ setup: - match: { hits.total.value: 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } @@ -176,17 +171,14 @@ setup: - match: { hits.total.value: 3 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml index 575723853f0aa..b4893bfec0849 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/150_rank_rrf_pagination.yml @@ -164,11 +164,8 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 3 } - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0._rank: 2 } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 3 } - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 4 } --- "Standard pagination outside rank_window_size": @@ -378,7 +375,6 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 3 } --- @@ -489,9 +485,7 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "4" } - - match: { hits.hits.1._rank: 2 } - do: search: @@ -594,9 +588,7 @@ setup: - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 3 } - match: { hits.hits.1._id: "2" } - - match: { hits.hits.1._rank: 4 } --- "Pagination within interleaved results, different result set sizes, rank_window_size covering all results": @@ -690,9 +682,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - do: search: @@ -779,9 +769,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._rank: 3 } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 4 } - do: search: @@ -868,7 +856,6 @@ setup: - match: { hits.total.value: 5 } - length: { hits.hits: 1 } - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0._rank: 5 } --- @@ -965,9 +952,7 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.1._id: "4" } - - match: { hits.hits.1._rank: 2 } - do: search: diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml index 76cedf44d3dbe..bca39dea4ae57 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/200_rank_rrf_script.yml @@ -129,19 +129,10 @@ setup: size: 5 - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - - match: { hits.hits.1._id: "6" } - - match: { hits.hits.1._rank: 2 } - - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 3 } - - match: { hits.hits.3._id: "7" } - - match: { hits.hits.3._rank: 4 } - - match: { hits.hits.4._id: "3" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }} - match: { aggregations.sums.value.text_total: 25 } @@ -196,7 +187,6 @@ setup: - match: { hits.total.value: 6 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - close_to: { aggregations.sums.value.asc_total: { value: 33.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 39.0, error: 0.001 }} @@ -272,20 +262,10 @@ setup: size: 5 - match: { hits.hits.0._id: "6" } - - match: { hits.hits.0._rank: 1 } - - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - - match: { hits.hits.2._id: "7" } - - match: { hits.hits.2._rank: 3 } - - match: { hits.hits.3._id: "4" } - - match: { hits.hits.3._rank: 4 } - - match: { hits.hits.4._id: "8" } - - match: { hits.hits.4._rank: 5 } - - close_to: { aggregations.sums.value.asc_total: { value: 30.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 30.0, error: 0.001 }} - match: { aggregations.sums.value.text_total: 30 } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml index d3d45ef2b18e8..258ab70cd09bd 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/300_rrf_retriever.yml @@ -91,17 +91,14 @@ setup: size: 10 - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term term" } - match: { hits.hits.0.fields.keyword.0: "other" } - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term" } - match: { hits.hits.1.fields.keyword.0: "keyword" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -143,12 +140,10 @@ setup: - match: { hits.total.value : 2 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } @@ -198,17 +193,14 @@ setup: - match: { hits.total.value : 3 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._rank: 2 } - match: { hits.hits.1.fields.text.0: "term term" } - match: { hits.hits.1.fields.keyword.0: "other" } - match: { hits.hits.2._id: "2" } - - match: { hits.hits.2._rank: 3 } - match: { hits.hits.2.fields.text.0: "other" } - match: { hits.hits.2.fields.keyword.0: "other" } @@ -267,7 +259,6 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } @@ -330,6 +321,82 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "3" } - - match: { hits.hits.0._rank: 1 } - match: { hits.hits.0.fields.text.0: "term" } - match: { hits.hits.0.fields.keyword.0: "keyword" } + +--- +"rrf retriever with nested rrf retriever and pagination": + + - do: + search: + index: test + body: + track_total_hits: true + fields: [ "text", "keyword" ] + retriever: + rrf: + retrievers: [ + { + "rrf": + { + "retrievers": [ + { + "knn": { + "field": "vector", + "query_vector": [ 0.0 ], + "k": 3, + "num_candidates": 3 + } + }, + { + "standard": + { + "query": + { + "term": + { + "text": "term" + } + } + } + }, + { + "standard": + { + "query": + { + "match": + { + "keyword": "keyword" + } + } + } + } + ], + "rank_window_size": 100, + "rank_constant": 1 + } + }, + { + "knn": { + "field": vector, + "query_vector": [ 0.0 ], + "k": 2, + "num_candidates": 2 + } + } + ] + "rank_window_size": 10 + "rank_constant": 1 + size: 10 + from: 1 + + - match: { hits.total.value : 3 } + + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0.fields.text.0: "other" } + - match: { hits.hits.0.fields.keyword.0: "other" } + + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1.fields.text.0: "term" } + - match: { hits.hits.1.fields.keyword.0: "keyword" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml new file mode 100644 index 0000000000000..47ba3658bb38d --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml @@ -0,0 +1,1112 @@ +setup: + - skip: + features: close_to + + - requires: + cluster_features: 'rrf_retriever_composition_supported' + reason: 'test requires rrf retriever composition support' + + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + properties: + number_val: + type: keyword + char_val: + type: keyword + + - do: + index: + index: test + id: 1 + body: + number_val: "1" + char_val: "A" + + - do: + index: + index: test + id: 2 + body: + number_val: "2" + char_val: "B" + + - do: + index: + index: test + id: 3 + body: + number_val: "3" + char_val: "C" + + - do: + index: + index: test + id: 4 + body: + number_val: "4" + char_val: "D" + + - do: + index: + index: test + id: 5 + body: + number_val: "5" + char_val: "E" + + - do: + indices.refresh: {} + +--- +"Standard pagination within rank_window_size": + # this test retrieves the same results from two queries, and applies a simple pagination skipping the first result + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: { + retrievers: [ + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: + { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "1", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "2", + boost: 9.0 + } + } + }, + { + term: { + number_val: { + value: "3", + boost: 8.0 + } + } + }, + { + term: { + number_val: { + value: "4", + boost: 7.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "A", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "D", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from : 1 + size : 10 + + - match: { hits.total.value : 4 } + - length: { hits.hits : 3 } + - match: { hits.hits.0._id: "2" } + # score for doc 2 is (1/12 + 1/12) + - close_to: {hits.hits.0._score: {value: 0.1666, error: 0.001}} + - match: { hits.hits.1._id: "3" } + # score for doc 3 is (1/13 + 1/13) + - close_to: {hits.hits.1._score: {value: 0.1538, error: 0.001}} + - match: { hits.hits.2._id: "4" } + # score for doc 4 is (1/14 + 1/14) + - close_to: {hits.hits.2._score: {value: 0.1428, error: 0.001}} + +--- +"Standard pagination outside rank_window_size": + # in this example, from starts *after* rank_window_size so, we expect 0 results to be returned + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "1", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "2", + boost: 9.0 + } + } + }, + { + term: { + number_val: { + value: "3", + boost: 8.0 + } + } + }, + { + term: { + number_val: { + value: "4", + boost: 7.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "A", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "D", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 2, + rank_constant: 10 + } + from : 10 + size : 2 + + - match: { hits.total.value : 4 } + - length: { hits.hits : 0 } + +--- +"Standard pagination partially outside rank_window_size": + # in this example we have that from starts *within* rank_window_size, but "from + size" goes over + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "1", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "2", + boost: 9.0 + } + } + }, + { + term: { + number_val: { + value: "3", + boost: 8.0 + } + } + }, + { + term: { + number_val: { + value: "4", + boost: 7.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "A", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "D", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 3, + rank_constant: 10 + } + from : 2 + size : 2 + + - match: { hits.total.value : 4 } + - length: { hits.hits : 1 } + - match: { hits.hits.0._id: "3" } + # score for doc 3 is (1/13 + 1/13) + - close_to: {hits.hits.0._score: {value: 0.1538, error: 0.001}} + + +--- +"Pagination within interleaved results": + # perform two searches with different "from" parameter, ensuring that results are consistent + # rank_window_size covers the entire result set for both queries, so pagination should be consistent + # queryA has a result set of [1, 2, 3, 4] and + # queryB has a result set of [4, 3, 1, 2] + # so for rank_constant=10, the expected order is [1, 4, 3, 2] + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "1", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "2", + boost: 9.0 + } + } + }, + { + term: { + number_val: { + value: "3", + boost: 8.0 + } + } + }, + { + term: { + number_val: { + value: "4", + boost: 7.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from : 0 + size : 2 + + - match: { hits.total.value : 4 } + - length: { hits.hits : 2 } + - match: { hits.hits.0._id: "1" } + # score for doc 1 is (1/11 + 1/13) + - close_to: {hits.hits.0._score: {value: 0.1678, error: 0.001}} + - match: { hits.hits.1._id: "4" } + # score for doc 4 is (1/11 + 1/14) + - close_to: {hits.hits.1._score: {value: 0.1623, error: 0.001}} + + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [1, 2, 3, 4] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "1", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "2", + boost: 9.0 + } + } + }, + { + term: { + number_val: { + value: "3", + boost: 8.0 + } + } + }, + { + term: { + number_val: { + value: "4", + boost: 7.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from : 2 + size : 2 + + - match: { hits.total.value : 4 } + - length: { hits.hits : 2 } + - match: { hits.hits.0._id: "3" } + # score for doc 3 is (1/12 + 1/13) + - close_to: {hits.hits.0._score: {value: 0.1602, error: 0.001}} + - match: { hits.hits.1._id: "2" } + # score for doc 2 is (1/12 + 1/14) + - close_to: {hits.hits.1._score: {value: 0.1547, error: 0.001}} + +--- +"Pagination within interleaved results, different result set sizes, rank_window_size covering all results": + # perform multiple searches with different "from" parameter, ensuring that results are consistent + # rank_window_size covers the entire result set for both queries, so pagination should be consistent + # queryA has a result set of [5, 1] and + # queryB has a result set of [4, 3, 1, 2] + # so for rank_constant=10, the expected order is [1, 4, 5, 3, 2] + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [5, 1] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "5", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "1", + boost: 9.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from : 0 + size : 2 + + - match: { hits.total.value : 5 } + - length: { hits.hits : 2 } + - match: { hits.hits.0._id: "1" } + # score for doc 1 is (1/12 + 1/13) + - close_to: {hits.hits.0._score: {value: 0.1602, error: 0.001}} + - match: { hits.hits.1._id: "4" } + # score for doc 4 is (1/11) + - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} + + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [5, 1] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "5", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "1", + boost: 9.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from : 2 + size : 2 + + - match: { hits.total.value : 5 } + - length: { hits.hits : 2 } + - match: { hits.hits.0._id: "5" } + # score for doc 5 is (1/11) + - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} + - match: { hits.hits.1._id: "3" } + # score for doc 3 is (1/12) + - close_to: {hits.hits.1._score: {value: 0.0833, error: 0.001}} + + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [5, 1] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "5", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "1", + boost: 9.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + from: 4 + size: 2 + + - match: { hits.total.value: 5 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "2" } + # score for doc 2 is (1/14) + - close_to: {hits.hits.0._score: {value: 0.0714, error: 0.001}} + + +--- +"Pagination within interleaved results, different result set sizes, rank_window_size not covering all results": + # perform multiple searches with different "from" parameter, ensuring that results are consistent + # rank_window_size does not cover the entire result set for both queries, so the results should be different + # from the test above. More specifically, we'd get to collect 2 results from each query, so we'd have: + # queryA has a result set of [5, 1] and + # queryB has a result set of [4, 3] + # so for rank_constant=10, the expected order is [4, 5, 1, 3], + # and the rank_window_size-sized result set that we'd paginate over is [4, 5] + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [5, 1] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "5", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "1", + boost: 9.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 2, + rank_constant: 10 + } + from : 0 + size : 2 + + - match: { hits.total.value : 5 } + - length: { hits.hits : 2 } + - match: { hits.hits.0._id: "4" } + # score for doc 4 is (1/11) + - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} + - match: { hits.hits.1._id: "5" } + # score for doc 5 is (1/11) + - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} + + - do: + search: + index: test + body: + track_total_hits: true + retriever: + rrf: + { + retrievers: [ + { + # this should clause would generate the result set [5, 1] + standard: { + query: { + bool: { + should: [ + { + term: { + number_val: { + value: "5", + boost: 10.0 + } + } + }, + { + term: { + number_val: { + value: "1", + boost: 9.0 + } + } + } + ] + } + } + } + }, + { + # this should clause would generate the result set [4, 3, 1, 2] + standard: { + query: { + bool: { + should: [ + { + term: { + char_val: { + value: "D", + boost: 10.0 + } + } + }, + { + term: { + char_val: { + value: "C", + boost: 9.0 + } + } + }, + { + term: { + char_val: { + value: "A", + boost: 8.0 + } + } + }, + { + term: { + char_val: { + value: "B", + boost: 7.0 + } + } + } + ] + } + } + } + } + ], + rank_window_size: 2, + rank_constant: 10 + } + from : 2 + size : 2 + + - match: { hits.total.value : 5 } + - length: { hits.hits : 0 } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml index 520389d51b737..bbc1087b05cc3 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/400_rrf_retriever_script.yml @@ -160,19 +160,10 @@ setup: ] - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - - match: { hits.hits.1._id: "6" } - - match: { hits.hits.1._rank: 2 } - - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._rank: 3 } - - match: { hits.hits.3._id: "7" } - - match: { hits.hits.3._rank: 4 } - - match: { hits.hits.4._id: "3" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }} - match: { aggregations.sums.value.text_total: 25 } @@ -228,13 +219,12 @@ setup: 'desc_total': states.stream().mapToDouble(v -> v['desc_total']).sum() ] - - match: { hits.total.value: 6 } + - match: { hits.total.value: 5 } - match: { hits.hits.0._id: "5" } - - match: { hits.hits.0._rank: 1 } - - close_to: { aggregations.sums.value.asc_total: { value: 33.0, error: 0.001 }} - - close_to: { aggregations.sums.value.desc_total: { value: 39.0, error: 0.001 }} + - close_to: { aggregations.sums.value.asc_total: { value: 25.0, error: 0.001 }} + - close_to: { aggregations.sums.value.desc_total: { value: 35.0, error: 0.001 }} --- "rrf retriever using multiple knn retrievers and a standard retriever with a scripted metric aggregation": @@ -333,19 +323,10 @@ setup: ] - match: { hits.hits.0._id: "6" } - - match: { hits.hits.0._rank: 1 } - - match: { hits.hits.1._id: "5" } - - match: { hits.hits.1._rank: 2 } - - match: { hits.hits.2._id: "7" } - - match: { hits.hits.2._rank: 3 } - - match: { hits.hits.3._id: "4" } - - match: { hits.hits.3._rank: 4 } - - match: { hits.hits.4._id: "8" } - - match: { hits.hits.4._rank: 5 } - close_to: { aggregations.sums.value.asc_total: { value: 30.0, error: 0.001 }} - close_to: { aggregations.sums.value.desc_total: { value: 30.0, error: 0.001 }} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/500_rrf_retriever_explain.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/500_rrf_retriever_explain.yml index 8d74ecbccd328..a66d99a922ed0 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/500_rrf_retriever_explain.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/500_rrf_retriever_explain.yml @@ -117,7 +117,7 @@ setup: - match: {hits.hits.0._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.0._explanation.details.1.value: 1} - match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.5\\].*/" } - - match: {hits.hits.0._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - close_to: { hits.hits.1._explanation.value: { value: 0.5833334, error: 0.000001 } } - match: {hits.hits.1._explanation.description: "/rrf.score:.\\[0.5833334\\].*/" } @@ -126,7 +126,7 @@ setup: - match: {hits.hits.1._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.1._explanation.details.1.value: 2} - match: {hits.hits.1._explanation.details.1.description: "/rrf.score:.\\[0.33333334\\].*/" } - - match: {hits.hits.1._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.1._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - match: {hits.hits.2._explanation.value: 0.5} - match: {hits.hits.2._explanation.description: "/rrf.score:.\\[0.5\\].*/" } @@ -154,10 +154,10 @@ setup: term: { text: { value: "term", - _name: "my_query" } } - } + }, + _name: "my_query" } }, { @@ -186,7 +186,7 @@ setup: - match: {hits.hits.0._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.0._explanation.details.1.value: 1} - match: {hits.hits.0._explanation.details.1.description: "/.*my_top_knn.*/" } - - match: {hits.hits.0._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - close_to: { hits.hits.1._explanation.value: { value: 0.5833334, error: 0.000001 } } - match: {hits.hits.1._explanation.description: "/rrf.score:.\\[0.5833334\\].*/" } @@ -195,7 +195,7 @@ setup: - match: {hits.hits.1._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.1._explanation.details.1.value: 2} - match: {hits.hits.1._explanation.details.1.description: "/.*my_top_knn.*/" } - - match: {hits.hits.1._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.1._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - match: {hits.hits.2._explanation.value: 0.5} - match: {hits.hits.2._explanation.description: "/rrf.score:.\\[0.5\\].*/" } @@ -254,7 +254,7 @@ setup: - match: {hits.hits.0._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.0._explanation.details.1.value: 1} - match: {hits.hits.0._explanation.details.1.description: "/.*my_top_knn.*/" } - - match: {hits.hits.0._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - close_to: { hits.hits.1._explanation.value: { value: 0.5833334, error: 0.000001 } } - match: {hits.hits.1._explanation.description: "/rrf.score:.\\[0.5833334\\].*/" } @@ -263,7 +263,7 @@ setup: - match: {hits.hits.1._explanation.details.0.details.0.description: "/weight\\(text:term.*/" } - match: {hits.hits.1._explanation.details.1.value: 2} - match: {hits.hits.1._explanation.details.1.description: "/.*my_top_knn.*/" } - - match: {hits.hits.1._explanation.details.1.details.0.description: "/within.top.*/" } + - match: {hits.hits.1._explanation.details.1.details.0.details.0.description: "/found.vector.with.calculated.similarity.*/" } - match: {hits.hits.2._explanation.value: 0.5} - match: {hits.hits.2._explanation.description: "/rrf.score:.\\[0.5\\].*/" } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/600_rrf_retriever_profile.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/600_rrf_retriever_profile.yml index 7308ce8947db7..e34885419c7f7 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/600_rrf_retriever_profile.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/600_rrf_retriever_profile.yml @@ -1,7 +1,7 @@ setup: - requires: - cluster_features: "gte_v8.15.0" - reason: 'profile for rrf was enabled in 8.15' + cluster_features: 'rrf_retriever_composition_supported' + reason: 'test requires rrf retriever composition support' test_runner_features: close_to - do: @@ -114,12 +114,13 @@ setup: - match: { hits.hits.2._id: "4" } - not_exists: profile.shards.0.dfs - - match: { profile.shards.0.searches.0.query.0.type: ConstantScoreQuery } - - length: { profile.shards.0.searches.0.query.0.children: 1 } - - match: { profile.shards.0.searches.0.query.0.children.0.type: BooleanQuery } - - length: { profile.shards.0.searches.0.query.0.children.0.children: 2 } - - match: { profile.shards.0.searches.0.query.0.children.0.children.0.type: TermQuery } - - match: { profile.shards.0.searches.0.query.0.children.0.children.1.type: DocAndScoreQuery } + - match: { profile.shards.0.searches.0.query.0.type: RankDocsQuery } + - length: { profile.shards.0.searches.0.query.0.children: 2 } + - match: { profile.shards.0.searches.0.query.0.children.0.type: TopQuery } + - match: { profile.shards.0.searches.0.query.0.children.1.type: BooleanQuery } + - length: { profile.shards.0.searches.0.query.0.children.1.children: 2 } + - match: { profile.shards.0.searches.0.query.0.children.1.children.0.type: TermQuery } + - match: { profile.shards.0.searches.0.query.0.children.1.children.1.type: DocAndScoreQuery } --- "profile standard and knn dfs retrievers": @@ -159,17 +160,14 @@ setup: - match: { hits.hits.1._id: "2" } - match: { hits.hits.2._id: "4" } - - exists: profile.shards.0.dfs - - length: { profile.shards.0.dfs.knn: 1 } - - length: { profile.shards.0.dfs.knn.0.query: 1 } - - match: { profile.shards.0.dfs.knn.0.query.0.type: DocAndScoreQuery } - - - match: { profile.shards.0.searches.0.query.0.type: ConstantScoreQuery } - - length: { profile.shards.0.searches.0.query.0.children: 1 } - - match: { profile.shards.0.searches.0.query.0.children.0.type: BooleanQuery } - - length: { profile.shards.0.searches.0.query.0.children.0.children: 2 } - - match: { profile.shards.0.searches.0.query.0.children.0.children.0.type: TermQuery } - - match: { profile.shards.0.searches.0.query.0.children.0.children.1.type: KnnScoreDocQuery } + - not_exists: profile.shards.0.dfs + - match: { profile.shards.0.searches.0.query.0.type: RankDocsQuery } + - length: { profile.shards.0.searches.0.query.0.children: 2 } + - match: { profile.shards.0.searches.0.query.0.children.0.type: TopQuery } + - match: { profile.shards.0.searches.0.query.0.children.1.type: BooleanQuery } + - length: { profile.shards.0.searches.0.query.0.children.1.children: 2 } + - match: { profile.shards.0.searches.0.query.0.children.1.children.0.type: TermQuery } + - match: { profile.shards.0.searches.0.query.0.children.1.children.1.type: TopQuery } --- "using query and dfs knn search": diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml new file mode 100644 index 0000000000000..1f7125377b892 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml @@ -0,0 +1,541 @@ +setup: + - skip: + features: close_to + + - requires: + cluster_features: 'rrf_retriever_composition_supported' + reason: 'test requires rrf retriever composition support' + + - do: + indices.create: + index: test + body: + settings: + number_of_shards: 1 + mappings: + properties: + text: + type: text + text_to_highlight: + type: text + keyword: + type: keyword + integer: + type: integer + vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + index_options: + type: hnsw + ef_construction: 100 + m: 16 + nested: + type: nested + properties: + views: + type: long + + - do: + index: + index: test + id: "1" + body: + text: "term term term term term term term term term" + vector: [1.0] + + - do: + index: + index: test + id: "2" + body: + text: "term term term term term term term term" + text_to_highlight: "search for the truth" + keyword: "biology" + vector: [2.0] + + - do: + index: + index: test + id: "3" + body: + text: "term term term term term term term" + text_to_highlight: "nothing related but still a match" + keyword: "technology" + vector: [3.0] + + - do: + index: + index: test + id: "4" + body: + text: "term term term term term term" + vector: [4.0] + - do: + index: + index: test + id: "5" + body: + text: "term term term term term" + text_to_highlight: "You know, for Search!" + keyword: "technology" + integer: 5 + vector: [5.0] + - do: + index: + index: test + id: "6" + body: + text: "term term term term" + keyword: "biology" + integer: 6 + vector: [6.0] + - do: + index: + index: test + id: "7" + body: + text: "term term term" + keyword: "astronomy" + vector: [7.0] + nested: { views: 50} + - do: + index: + index: test + id: "8" + body: + text: "term term" + keyword: "technology" + vector: [8.0] + nested: { views: 100} + - do: + index: + index: test + id: "9" + body: + text: "term" + keyword: "technology" + vector: [9.0] + nested: { views: 10} + - do: + indices.refresh: {} + +--- +"rrf retriever with aggs": + + - do: + search: + index: test + body: + track_total_hits: false + retriever: + rrf: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 6.0 ], + k: 3, + num_candidates: 3 + } + }, + { + standard: { + query: { + term: { + text: term + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + aggs: + keyword_aggs: + terms: + field: keyword + + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.2._id: "6" } + + - match: { aggregations.keyword_aggs.buckets.0.key: "technology" } + - match: { aggregations.keyword_aggs.buckets.0.doc_count: 4 } + - match: { aggregations.keyword_aggs.buckets.1.key: "biology" } + - match: { aggregations.keyword_aggs.buckets.1.doc_count: 2 } + - match: { aggregations.keyword_aggs.buckets.2.key: "astronomy" } + - match: { aggregations.keyword_aggs.buckets.2.doc_count: 1 } + +--- +"rrf retriever with aggs - scripted metric using score": + + - do: + search: + index: test + body: + track_total_hits: false + retriever: + rrf: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 6.0 ], + k: 3, + num_candidates: 3 + } + }, + { + standard: { + query: { + term: { + text: term + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + aggs: + max_score: + max: + script: + lang: painless + source: "_score" + + + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.2._id: "6" } + + - close_to: { aggregations.max_score.value: { value: 0.15, error: 0.001 }} + +--- +"rrf retriever with top-level collapse": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 6.0 ], + k: 3, + num_candidates: 3 + } + }, + { + standard: { + query: { + term: { + text: term + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + collapse: { field: keyword, inner_hits: { name: sub_hits, size: 2 } } + + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.2._id: "6" } + + - match: { hits.hits.0.inner_hits.sub_hits.hits.total : 4 } + - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "5" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "3" } + + - length: { hits.hits.1.inner_hits.sub_hits.hits.hits : 2 } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.0._id: "1" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.1._id: "4" } + + - length: { hits.hits.2.inner_hits.sub_hits.hits.hits: 2 } + - match: { hits.hits.2.inner_hits.sub_hits.hits.hits.0._id: "6" } + - match: { hits.hits.2.inner_hits.sub_hits.hits.hits.1._id: "2" } + +--- +"rrf retriever with inner-level collapse": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 6.0 ], + k: 3, + num_candidates: 10 + } + }, + { + standard: { + query: { + term: { + text: term + } + }, + collapse: { field: keyword, inner_hits: { name: sub_hits, size: 1 } } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + + - match: { hits.hits.0._id: "7" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.2._id: "6" } + +--- +"rrf retriever highlighting results": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + standard: { + query: { + match: { + text_to_highlight: "search" + } + } + } + }, + { + standard: { + query: { + term: { + keyword: technology + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + highlight: { + fields: { + "text_to_highlight": { + "fragment_size": 150, + "number_of_fragments": 3 + } + } + } + + - match: { hits.total : 5 } + + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.0.highlight.text_to_highlight.0: "You know, for Search!" } + + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1.highlight.text_to_highlight.0: "search for the truth" } + + - match: { hits.hits.2._id: "3" } + - not_exists: hits.hits.2.highlight + +--- +"rrf retriever with custom nested sort": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + # this one retrievers docs 1, 2, 3, .., 9 + # but due to sorting, it will revert the order to 6, 5, .., 9 which due to + # rank_window_size: 2 will only return 6 and 5 + standard: { + query: { + term: { + text: term + } + }, + sort: [ + { + integer: { + order: desc + } + } + ] + } + }, + { + # this one retrieves doc 2 and 6 + standard: { + query: { + term: { + keyword: biology + } + } + } + } + ] + rank_window_size: 2 + rank_constant: 10 + size: 2 + + - match: { hits.total : 9 } + - length: {hits.hits: 2 } + + - match: { hits.hits.0._id: "6" } + - match: { hits.hits.1._id: "2" } + +--- +"rrf retriever with nested query": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 7.0 ], + k: 1, + num_candidates: 3 + } + }, + { + standard: { + query: { + nested: { + path: nested, + query: { + range: { + nested.views: { + gte: 50 + } + } + } + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + + - match: { hits.total : 2 } + - match: { hits.hits.0._id: "7" } + - match: { hits.hits.1._id: "8" } + +--- +"rrf retriever with global min_score": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + # this one retrievers docs 1 and 2 + knn: { + field: vector, + query_vector: [ 1.0 ], + k: 2, + num_candidates: 10 + } + }, + { + # this one retrieves docs 2 and 6 + standard: { + query: { + term: { + keyword: biology + } + } + } + } + ] + rank_window_size: 5 + rank_constant: 10 + size: 3 + min_score: 0.1 + + - match: { hits.total : 1 } + + - match: { hits.hits.0._id: "2" } + +--- +"rrf retriever with retriever-level min_score": + + - do: + search: + rest_total_hits_as_int: true + index: test + body: + track_total_hits: true + retriever: + rrf: + retrievers: [ + { + # this one retrieves doc 1 + knn: { + field: vector, + query_vector: [ 1.0 ], + k: 10, + num_candidates: 10, + similarity: 0.0 + } + }, + { + # this one retrieves no docs + standard: { + query: { + term: { + keyword: biology + } + }, + min_score: 100 + } + } + ] + rank_window_size: 10 + rank_constant: 10 + size: 10 + + - length: { hits.hits : 1 } + + - match: { hits.hits.0._id: "1" } From cd427198dc68764e9a36ee4b4b033700f256d736 Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 3 Oct 2024 10:53:47 +0100 Subject: [PATCH 4/4] More verbose logging in `IndicesSegmentsRestCancellationIT` (#113844) Relates #88201 --- ...ockedSearcherRestCancellationTestCase.java | 34 +++++++++++++++++++ .../IndicesSegmentsRestCancellationIT.java | 11 ++++++ 2 files changed, 45 insertions(+) diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BlockedSearcherRestCancellationTestCase.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BlockedSearcherRestCancellationTestCase.java index a85ac9aefe694..c7e3a5b1c9a77 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BlockedSearcherRestCancellationTestCase.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BlockedSearcherRestCancellationTestCase.java @@ -9,10 +9,12 @@ package org.elasticsearch.http; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Cancellable; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; @@ -29,6 +31,8 @@ import org.elasticsearch.index.shard.IndexShardTestCase; import org.elasticsearch.index.translog.TranslogStats; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.Task; @@ -141,6 +145,12 @@ public List> getSettings() { private static class SearcherBlockingEngine extends ReadOnlyEngine { + // using a specialized logger for this case because and "logger" means "Engine#logger" + // (relates investigation into https://github.com/elastic/elasticsearch/issues/88201) + private static final Logger blockedSearcherRestCancellationTestCaseLogger = LogManager.getLogger( + BlockedSearcherRestCancellationTestCase.class + ); + final Semaphore searcherBlock = new Semaphore(1); SearcherBlockingEngine(EngineConfig config) { @@ -149,12 +159,36 @@ private static class SearcherBlockingEngine extends ReadOnlyEngine { @Override public Searcher acquireSearcher(String source, SearcherScope scope, Function wrapper) throws EngineException { + if (blockedSearcherRestCancellationTestCaseLogger.isDebugEnabled()) { + blockedSearcherRestCancellationTestCaseLogger.debug( + Strings.format( + "in acquireSearcher for shard [%s] on thread [%s], availablePermits=%d", + config().getShardId(), + Thread.currentThread().getName(), + searcherBlock.availablePermits() + ), + new ElasticsearchException("stack trace") + ); + } + try { searcherBlock.acquire(); } catch (InterruptedException e) { throw new AssertionError(e); } searcherBlock.release(); + + if (blockedSearcherRestCancellationTestCaseLogger.isDebugEnabled()) { + blockedSearcherRestCancellationTestCaseLogger.debug( + Strings.format( + "continuing in acquireSearcher for shard [%s] on thread [%s], availablePermits=%d", + config().getShardId(), + Thread.currentThread().getName(), + searcherBlock.availablePermits() + ) + ); + } + return super.acquireSearcher(source, scope, wrapper); } } diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java index a90b04d54649c..92fde6d7765cc 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java @@ -12,12 +12,23 @@ import org.apache.http.client.methods.HttpGet; import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction; import org.elasticsearch.client.Request; +import org.elasticsearch.test.junit.annotations.TestIssueLogging; public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase { + @TestIssueLogging( + issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", + value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" + + ",org.elasticsearch.transport.TransportService:TRACE" + ) public void testIndicesSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME); } + @TestIssueLogging( + issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", + value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" + + ",org.elasticsearch.transport.TransportService:TRACE" + ) public void testCatSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME); }