Skip to content

Commit

Permalink
Adding comments
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed May 30, 2024
1 parent f2fabb7 commit f8d4930
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ private void updateOriginalQueryResults(
)
);
}

for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ public class ScoreCombiner {
*
* @param queryTopDocs query results that need to be normalized, mutated by method execution
* @param scoreCombinationTechnique exact combination method that should be applied
* @param isSortingEnabled if sorting is enabled or not
* @param sort sort criteria
*/
public void combineScores(
final List<CompoundTopDocs> queryTopDocs,
final ScoreCombinationTechnique scoreCombinationTechnique,
final boolean isSortingEnabled,
final Sort sort
) {
final boolean isSortByScore = isSortByScore(sort);
final boolean isSortByScore = checkIfSortOrderByScore(sort);
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
queryTopDocs.forEach(
Expand Down Expand Up @@ -87,14 +89,15 @@ private void combineShardScores(
scoreCombinationTechnique
);

// If sorting is enabled then create the map of sort field values per docId
Map<Integer, Object[]> docIdSortFieldMap = null;
List<TopFieldDocs> topFieldDocs = null;
if (isSortingEnabled) {
topFieldDocs = topDocsPerSubQuery.stream()
.filter(topDocs -> topDocs.scoreDocs.length != 0)
.map(topDocs -> (TopFieldDocs) topDocs)
.collect(Collectors.toList());
docIdSortFieldMap = getDocIdFieldMap(compoundQueryTopDocs, isSortByScore, combinedNormalizedScoresByDocId);
docIdSortFieldMap = getDocIdSortFieldsMap(compoundQueryTopDocs, isSortByScore, combinedNormalizedScoresByDocId);
}

// - sort documents by scores and take first "max number" of docs
Expand All @@ -112,7 +115,7 @@ private void combineShardScores(
);
}

private boolean isSortByScore(Sort sort) {
private boolean checkIfSortOrderByScore(Sort sort) {
if (sort != null) {
for (SortField sortField : sort.getSort()) {
if (sortField.getType().equals(SortField.Type.SCORE)) {
Expand All @@ -123,7 +126,7 @@ private boolean isSortByScore(Sort sort) {
return false;
}

private Map<Integer, Object[]> getDocIdFieldMap(
private Map<Integer, Object[]> getDocIdSortFieldsMap(
final CompoundTopDocs compoundTopDocs,
final boolean isSortByScore,
Map<Integer, Float> combinedNormalizedScoresByDocId
Expand All @@ -137,6 +140,7 @@ private Map<Integer, Object[]> getDocIdFieldMap(
FieldDoc fieldDoc = (FieldDoc) scoreDoc;

if (docIdSortFieldMap.get(fieldDoc.doc) == null) {
// If sort by score then replace sort field value with normalized score.
if (isSortByScore) {
docIdSortFieldMap.put(fieldDoc.doc, new Object[] { combinedNormalizedScoresByDocId.get(fieldDoc.doc) });
} else {
Expand Down Expand Up @@ -164,7 +168,11 @@ private List<Integer> getSortedDocIds(
for (TopFieldDocs topFieldDoc : topFieldDocs) {
topN += topFieldDoc.scoreDocs.length;
}

// Merge the sorted results of individual queries to form a one final result per shard which is sorted.
final TopDocs sortedTopDocs = TopDocs.merge(sort, 0, topN, topFieldDocs.toArray(new TopFieldDocs[0]), getTieBreaker());

// Remove duplicates from the sorted top docs.
Set<Integer> uniqueDocIds = new LinkedHashSet<>();
for (ScoreDoc scoreDoc : sortedTopDocs.scoreDocs) {
uniqueDocIds.add(scoreDoc.doc);
Expand Down Expand Up @@ -255,6 +263,7 @@ private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, final lon
return new TotalHits(maxHits, totalHits);
}

// Tie-breaker to merge multiple top docs
private Comparator<ScoreDoc> getTieBreaker() {
final Comparator<ScoreDoc> Sorting_TIE_BREAKER = (o1, o2) -> {
int scoreComparison = Double.compare(o1.score, o2.score);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@
import org.opensearch.neuralsearch.query.HybridQueryScorer;
import org.opensearch.common.Nullable;

/*
Collects the TopFieldDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results.
The individual query results are sorted as per the sort criteria sent in the search request.
*/
@Log4j2
public abstract class HybridTopDocSortCollector implements Collector {
public abstract class HybridTopFieldDocSortCollector implements Collector {
final int numHits;
final HitsThresholdChecker hitsThresholdChecker;
int docBase;
Expand Down Expand Up @@ -66,11 +70,12 @@ public abstract class HybridTopDocSortCollector implements Collector {
// internal versions. If someone will define a constructor with any other
// visibility, then anyone will be able to extend the class, which is not what
// we want.
private HybridTopDocSortCollector(final int numHits, final HitsThresholdChecker hitsThresholdChecker) {
private HybridTopFieldDocSortCollector(final int numHits, final HitsThresholdChecker hitsThresholdChecker) {
this.numHits = numHits;
this.hitsThresholdChecker = hitsThresholdChecker;
}

// Add the entry in the Priority queue
void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore, int i, float score) {
FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, docBase + doc);
bottomEntry.score = score;
Expand All @@ -83,7 +88,6 @@ void add(int slot, int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoun
}

void updateBottom(int doc, FieldValueHitQueue<FieldValueHitQueue.Entry> compoundScore) {
// bottom.score is already set to Float.NaN in add().
bottom.doc = docBase + doc;
bottom = compoundScore.updateTop();
}
Expand Down Expand Up @@ -285,7 +289,7 @@ boolean thresholdCheck(int doc, int subQueryNumber) throws IOException {

}

public static class SimpleFieldCollector extends HybridTopDocSortCollector {
public static class SimpleFieldCollector extends HybridTopFieldDocSortCollector {
final Sort sort;
final int numHits;

Expand Down Expand Up @@ -315,14 +319,14 @@ public void collect(int doc) throws IOException {
continue;
}
collectedHits[i]++;
maxScore = Math.max(score, maxScore);
if (queueFull[i]) {
if (thresholdCheck(doc, i)) {
return;
}
collectCompetitiveHit(doc, i);
} else {
collectHit(doc, collectedHits[i], i, score);
maxScore = Math.max(score, maxScore);
}

}
Expand All @@ -335,7 +339,7 @@ public List<TopFieldDocs> topDocs() {
}
}

public static class PagingFieldCollector extends HybridTopDocSortCollector {
public static class PagingFieldCollector extends HybridTopFieldDocSortCollector {

final Sort sort;
final int numHits;
Expand Down Expand Up @@ -379,13 +383,12 @@ public void collect(int doc) throws IOException {
if (resultsFoundOnPreviousPage) {
return;
}

maxScore = Math.max(score, maxScore);
if (queueFull[i]) {
collectCompetitiveHit(doc, i);
} else {
collectedHits[i]++;
collectHit(doc, collectedHits[i], i, score);
maxScore = Math.max(score, maxScore);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopDocSortCollector;
import org.opensearch.neuralsearch.search.HybridTopFieldDocSortCollector;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
Expand Down Expand Up @@ -117,7 +117,11 @@ public Collector newCollector() {
Collector hybridcollector;
if (sortAndFormats != null) {
if (after == null) {
hybridcollector = new HybridTopDocSortCollector.SimpleFieldCollector(numHits, hitsThresholdChecker, sortAndFormats.sort);
hybridcollector = new HybridTopFieldDocSortCollector.SimpleFieldCollector(
numHits,
hitsThresholdChecker,
sortAndFormats.sort
);
} else {
if (after.fields == null) {
throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
Expand All @@ -129,7 +133,7 @@ public Collector newCollector() {
);
}

hybridcollector = new HybridTopDocSortCollector.PagingFieldCollector(
hybridcollector = new HybridTopFieldDocSortCollector.PagingFieldCollector(
numHits,
hitsThresholdChecker,
sortAndFormats.sort,
Expand Down Expand Up @@ -157,7 +161,7 @@ public Collector newCollector() {
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<>();
final List<HybridTopDocSortCollector> hybridSortedTopDocCollectors = new ArrayList<>();
final List<HybridTopFieldDocSortCollector> hybridSortedTopDocCollectors = new ArrayList<>();
// check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper
// in case multiple collector managers are registered. We use hybrid scores collector to format scores into
// format specific for hybrid search query: start, sub-query-delimiter, scores, stop
Expand All @@ -166,30 +170,30 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub);
} else if (sub instanceof HybridTopDocSortCollector.SimpleFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopDocSortCollector.SimpleFieldCollector) sub);
} else if (sub instanceof HybridTopDocSortCollector.PagingFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopDocSortCollector.PagingFieldCollector) sub);
} else if (sub instanceof HybridTopFieldDocSortCollector.SimpleFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector.SimpleFieldCollector) sub);
} else if (sub instanceof HybridTopFieldDocSortCollector.PagingFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector.PagingFieldCollector) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector);
} else if (collector instanceof HybridTopDocSortCollector.SimpleFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopDocSortCollector.SimpleFieldCollector) collector);
} else if (collector instanceof HybridTopDocSortCollector.PagingFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopDocSortCollector.PagingFieldCollector) collector);
} else if (collector instanceof HybridTopFieldDocSortCollector.SimpleFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector.SimpleFieldCollector) collector);
} else if (collector instanceof HybridTopFieldDocSortCollector.PagingFieldCollector) {
hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector.PagingFieldCollector) collector);
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector());
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopDocSortCollector.SimpleFieldCollector) {
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector.SimpleFieldCollector) {
hybridSortedTopDocCollectors.add(
(HybridTopDocSortCollector.SimpleFieldCollector) ((FilteredCollector) collector).getCollector()
(HybridTopFieldDocSortCollector.SimpleFieldCollector) ((FilteredCollector) collector).getCollector()
);
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopDocSortCollector.PagingFieldCollector) {
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector.PagingFieldCollector) {
hybridSortedTopDocCollectors.add(
(HybridTopDocSortCollector.PagingFieldCollector) ((FilteredCollector) collector).getCollector()
(HybridTopFieldDocSortCollector.PagingFieldCollector) ((FilteredCollector) collector).getCollector()
);
}
}
Expand All @@ -208,22 +212,22 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
}

if (!hybridSortedTopDocCollectors.isEmpty()) {
HybridTopDocSortCollector hybridSortedTopScoreDocCollector = hybridSortedTopDocCollectors.stream()
HybridTopFieldDocSortCollector hybridSortedTopScoreDocCollector = hybridSortedTopDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));

HybridTopDocSortCollector.SimpleFieldCollector simpleFieldCollector;
HybridTopDocSortCollector.PagingFieldCollector pagingFieldCollector;
HybridTopFieldDocSortCollector.SimpleFieldCollector simpleFieldCollector;
HybridTopFieldDocSortCollector.PagingFieldCollector pagingFieldCollector;
List<TopFieldDocs> topFieldDocs;
long maxTotalHits;
float maxScore;
if (hybridSortedTopScoreDocCollector instanceof HybridTopDocSortCollector.SimpleFieldCollector) {
simpleFieldCollector = (HybridTopDocSortCollector.SimpleFieldCollector) hybridSortedTopScoreDocCollector;
if (hybridSortedTopScoreDocCollector instanceof HybridTopFieldDocSortCollector.SimpleFieldCollector) {
simpleFieldCollector = (HybridTopFieldDocSortCollector.SimpleFieldCollector) hybridSortedTopScoreDocCollector;
topFieldDocs = simpleFieldCollector.topDocs();
maxTotalHits = simpleFieldCollector.getTotalHits();
maxScore = simpleFieldCollector.getMaxScore();
} else {
pagingFieldCollector = (HybridTopDocSortCollector.PagingFieldCollector) hybridSortedTopScoreDocCollector;
pagingFieldCollector = (HybridTopFieldDocSortCollector.PagingFieldCollector) hybridSortedTopScoreDocCollector;
topFieldDocs = pagingFieldCollector.topDocs();
maxTotalHits = pagingFieldCollector.getTotalHits();
maxScore = pagingFieldCollector.getMaxScore();
Expand Down Expand Up @@ -263,7 +267,7 @@ private static void validateSortCriteria(SearchContext searchContext, boolean tr
}
if (trackScores && isSortByField) {
throw new IllegalArgumentException(
"Hybrid search results are sorted by any field, docId or _id, track_scores must be set to true."
"Hybrid search results when sorted by any field, docId or _id, track_scores must be set to false."
);
}
if (trackScores && isSortByScore) {
Expand Down

0 comments on commit f8d4930

Please sign in to comment.