Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed document source and score field mismatch in sorted hybrid queries #1043

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
- Support empty string for fields in text embedding processor ([#1041](https://github.com/opensearch-project/neural-search/pull/1041))
### Bug Fixes
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import java.util.Objects;
import java.util.Locale;
import java.util.ArrayList;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -44,8 +47,10 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Nullable
private FieldDoc after;
private FieldComparator<?> firstComparator;
// bottom would be set to null per shard.
private FieldValueHitQueue.Entry bottom;
// the array stores bottom elements of the min heap of sorted hits for each sub query
@Getter(AccessLevel.PACKAGE)
@VisibleForTesting
private FieldValueHitQueue.Entry fieldValueLeafTrackers[];
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
@Getter
private int totalHits;
protected int docBase;
Expand All @@ -65,6 +70,7 @@ public abstract class HybridTopFieldDocSortCollector implements HybridSearchColl
@Getter
protected float maxScore = 0.0f;
protected int[] collectedHits;
private boolean needsInitialization = true;

// searchSortPartOfIndexSort is used to evaluate whether to perform index sort or not.
private Boolean searchSortPartOfIndexSort = null;
Expand Down Expand Up @@ -203,7 +209,7 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
comparators[subQueryNumber].copy(slot, doc);
add(slot, doc, compoundScores[subQueryNumber], subQueryNumber, score);
if (queueFull[subQueryNumber]) {
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
} else {
queueFull[subQueryNumber] = true;
Expand All @@ -216,9 +222,9 @@ protected void collectHit(int doc, int hitsCollected, int subQueryNumber, float
protected void collectCompetitiveHit(int doc, int subQueryNumber) throws IOException {
// This hit is competitive - replace bottom element in queue & adjustTop
if (numHits > 0) {
comparators[subQueryNumber].copy(bottom.slot, doc);
updateBottom(doc, compoundScores[subQueryNumber]);
comparators[subQueryNumber].setBottom(bottom.slot);
comparators[subQueryNumber].copy(fieldValueLeafTrackers[subQueryNumber].slot, doc);
updateBottom(doc, compoundScores[subQueryNumber], subQueryNumber);
comparators[subQueryNumber].setBottom(fieldValueLeafTrackers[subQueryNumber].slot);
}
}

Expand All @@ -245,14 +251,16 @@ protected boolean thresholdCheck(int doc, int subQueryNumber) throws IOException
The method initializes once per search request.
*/
protected void initializePriorityQueuesWithComparators(LeafReaderContext context, int numberOfSubQueries) throws IOException {
if (compoundScores == null) {
if (needsInitialization) {
compoundScores = new FieldValueHitQueue[numberOfSubQueries];
comparators = new LeafFieldComparator[numberOfSubQueries];
queueFull = new boolean[numberOfSubQueries];
collectedHits = new int[numberOfSubQueries];
for (int i = 0; i < numberOfSubQueries; i++) {
initializeLeafFieldComparators(context, i);
}
fieldValueLeafTrackers = new FieldValueHitQueue.Entry[numberOfSubQueries];
needsInitialization = false;
}
if (initializeLeafComparatorsPerSegmentOnce) {
for (int i = 0; i < numberOfSubQueries; i++) {
Expand Down Expand Up @@ -369,7 +377,7 @@ private void populateResults(ScoreDoc[] results, int howMany, PriorityQueue<Fiel
private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryNumber, float score) {
FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc);
bottomEntry.score = score;
bottom = compoundScore.add(bottomEntry);
fieldValueLeafTrackers[subQueryNumber] = compoundScore.add(bottomEntry);
// The queue is full either when totalHits == numHits (in SimpleFieldCollector), in which case
// slot = totalHits - 1, or when hitsCollected == numHits (in PagingFieldCollector this is hits
// on the current page) and slot = hitsCollected - 1.
Expand All @@ -381,9 +389,9 @@ private void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry>
queueFull[subQueryNumber] = isQueueFull;
}

private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore) {
bottom.doc = docBase + doc;
bottom = compoundScore.updateTop();
private void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int subQueryIndex) {
fieldValueLeafTrackers[subQueryIndex].doc = docBase + doc;
fieldValueLeafTrackers[subQueryIndex] = compoundScore.updateTop();
}

private boolean canEarlyTerminate(Sort searchSort, Sort indexSort) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -24,6 +24,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
Expand All @@ -35,14 +36,13 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.collector.PagingFieldCollector;
import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopFieldDocSortCollectorTests extends OpenSearchQueryTestCase {
static final String TEXT_FIELD_NAME = "field";
Expand Down Expand Up @@ -127,8 +127,13 @@ public void testSimpleFieldCollectorTopDocs_whenCreateNewAndGetTopDocs_thenSucce
DocIdSetIterator iterator = hybridQueryScorer.iterator();

int doc = iterator.nextDoc();
assertNull(hybridTopFieldDocSortCollector.getFieldValueLeafTrackers());
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(doc);
FieldValueHitQueue.Entry[] fieldValueLeafTrackers = hybridTopFieldDocSortCollector.getFieldValueLeafTrackers();
assertNotNull(fieldValueLeafTrackers);
assertEquals(1, fieldValueLeafTrackers.length);
assertEquals(doc, fieldValueLeafTrackers[0].doc);
doc = iterator.nextDoc();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search;
package org.opensearch.neuralsearch.search.collector;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
Expand Down Expand Up @@ -46,7 +46,7 @@
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

public class HybridTopScoreDocCollectorTests extends OpenSearchQueryTestCase {

Expand Down
Loading