Skip to content

Commit

Permalink
Fixed mismatch between document source and score fields when sorting …
Browse files Browse the repository at this point in the history
…is enabled in hybrid query

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 24, 2024
1 parent 22ba5d3 commit 6621735
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import java.util.Objects;
import java.util.Locale;
import java.util.ArrayList;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -45,7 +48,9 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
private FieldDoc after;
private FieldComparator<?> firstComparator;
// bottom would be set to null per shard.
private FieldValueHitQueue.Entry bottom;
@Getter(AccessLevel.PACKAGE)
@VisibleForTesting
private FieldValueHitQueue.Entry fieldValueLeafTrackers[];
@Getter
private int totalHits;
protected int docBase;
Expand Down Expand Up @@ -203,7 +208,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
comparators[subQueryNumber].copy(slot, doc);
add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score);
if (queueFull[subQueryNumber]) {
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
} else {
queueFull[subQueryNumber] = true;
Expand All @@ -216,9 +221,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException {
// This hit is competitive - replace bottom element in queue & adjustTop
if (numHits > 0) {
comparators[subQueryNumber].copy(bottom.slot, doc);
updateBottom(doc, compoundScores[subQueryNumber]);
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc);
updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
}

Expand Down Expand Up @@ -254,6 +259,9 @@ protected void initializePriorityQueuesWithComparators(LeafReaderContext context
initializeLeafFieldComparators(context, i);
}
}
if (Objects.isNull(fieldValueLeafTrackers)) {
fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries];
}
if (initializeLeafComparatorsPerSegmentOnce) {
for (int i = 0; i < numberOfSubQueries; i++) {
initializeComparators(context, i);
Expand Down Expand Up @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<Fiel
private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryNumber, float score) {
FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc);
bottomEntry.score = score;
bottom = compoundScore.add(bottomEntry);
fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry);
// The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case
// slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits
// on the current page) and slot = hitsCollected - 1.
Expand All @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry>
queueFull[subQueryNumber] = isQueueFull;
}

private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore) {
bottom.doc = docBase + doc;
bottom = compoundScore.updateTop();
private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryIndex) {
fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc;
fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop();
}

private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -24,6 +24,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
Expand All @@ -35,14 +36,13 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
Expand Down Expand Up @@ -127,8 +127,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce
DocIdSetIterator iterator = hybridQueryScorer.iterator();

int doc = iterator.nextDoc();
assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers());
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(doc);
FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers();
assertNotNull(fieldValueLeafTrackers);
assertEquals(1, fieldValueLeafTrackers.length);
assertEquals(doc, fieldValueLeafTrackers[0].doc);
doc = iterator.nextDoc();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
Expand Down Expand Up @@ -46,7 +46,7 @@
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {

Expand Down

0 comments on commit 6621735

Please sign in to comment.