From 662173511cce3a2e7bbf523870be3ef94cd6e73b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 24 Dec 2024 15:51:43 -0800 Subject: [PATCH] Fixed mismatch between document source and score fields when sorting is enabled in hybrid query Signed-off-by: Martin Gaievski --- .../HybridTopFieldDocSortCollector.java | 26 ++++++++++++------- .../HybridTopFieldDocSortCollectorTests.java | 13 +++++++--- .../HybridTopScoreDocCollectorTests.java | 4 +-- 3 files changed, 28 insertions(+), 15 deletions(-) rename src/test/java/org/opensearch/neuralsearch/search/{ => collector}/HybridTopFieldDocSortCollectorTests.java (95%) rename src/test/java/org/opensearch/neuralsearch/search/{ => collector}/HybridTopScoreDocCollectorTests.java (99%) diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java index 2e268d37b..b6e13ec89 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java @@ -9,6 +9,9 @@ import java.util.Objects; import java.util.Locale; import java.util.ArrayList; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -45,7 +48,9 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl private FieldDoc after; private FieldComparator firstComparator; // bottom would be set to null per shard. - private FieldValueHitQueue.Entry bottom; + @Getter(AccessLevel.PACKAGE) + @VisibleForTesting + private FieldValueHitQueue.Entry fieldValueLeafTrackers[]; @Getter private int totalHits; protected int docBase; @@ -203,7 +208,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float comparators[subQueryNumber].copy(slot, doc); add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score); if (queueFull[subQueryNumber]) { - comparators[subQueryNumber].setBottom(bottom.slot); + comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot); } } else { queueFull[subQueryNumber] = true; @@ -216,9 +221,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException { // This hit is competitive - replace bottom element in queue & adjustTop if (numHits > 0) { - comparators[subQueryNumber].copy(bottom.slot, doc); - updateBottom(doc, compoundScores[subQueryNumber]); - comparators[subQueryNumber].setBottom(bottom.slot); + comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc); + updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber); + comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot); } } @@ -254,6 +259,9 @@ protected void initializePriorityQueuesWithComparators(LeafReaderContext context initializeLeafFieldComparators(context, i); } } + if (Objects.isNull(fieldValueLeafTrackers)) { + fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries]; + } if (initializeLeafComparatorsPerSegmentOnce) { for (int i = 0; i < numberOfSubQueries; i++) { initializeComparators(context, i); @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue compoundScore, int subQueryNumber, float score) { FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc); bottomEntry.score = score; - bottom = compoundScore.add(bottomEntry); + fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry); // The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case // slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits // on the current page) and slot = hitsCollected - 1. @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue queueFull[subQueryNumber] = isQueueFull; } - private void updateBottom(int doc, FieldValueHitQueue compoundScore) { - bottom.doc = docBase + doc; - bottom = compoundScore.updateTop(); + private void updateBottom(int doc, FieldValueHitQueue compoundScore, int subQueryIndex) { + fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc; + fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop(); } private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java similarity index 95% rename from src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java rename to src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java index 3bb0e6bcd..f1f50f97e 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopFieldDocSortCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollectorTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.search.collector; import java.util.ArrayList; import java.util.Arrays; @@ -24,6 +24,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FieldValueHitQueue; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; @@ -35,14 +36,13 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; + import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.query.HybridQueryScorer; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; -import org.opensearch.neuralsearch.search.collector.PagingFieldCollector; -import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase { static final String TEXT_FIELD_NAME = "field"; @@ -127,8 +127,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce DocIdSetIterator iterator = hybridQueryScorer.iterator(); int doc = iterator.nextDoc(); + assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers()); while (doc != DocIdSetIterator.NO_MORE_DOCS) { leafCollector.collect(doc); + FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers(); + assertNotNull(fieldValueLeafTrackers); + assertEquals(1, fieldValueLeafTrackers.length); + assertEquals(doc, fieldValueLeafTrackers[0].doc); doc = iterator.nextDoc(); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java similarity index 99% rename from src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java rename to src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java index 1fb66d5b7..4de32f6f2 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollectorTests.java @@ -2,7 +2,7 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.search.collector; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.LeafCollector; @@ -46,7 +46,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import lombok.SneakyThrows; -import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {