Skip to content

Commit

Permalink
Adding unit tests, minor refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 4, 2024
1 parent d48c06c commit 274a416
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public final class HybridQueryScorer extends Scorer {
private final Map<Query, List<Integer>> queryToIndex;

private final DocIdSetIterator approximation;
HybridScorePropagator disjunctionBlockPropagator;
HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
private final TwoPhase twoPhase;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
Expand All @@ -56,23 +56,19 @@ public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOExcept

public HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(weight);
// max
this.subScorers = Collections.unmodifiableList(subScorers);
// custom
subScores = new float[subScorers.size()];
this.queryToIndex = mapQueryToIndex();
// base
this.subScorersPQ = initializeSubScorersPQ();
// base
boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ);
// max

this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ);
if (scoreMode == ScoreMode.TOP_SCORES) {
this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers);
this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers);
} else {
this.disjunctionBlockPropagator = null;
}
// base

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
Expand Down Expand Up @@ -269,6 +265,10 @@ public Collection<ChildScorable> getChildren() throws IOException {
return children;
}

/**
* Object returned by Scorer.twoPhaseIterator() to provide an approximation of a DocIdSetIterator.
* After calling nextDoc() or advance(int) on the iterator returned by approximation(), you need to check matches() to confirm if the retrieved document ID is a match.
*/
static class TwoPhase extends TwoPhaseIterator {
private final float matchCost;
// list of verified matches on the current doc
Expand All @@ -292,11 +292,10 @@ protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
}

