From 3117cf6fd6965c5ade7b9aa39578046bdc48efa8 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 12 Dec 2024 15:58:00 +0000 Subject: [PATCH] Support `_shard_doc` as a sort tiebreaker for query rescoring This commit introduces support for using the `_shard_doc` field as a sort tiebreaker during query rescoring. This change is a prerequisite to add support for rescorers in retriever workflows. --- .../search/functionscore/QueryRescorerIT.java | 23 ++++ .../search/DefaultSearchContext.java | 3 +- .../search/builder/SearchSourceBuilder.java | 14 --- .../query/QueryPhaseCollectorManager.java | 3 +- .../search/rescore/RescorePhase.java | 101 +++++++++++++++--- 5 files changed, 113 insertions(+), 31 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index a7efb2fe0e68b..fbdcfe26d28ee 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -38,6 +38,7 @@ import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.ParseField; @@ -840,6 +841,20 @@ public void testRescorePhaseWithInvalidSort() throws Exception { } } ); + + assertResponse( + prepareSearch().addSort(SortBuilders.scoreSort()) + .addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)) + .setTrackScores(true) + .addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50), + response -> { + assertThat(response.getHits().getTotalHits().value(), equalTo(5L)); + assertThat(response.getHits().getHits().length, equalTo(5)); + for (SearchHit hit : response.getHits().getHits()) { + assertThat(hit.getScore(), equalTo(101f)); + } + } + ); } record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {} @@ -879,6 +894,10 @@ public void testRescoreAfterCollapse() throws Exception { .setQuery(fieldValueScoreQuery("firstPassScore")) .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore"))) .setCollapse(new CollapseBuilder("group")); + if (randomBoolean()) { + request.addSort(SortBuilders.scoreSort()); + request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + } assertResponse(request, resp -> { assertThat(resp.getHits().getTotalHits().value(), equalTo(5L)); assertThat(resp.getHits().getHits().length, equalTo(3)); @@ -958,6 +977,10 @@ public void testRescoreAfterCollapseRandom() throws Exception { .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).setQueryWeight(0f).windowSize(numGroups)) .setCollapse(new CollapseBuilder("group")) .setSize(Math.min(numGroups, 10)); + if (randomBoolean()) { + request.addSort(SortBuilders.scoreSort()); + request.addSort(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); + } long expectedNumHits = numHits; assertResponse(request, resp -> { assertThat(resp.getHits().getTotalHits().value(), equalTo(expectedNumHits)); diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index b87d097413b67..47d3ed337af73 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -73,6 +73,7 @@ import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.RescorePhase; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; @@ -377,7 +378,7 @@ public void preProcess() { ); } if (rescore != null) { - if (sort != null) { + if (RescorePhase.validateSort(sort) == false) { throw new IllegalArgumentException("Cannot use [sort] option in conjunction with [rescore]."); } int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow(); diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index 3554a6dc08b90..8c21abe4180ea 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -48,9 +48,7 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.slice.SliceBuilder; -import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -2341,18 +2339,6 @@ public ActionRequestValidationException validate( validationException = rescorer.validate(this, validationException); } } - - if (pointInTimeBuilder() == null && sorts() != null) { - for (var sortBuilder : sorts()) { - if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder - && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) { - validationException = addValidationError( - "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", - validationException - ); - } - } - } return validationException; } } diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java index cbc04dd460ff5..3d793a164f40a 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java @@ -58,6 +58,7 @@ import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.InternalProfileCollector; import org.elasticsearch.search.rescore.RescoreContext; +import org.elasticsearch.search.rescore.RescorePhase; import org.elasticsearch.search.sort.SortAndFormats; import java.io.IOException; @@ -238,7 +239,7 @@ static CollectorManager createQueryPhaseCollectorMa int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); final boolean rescore = searchContext.rescore().isEmpty() == false; if (rescore) { - assert searchContext.sort() == null; + assert RescorePhase.validateSort(searchContext.sort()); for (RescoreContext rescoreContext : searchContext.rescore()) { numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index 7e3646e7689cc..34aa7988d9d51 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; @@ -22,9 +23,12 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.SearchTimeoutException; +import org.elasticsearch.search.sort.ShardDocSortField; +import org.elasticsearch.search.sort.SortAndFormats; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -39,15 +43,26 @@ public static void execute(SearchContext context) { if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) { return; } - + assert validateSort(context.sort()) : "invalid sort"; TopDocs topDocs = context.queryResult().topDocs().topDocs; if (topDocs.scoreDocs.length == 0) { return; } + // Populate FieldDoc#score using the primary sort field (_score) to ensure compatibility + // with top docs rescoring. + Arrays.stream(topDocs.scoreDocs).forEach(t -> { + if (t instanceof FieldDoc fieldDoc) { + fieldDoc.score = (float) fieldDoc.fields[0]; + } + }); TopFieldGroups topGroups = null; + TopFieldDocs topFields = null; if (topDocs instanceof TopFieldGroups topFieldGroups) { - assert context.collapse() != null; + assert context.collapse() != null && validateSortFields(topFieldGroups.fields); topGroups = topFieldGroups; + } else if (topDocs instanceof TopFieldDocs topFieldDocs) { + assert validateSortFields(topFieldDocs.fields); + topFields = topFieldDocs; } try { Runnable cancellationCheck = getCancellationChecks(context); @@ -56,17 +71,18 @@ public static void execute(SearchContext context) { topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); // It is the responsibility of the rescorer to sort the resulted top docs, // here we only assert that this condition is met. - assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; + assert topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; ctx.setCancellationChecker(null); } + /** + * Since rescorers are building top docs with score only, we must reconstruct the {@link TopFieldGroups} + * or {@link TopFieldDocs} using their original version before rescoring. + */ if (topGroups != null) { assert context.collapse() != null; - /** - * Since rescorers don't preserve collapsing, we must reconstruct the group and field - * values from the originalTopGroups to create a new {@link TopFieldGroups} from the - * rescored top documents. - */ - topDocs = rewriteTopGroups(topGroups, topDocs); + topDocs = rewriteTopFieldGroups(topGroups, topDocs); + } else if (topFields != null) { + topDocs = rewriteTopFieldDocs(topFields, topDocs); } context.queryResult() .topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats()); @@ -81,29 +97,84 @@ public static void execute(SearchContext context) { } } - private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { - assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0]) - : "rescore must always sort by score descending"; + /** + * Returns whether the provided {@link SortAndFormats} can be used to rescore + * top documents. + */ + public static boolean validateSort(SortAndFormats sortAndFormats) { + if (sortAndFormats == null) { + return true; + } + return validateSortFields(sortAndFormats.sort.getSort()); + } + + private static boolean validateSortFields(SortField[] fields) { + if (fields[0].equals(SortField.FIELD_SCORE) == false) { + return false; + } + if (fields.length == 1) { + return true; + } + + // The ShardDocSortField can be used as a tiebreaker because it maintains + // the natural document ID order within the shard. + if (fields[1] instanceof ShardDocSortField == false || fields[1].getReverse()) { + return false; + } + return true; + } + + private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) { + Map docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length); + for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; i++) { + docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc) originalTopFieldDocs.scoreDocs[i]); + } + var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length]; + int pos = 0; + for (var doc : rescoredTopDocs.scoreDocs) { + newScoreDocs[pos] = docIdToFieldDoc.get(doc.doc); + newScoreDocs[pos].score = doc.score; + newScoreDocs[pos].fields[0] = newScoreDocs[pos].score; + pos++; + } + return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields); + } + + private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { + var newFieldDocs = rewriteFieldDocs((FieldDoc[]) originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs); + Map docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length); for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) { docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]); } - var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length]; var newGroupValues = new Object[originalTopGroups.groupValues.length]; int pos = 0; for (var doc : rescoredTopDocs.scoreDocs) { - newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score }); newGroupValues[pos++] = docIdToGroupValue.get(doc.doc); } return new TopFieldGroups( originalTopGroups.field, originalTopGroups.totalHits, - newScoreDocs, + newFieldDocs, originalTopGroups.fields, newGroupValues ); } + private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) { + Map docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length); + Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d)); + var newDocs = new FieldDoc[rescoredTopDocs.length]; + int pos = 0; + for (var doc : rescoredTopDocs) { + newDocs[pos] = docIdToFieldDoc.get(doc.doc); + newDocs[pos].score = doc.score; + newDocs[pos].fields[0] = doc.score; + pos++; + } + return newDocs; + } + /** * Returns true if the provided docs are sorted by score. */