Skip to content

Commit

Permalink
[Part 3] Concurrent segment search bug in Sorting (#808)
Browse files Browse the repository at this point in the history
* Cherry picking Concurrent Segment Search Bug Commit

Signed-off-by: Varun Jain <[email protected]>

* Fix Concurrent Segment Search Bug in Sorting

Signed-off-by: Varun Jain <[email protected]>

* Functional Interface

Signed-off-by: Varun Jain <[email protected]>

* Addressing Martin Comments

Signed-off-by: Varun Jain <[email protected]>

* Removing comments

Signed-off-by: Varun Jain <[email protected]>

* Addressing Martin Comments

Signed-off-by: Varun Jain <[email protected]>

* Addressing Martin Comments

Signed-off-by: Varun Jain <[email protected]>

* Addressing Martin commnents

Signed-off-by: Varun Jain <[email protected]>

* Address Martin Comments

Signed-off-by: Varun Jain <[email protected]>

* Address Martin Comments

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>
Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
vibrantvarun and martin-gaievski authored Jul 9, 2024
1 parent d0f870c commit ded2788
Show file tree
Hide file tree
Showing 22 changed files with 1,837 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (vectorSupplier().get() == null) {
return this;
}
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter());
if (maxDistance != null) {
knnQueryBuilder.maxDistance(maxDistance);
} else if (minScore != null) {
knnQueryBuilder.minScore(minScore);
} else {
knnQueryBuilder.k(k);
}
return knnQueryBuilder;
return KNNQueryBuilder.builder()
.fieldName(fieldName())
.vector(vectorSupplier.get())
.filter(filter())
.maxDistance(maxDistance)
.minScore(minScore)
.k(k)
.build();
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.collector;

import java.util.List;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.TopDocs;

/**
* Common interface class for Hybrid search collectors
*/
public interface HybridSearchCollector extends Collector {
/**
* @return List of topDocs which contains topDocs of individual subqueries.
*/
List<? extends TopDocs> topDocs();

/**
* @return count of total hits per shard
*/
int getTotalHits();

/**
* @return maxScore found on a shard
*/
float getMaxScore();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -38,7 +37,7 @@
The individual query results are sorted as per the sort criteria sent in the search request.
*/
@Log4j2
public abstract class HybridTopFieldDocSortCollector implements Collector {
public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector {
private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final Sort sort;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import lombok.Getter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
Expand All @@ -30,7 +29,7 @@
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
*/
@Log4j2
public class HybridTopScoreDocCollector implements Collector {
public class HybridTopScoreDocCollector implements HybridSearchCollector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import java.util.Comparator;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;

/**
* Comparator class that compares two field docs as per the sorting criteria
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryFieldDocComparator implements Comparator<FieldDoc> {
final SortField[] sortFields;
final FieldComparator<?>[] comparators;
final int[] reverseMul;
final Comparator<ScoreDoc> tieBreaker;

public HybridQueryFieldDocComparator(SortField[] sortFields, Comparator<ScoreDoc> tieBreaker) {
this.sortFields = sortFields;
this.tieBreaker = tieBreaker;
comparators = new FieldComparator[sortFields.length];
reverseMul = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
}
}

@Override
public int compare(final FieldDoc firstFD, final FieldDoc secondFD) {
for (int compIDX = 0; compIDX < comparators.length; compIDX++) {
final FieldComparator comp = comparators[compIDX];

final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);

if (cmp != 0) {
return cmp;
}
}
return tieBreakCompare(firstFD, secondFD, tieBreaker);
}

private int tieBreakCompare(ScoreDoc firstDoc, ScoreDoc secondDoc, Comparator<ScoreDoc> tieBreaker) {
assert tieBreaker != null;
int value = tieBreaker.compare(firstDoc, secondDoc);
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement;

/**
* Merges two ScoreDoc arrays into one
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
class HybridQueryScoreDocsMerger<T extends ScoreDoc> {

private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3;

/**
* Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects.
* Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from
* other query types.
* Logic is based on assumption that hits of every sub-query are sorted by score.
* Method returns new object and doesn't mutate original ScoreDocs arrays.
* @param sourceScoreDocs original score docs from query result
* @param newScoreDocs new score docs that we need to merge into existing scores
* @param comparator comparator to compare the score docs
* @param isSortEnabled flag that show if sort is enabled or disabled
* @return merged array of ScoreDocs objects
*/
public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator<T> comparator, final boolean isSortEnabled) {
if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC
|| Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) {
throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements");
}
// we overshoot and preallocate more than we need - length of both top docs combined.
// we will take only portion of the array at the end
List<T> mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length);
int sourcePointer = 0;
// mark beginning of hybrid query results by start element
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
// new pointer is set to 1 as we don't care about it start-stop element
int newPointer = 1;

while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) {
// every iteration is for results of one sub-query
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
newPointer++;
// simplest case when both arrays have results for sub-query
while (sourcePointer < sourceScoreDocs.length
&& isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])
&& newPointer < newScoreDocs.length
&& isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
if (compareCondition(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer], comparator, isSortEnabled)) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
} else {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// at least one object got exhausted at this point, now merge all elements from object that's left
while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) {
mergedScoreDocs.add(sourceScoreDocs[sourcePointer]);
sourcePointer++;
}
while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) {
mergedScoreDocs.add(newScoreDocs[newPointer]);
newPointer++;
}
}
// mark end of hybrid query results by end element
mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]);
if (isSortEnabled) {
return mergedScoreDocs.toArray((T[]) new FieldDoc[0]);
}
return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]);
}