DisiWrapper getSubMatches() throws IOException {
// iteration order does not matter
for (DisiWrapper w : unverifiedMatches) {
if (w.twoPhaseView.matches()) {
w.next = verifiedMatches;
verifiedMatches = w;
for (DisiWrapper wrapper : unverifiedMatches) {
if (wrapper.twoPhaseView.matches()) {
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;
}
}
unverifiedMatches.clear();
Expand All @@ -308,39 +307,38 @@ public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null;) {
DisiWrapper next = w.next;
for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) {
DisiWrapper next = wrapper.next;

if (w.twoPhaseView == null) {
if (Objects.isNull(wrapper.twoPhaseView)) {
// implicitly verified, move it to verifiedMatches
w.next = verifiedMatches;
verifiedMatches = w;
wrapper.next = verifiedMatches;
verifiedMatches = wrapper;

if (!needsScores) {
// we can stop here
return true;
}
} else {
unverifiedMatches.add(w);
unverifiedMatches.add(wrapper);
}
w = next;
wrapper = next;
}

if (verifiedMatches != null) {
if (Objects.nonNull(verifiedMatches)) {
return true;
}

// verify subs that have an two-phase iterator
// least-costly ones first
while (unverifiedMatches.size() > 0) {
DisiWrapper w = unverifiedMatches.pop();
if (w.twoPhaseView.matches()) {
w.next = null;
verifiedMatches = w;
DisiWrapper wrapper = unverifiedMatches.pop();
if (wrapper.twoPhaseView.matches()) {
wrapper.next = null;
verifiedMatches = wrapper;
return true;
}
}

return false;
}

Expand All @@ -350,42 +348,46 @@ public float matchCost() {
}
}

static class HybridDisjunctionDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator delegate;
/**
* A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports
* sub iterators that return empty results
*/
static class HybridSubqueriesDISIApproximation extends DocIdSetIterator {
final DocIdSetIterator docIdSetIterator;
final DisiPriorityQueue subIterators;

public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
delegate = new DisjunctionDISIApproximation(subIterators);
public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) {
docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
this.subIterators = subIterators;
}

@Override
public long cost() {
return delegate.cost();
return docIdSetIterator.cost();
}

@Override
public int docID() {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return delegate.docID();
return docIdSetIterator.docID();
}

@Override
public int nextDoc() throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return delegate.nextDoc();
return docIdSetIterator.nextDoc();
}

@Override
public int advance(int target) throws IOException {
public int advance(final int target) throws IOException {
if (subIterators.size() == 0) {
return NO_MORE_DOCS;
}
return delegate.advance(target);
return docIdSetIterator.advance(target);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
import java.util.Comparator;
import java.util.Objects;

public class HybridScorePropagator {
/**
* This class functions as a utility for propagating block boundaries within disjunctions.
* In disjunctions, where a match occurs if any subclause matches, a common approach might involve returning
* the minimum block boundary across all clauses. However, this method can introduce performance challenges,
* particularly when dealing with high minimum competitive scores and clauses with low scores that no longer
* significantly contribute to the iteration process. Therefore, this class computes block boundaries solely for clauses
* with a maximum score equal to or exceeding the minimum competitive score, or for the clause with the maximum
* score if such a clause is absent.
*/
public class HybridScoreBlockBoundaryPropagator {

private static final Comparator<Scorer> MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> {
try {
Expand All @@ -27,7 +36,7 @@ public class HybridScorePropagator {
private final float[] maxScores;
private int leadIndex = 0;

HybridScorePropagator(Collection<Scorer> scorers) throws IOException {
HybridScoreBlockBoundaryPropagator(final Collection<Scorer> scorers) throws IOException {
this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new);
for (Scorer scorer : this.scorers) {
scorer.advanceShallow(0);
Expand Down Expand Up @@ -73,7 +82,6 @@ int advanceShallow(int target) throws IOException {
break;
}
}

return upTo;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

/**
* Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits.
* In most cases it will be wrapped in MultiCollectorManager.
*/
@RequiredArgsConstructor
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

Expand All @@ -43,6 +47,12 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
* @param searchContext
* @return
* @throws IOException
*/
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
Expand Down Expand Up @@ -184,6 +194,10 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats
return sortAndFormats == null ? null : sortAndFormats.formats;
}

/**
* Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to
* use saved state of collector
*/
static class HybridCollectorNonConcurrentManager extends HybridCollectorManager {
Collector maxScoreCollector;

Expand All @@ -210,6 +224,10 @@ public Collector newCollector() {
}
}

/**
* Implementation of the HybridCollector that doesn't save collector's state and return new instance of every
* call of newCollector
*/
static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager {

public HybridCollectorConcurrentSearchManager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.TestUtil;

Expand Down Expand Up @@ -223,6 +224,99 @@ public void testMaxScoreFailures_whenScorerThrowsException_thenFail() {
assertTrue(runtimeException.getMessage().contains("Test exception"));
}

@SneakyThrows
public void testApproximationIterator_whenSubScorerSupportsApproximation_thenSuccessful() {
final int maxDoc = TestUtil.nextInt(random(), 10, 1_000);
final int numDocs = TestUtil.nextInt(random(), 1, maxDoc / 2);
final Set<Integer> uniqueDocs = new HashSet<>();
while (uniqueDocs.size() < numDocs) {
uniqueDocs.add(random().nextInt(maxDoc));
}
final int[] docs = new int[numDocs];
int i = 0;
for (int doc : uniqueDocs) {
docs[i++] = doc;
}
Arrays.sort(docs);
final float[] scores1 = new float[numDocs];
for (i = 0; i < numDocs; ++i) {
scores1[i] = random().nextFloat();
}
final float[] scores2 = new float[numDocs];
for (i = 0; i < numDocs; ++i) {
scores2[i] = random().nextFloat();
}

Weight weight = mock(Weight.class);

HybridQueryScorer queryScorer = new HybridQueryScorer(
weight,
Arrays.asList(
scorerWithTwoPhaseIterator(docs, scores1, fakeWeight(new MatchAllDocsQuery()), maxDoc),
scorerWithTwoPhaseIterator(docs, scores2, fakeWeight(new MatchNoDocsQuery()), maxDoc)
)
);

int doc = -1;
int idx = 0;
while (doc != DocIdSetIterator.NO_MORE_DOCS) {
doc = queryScorer.iterator().nextDoc();
if (idx == docs.length) {
assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc);
} else {
assertEquals(docs[idx], doc);
assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f);
}
idx++;
}
}

protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) {
final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc);
return new Scorer(weight) {

int lastScoredDoc = -1;

public DocIdSetIterator iterator() {
return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
}

@Override
public int docID() {
return iterator.docID();
}

@Override
public float score() {
assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID());
lastScoredDoc = docID();
final int idx = Arrays.binarySearch(docs, docID());
return scores[idx];
}

@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE;
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return new TwoPhaseIterator(iterator) {

@Override
public boolean matches() {
return Arrays.binarySearch(docs, iterator.docID()) >= 0;
}

@Override
public float matchCost() {
return 10;
}
};
}
};
}

private Pair<int[], float[]> generateDocuments(int maxDocId) {
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
Expand Down
Loading

0 comments on commit 274a416

Please sign in to comment.