From 99b651898d995d0245425adc2ce3920ba7e2a30f Mon Sep 17 00:00:00 2001 From: Jack Conradson Date: Tue, 7 Nov 2023 07:58:18 -0800 Subject: [PATCH] Add an additional tiebreaker to RRF (#101847) This change adds an additional tiebreaker for RRF where when two documents have the same RRF "score" such as identical ranks of (3,4) and (4,3) or (1,-) and (-,1), etc. the ordering will fallback to the highest score from query 1 then query 2, and so on. If all scores are equal then the tiebreaker will be shard index followed by doc id, but these are not necessarily stable. This should resolve most of the stability issues outlined as part of (#101232). Closes #101232 --- docs/changelog/101847.yaml | 6 + .../rank/rrf/RRFRankCoordinatorContext.java | 15 +- .../xpack/rank/rrf/RRFRankShardContext.java | 14 + .../xpack/rank/rrf/RRFRankContextTests.java | 323 +++++++++++++++++- 4 files changed, 349 insertions(+), 9 deletions(-) create mode 100644 docs/changelog/101847.yaml diff --git a/docs/changelog/101847.yaml b/docs/changelog/101847.yaml new file mode 100644 index 0000000000000..91922b9e23ed0 --- /dev/null +++ b/docs/changelog/101847.yaml @@ -0,0 +1,6 @@ +pr: 101847 +summary: Add an additional tiebreaker to RRF +area: Ranking +type: bug +issues: + - 101232 diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankCoordinatorContext.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankCoordinatorContext.java index d7b96ad439501..50f3646264a92 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankCoordinatorContext.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankCoordinatorContext.java @@ -127,12 +127,25 @@ protected boolean lessThan(RRFRankDoc a, RRFRankDoc b) { } } - // sort the results based on rrf score, tiebreaker based on smaller shard then smaller doc id + // sort the results based on rrf score, tiebreaker based on + // larger individual query score from 1 to n, smaller shard then smaller doc id RRFRankDoc[] sortedResults = results.values().toArray(RRFRankDoc[]::new); Arrays.sort(sortedResults, (RRFRankDoc rrf1, RRFRankDoc rrf2) -> { if (rrf1.score != rrf2.score) { return rrf1.score < rrf2.score ? 1 : -1; } + assert rrf1.positions.length == rrf2.positions.length; + for (int qi = 0; qi < rrf1.positions.length; ++qi) { + if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) { + if (rrf1.scores[qi] != rrf2.scores[qi]) { + return rrf1.scores[qi] < rrf2.scores[qi] ? 1 : -1; + } + } else if (rrf1.positions[qi] != NO_RANK) { + return -1; + } else if (rrf2.positions[qi] != NO_RANK) { + return 1; + } + } if (rrf1.shardIndex != rrf2.shardIndex) { return rrf1.shardIndex < rrf2.shardIndex ? -1 : 1; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankShardContext.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankShardContext.java index e251207bdcb2a..e22e328193700 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankShardContext.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankShardContext.java @@ -17,6 +17,8 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.search.rank.RankDoc.NO_RANK; + /** * Executes queries and generates results on the shard for RRF. */ @@ -74,6 +76,18 @@ public RRFRankShardResult combine(List rankResults) { if (rrf1.score != rrf2.score) { return rrf1.score < rrf2.score ? 1 : -1; } + assert rrf1.positions.length == rrf2.positions.length; + for (int qi = 0; qi < rrf1.positions.length; ++qi) { + if (rrf1.positions[qi] != NO_RANK && rrf2.positions[qi] != NO_RANK) { + if (rrf1.scores[qi] != rrf2.scores[qi]) { + return rrf1.scores[qi] < rrf2.scores[qi] ? 1 : -1; + } + } else if (rrf1.positions[qi] != NO_RANK) { + return -1; + } else if (rrf2.positions[qi] != NO_RANK) { + return 1; + } + } return rrf1.doc < rrf2.doc ? -1 : 1; }); // trim the results to window size diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java index f1f19a371ed07..5cb89c071c767 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankContextTests.java @@ -239,31 +239,338 @@ public void testCoordinatorRank() { expected.score = 0.6666667f; assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[0]); - expected = new RRFRankDoc(1, 1, 2); + expected = new RRFRankDoc(3, 1, 2); expected.rank = 2; + expected.positions[0] = 0; + expected.positions[1] = NO_RANK; + expected.scores[0] = 10.0f; + expected.scores[1] = 0.0f; + expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[1]); + + expected = new RRFRankDoc(1, 1, 2); + expected.rank = 3; expected.positions[0] = NO_RANK; expected.positions[1] = 0; expected.scores[0] = 0.0f; expected.scores[1] = 8.0f; expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[2]); + + expected = new RRFRankDoc(2, 2, 2); + expected.rank = 4; + expected.positions[0] = 3; + expected.positions[1] = 3; + expected.scores[0] = 8.5f; + expected.scores[1] = 6.5f; + expected.score = 0.4f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[3]); + } + + public void testShardTieBreaker() { + RRFRankShardContext context = new RRFRankShardContext(null, 0, 10, 1); + + List topDocs = List.of( + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(1, 10.0f, -1), new ScoreDoc(2, 9.0f, -1) }), + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(2, 8.0f, -1), new ScoreDoc(1, 7.0f, -1) }) + ); + + RRFRankShardResult result = context.combine(topDocs); + assertEquals(2, result.queryCount); + assertEquals(2, result.rrfRankDocs.length); + + RRFRankDoc expected = new RRFRankDoc(1, -1, 2); + expected.rank = 1; + expected.positions[0] = 0; + expected.positions[1] = 1; + expected.scores[0] = 10.0f; + expected.scores[1] = 7.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[0]); + + expected = new RRFRankDoc(2, -1, 2); + expected.rank = 2; + expected.positions[0] = 1; + expected.positions[1] = 0; + expected.scores[0] = 9.0f; + expected.scores[1] = 8.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[1]); + + topDocs = List.of( + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(1, 10.0f, -1), new ScoreDoc(2, 9.0f, -1), new ScoreDoc(3, 9.0f, -1) }), + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(4, 11.0f, -1), new ScoreDoc(3, 9.0f, -1), new ScoreDoc(2, 7.0f, -1) }) + ); + + result = context.combine(topDocs); + assertEquals(2, result.queryCount); + assertEquals(4, result.rrfRankDocs.length); + + expected = new RRFRankDoc(3, -1, 2); + expected.rank = 1; + expected.positions[0] = 2; + expected.positions[1] = 1; + expected.scores[0] = 9.0f; + expected.scores[1] = 9.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[0]); + + expected = new RRFRankDoc(2, -1, 2); + expected.rank = 2; + expected.positions[0] = 1; + expected.positions[1] = 2; + expected.scores[0] = 9.0f; + expected.scores[1] = 7.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[1]); + + expected = new RRFRankDoc(1, -1, 2); + expected.rank = 3; + expected.positions[0] = 0; + expected.positions[1] = -1; + expected.scores[0] = 10.0f; + expected.scores[1] = 0.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[2]); + + expected = new RRFRankDoc(4, -1, 2); + expected.rank = 4; + expected.positions[0] = -1; + expected.positions[1] = 0; + expected.scores[0] = 0.0f; + expected.scores[1] = 11.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[3]); + + topDocs = List.of( + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(1, 10.0f, -1), new ScoreDoc(3, 3.0f, -1) }), + new TopDocs(null, new ScoreDoc[] { new ScoreDoc(2, 8.0f, -1), new ScoreDoc(4, 5.0f, -1) }) + ); + + result = context.combine(topDocs); + assertEquals(2, result.queryCount); + assertEquals(4, result.rrfRankDocs.length); + + expected = new RRFRankDoc(1, -1, 2); + expected.rank = 1; + expected.positions[0] = 0; + expected.positions[1] = -1; + expected.scores[0] = 10.0f; + expected.scores[1] = 0.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[0]); + + expected = new RRFRankDoc(2, -1, 2); + expected.rank = 2; + expected.positions[0] = -1; + expected.positions[1] = 0; + expected.scores[0] = 0.0f; + expected.scores[1] = 8.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[1]); + + expected = new RRFRankDoc(3, -1, 2); + expected.rank = 3; + expected.positions[0] = 1; + expected.positions[1] = -1; + expected.scores[0] = 3.0f; + expected.scores[1] = 0.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[2]); + + expected = new RRFRankDoc(4, -1, 2); + expected.rank = 4; + expected.positions[0] = -1; + expected.positions[1] = 1; + expected.scores[0] = 0.0f; + expected.scores[1] = 5.0f; + expected.score = Float.NaN; + assertRDEquals(expected, result.rrfRankDocs[3]); + } + + public void testCoordinatorRankTieBreaker() { + RRFRankCoordinatorContext context = new RRFRankCoordinatorContext(4, 0, 5, 1); + + QuerySearchResult qsr0 = new QuerySearchResult(); + qsr0.setShardIndex(1); + RRFRankDoc rd11 = new RRFRankDoc(1, -1, 2); + rd11.positions[0] = 0; + rd11.positions[1] = 0; + rd11.scores[0] = 10.0f; + rd11.scores[1] = 7.0f; + qsr0.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd11 })); + + QuerySearchResult qsr1 = new QuerySearchResult(); + qsr1.setShardIndex(2); + RRFRankDoc rd21 = new RRFRankDoc(1, -1, 2); + rd21.positions[0] = 0; + rd21.positions[1] = 0; + rd21.scores[0] = 9.0f; + rd21.scores[1] = 8.0f; + qsr1.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd21 })); + + TopDocsStats tds = new TopDocsStats(0); + SortedTopDocs std = context.rank(List.of(qsr0, qsr1), tds); + + assertEquals(2, tds.fetchHits); + assertEquals(2, std.scoreDocs().length); + + RRFRankDoc expected = new RRFRankDoc(1, 1, 2); + expected.rank = 1; + expected.positions[0] = 0; + expected.positions[1] = 1; + expected.scores[0] = 10.0f; + expected.scores[1] = 7.0f; + expected.score = 0.8333333730697632f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[0]); + + expected = new RRFRankDoc(1, 2, 2); + expected.rank = 2; + expected.positions[0] = 1; + expected.positions[1] = 0; + expected.scores[0] = 9.0f; + expected.scores[1] = 8.0f; + expected.score = 0.8333333730697632f; assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[1]); - expected = new RRFRankDoc(3, 1, 2); + qsr0 = new QuerySearchResult(); + qsr0.setShardIndex(1); + rd11 = new RRFRankDoc(1, -1, 2); + rd11.positions[0] = 0; + rd11.positions[1] = -1; + rd11.scores[0] = 10.0f; + rd11.scores[1] = 0.0f; + RRFRankDoc rd12 = new RRFRankDoc(2, -1, 2); + rd12.positions[0] = 0; + rd12.positions[1] = 1; + rd12.scores[0] = 9.0f; + rd12.scores[1] = 7.0f; + qsr0.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd11, rd12 })); + + qsr1 = new QuerySearchResult(); + qsr1.setShardIndex(2); + rd21 = new RRFRankDoc(1, -1, 2); + rd21.positions[0] = -1; + rd21.positions[1] = 0; + rd21.scores[0] = 0.0f; + rd21.scores[1] = 11.0f; + RRFRankDoc rd22 = new RRFRankDoc(2, -1, 2); + rd22.positions[0] = 0; + rd22.positions[1] = 1; + rd22.scores[0] = 9.0f; + rd22.scores[1] = 9.0f; + qsr1.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd21, rd22 })); + + tds = new TopDocsStats(0); + std = context.rank(List.of(qsr0, qsr1), tds); + + assertEquals(4, tds.fetchHits); + assertEquals(4, std.scoreDocs().length); + + expected = new RRFRankDoc(2, 2, 2); + expected.rank = 1; + expected.positions[0] = 2; + expected.positions[1] = 1; + expected.scores[0] = 9.0f; + expected.scores[1] = 9.0f; + expected.score = 0.5833333730697632f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[0]); + + expected = new RRFRankDoc(2, 1, 2); + expected.rank = 2; + expected.positions[0] = 1; + expected.positions[1] = 2; + expected.scores[0] = 9.0f; + expected.scores[1] = 7.0f; + expected.score = 0.5833333730697632f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[1]); + + expected = new RRFRankDoc(1, 1, 2); expected.rank = 3; expected.positions[0] = 0; - expected.positions[1] = NO_RANK; + expected.positions[1] = -1; + expected.scores[0] = 10.0f; + expected.scores[1] = 0.0f; + expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[2]); + + expected = new RRFRankDoc(1, 2, 2); + expected.rank = 4; + expected.positions[0] = -1; + expected.positions[1] = 0; + expected.scores[0] = 0.0f; + expected.scores[1] = 11.0f; + expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[3]); + + qsr0 = new QuerySearchResult(); + qsr0.setShardIndex(1); + rd11 = new RRFRankDoc(1, -1, 2); + rd11.positions[0] = 0; + rd11.positions[1] = -1; + rd11.scores[0] = 10.0f; + rd11.scores[1] = 0.0f; + rd12 = new RRFRankDoc(2, -1, 2); + rd12.positions[0] = -1; + rd12.positions[1] = 0; + rd12.scores[0] = 0.0f; + rd12.scores[1] = 12.0f; + qsr0.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd11, rd12 })); + + qsr1 = new QuerySearchResult(); + qsr1.setShardIndex(2); + rd21 = new RRFRankDoc(1, -1, 2); + rd21.positions[0] = 0; + rd21.positions[1] = -1; + rd21.scores[0] = 3.0f; + rd21.scores[1] = 0.0f; + rd22 = new RRFRankDoc(2, -1, 2); + rd22.positions[0] = -1; + rd22.positions[1] = 0; + rd22.scores[0] = 0.0f; + rd22.scores[1] = 5.0f; + qsr1.setRankShardResult(new RRFRankShardResult(2, new RRFRankDoc[] { rd21, rd22 })); + + tds = new TopDocsStats(0); + std = context.rank(List.of(qsr0, qsr1), tds); + + assertEquals(4, tds.fetchHits); + assertEquals(4, std.scoreDocs().length); + + expected = new RRFRankDoc(1, 1, 2); + expected.rank = 1; + expected.positions[0] = 0; + expected.positions[1] = -1; expected.scores[0] = 10.0f; expected.scores[1] = 0.0f; expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[0]); + + expected = new RRFRankDoc(2, 1, 2); + expected.rank = 2; + expected.positions[0] = -1; + expected.positions[1] = 0; + expected.scores[0] = 0.0f; + expected.scores[1] = 12.0f; + expected.score = 0.5f; + assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[1]); + + expected = new RRFRankDoc(1, 2, 2); + expected.rank = 3; + expected.positions[0] = 1; + expected.positions[1] = -1; + expected.scores[0] = 3.0f; + expected.scores[1] = 0.0f; + expected.score = 0.3333333333333333f; assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[2]); expected = new RRFRankDoc(2, 2, 2); expected.rank = 4; - expected.positions[0] = 3; - expected.positions[1] = 3; - expected.scores[0] = 8.5f; - expected.scores[1] = 6.5f; - expected.score = 0.4f; + expected.positions[0] = -1; + expected.positions[1] = 1; + expected.scores[0] = 0.0f; + expected.scores[1] = 5.0f; + expected.score = 0.3333333333333333f; assertRDEquals(expected, (RRFRankDoc) std.scoreDocs()[3]); } }