private boolean compareCondition(
final ScoreDoc oldScoreDoc,
final ScoreDoc secondScoreDoc,
final Comparator<T> comparator,
final boolean isSortEnabled
) {
// If sorting is enabled then compare condition will be different then normal HybridQuery
if (isSortEnabled) {
return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) < 0;
} else {
return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) >= 0;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;

import java.util.Comparator;
import java.util.Objects;
import org.opensearch.search.sort.SortAndFormats;

/**
* Utility class for merging TopDocs and MaxScore across multiple search queries
*/
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
class TopDocsMerger {

private HybridQueryScoreDocsMerger docsMerger;
private SortAndFormats sortAndFormats;
@VisibleForTesting
protected static Comparator<ScoreDoc> SCORE_DOC_BY_SCORE_COMPARATOR;
@VisibleForTesting
protected static HybridQueryFieldDocComparator FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR;
private final Comparator<ScoreDoc> MERGING_TIE_BREAKER = (o1, o2) -> {
int docIdComparison = Integer.compare(o1.doc, o2.doc);
return docIdComparison;
};

/**
* Uses hybrid query score docs merger to merge internal score docs
*/
TopDocsMerger(final SortAndFormats sortAndFormats) {
this.sortAndFormats = sortAndFormats;
if (isSortingEnabled()) {
docsMerger = new HybridQueryScoreDocsMerger<FieldDoc>();
FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER);
} else {
docsMerger = new HybridQueryScoreDocsMerger<>();
SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score);
}
}

/**
* Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object.
* @param source TopDocsAndMaxScore for the original query
* @param newTopDocs TopDocsAndMaxScore for the new query
* @return merged TopDocsAndMaxScore object
*/
public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) {
if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) {
return source;
}
TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs);
TopDocsAndMaxScore result = new TopDocsAndMaxScore(
getTopDocs(getMergedScoreDocs(source.topDocs.scoreDocs, newTopDocs.topDocs.scoreDocs), mergedTotalHits),
Math.max(source.maxScore, newTopDocs.maxScore)
);
return result;
}

private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) {
// merged value is a lower bound - if both are equal_to than merged will also be equal_to,
// otherwise assign greater_than_or_equal
TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|| newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation);
}

private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits) {
if (isSortingEnabled()) {
return new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort());
}
return new TopDocs(mergedTotalHits, mergedScoreDocs);
}

private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs) {
// Case 1 when sorting is enabled then below will be the TopDocs format
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1 | [1]
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_2 | [1]
// ...
// doc_id | magic_number_1 | [1]

// Case 2 when sorting is disabled then below will be the TopDocs format
// we need to merge hits per individual sub-query
// format of results in both new and source TopDocs is following
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
return docsMerger.merge(source, newScoreDocs, comparator(), isSortingEnabled());
}

private Comparator<? extends ScoreDoc> comparator() {
return sortAndFormats != null ? FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR : SCORE_DOC_BY_SCORE_COMPARATOR;
}

private boolean isSortingEnabled() {
return sortAndFormats != null;
}
}
Loading

0 comments on commit ded2788

Please sign in to comment.