diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 188a90209..c3365f6a1 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -47,7 +47,7 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; private final DocIdSetIterator approximation; - HybridScorePropagator disjunctionBlockPropagator; + HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; private final TwoPhase twoPhase; public HybridQueryScorer(Weight weight, List subScorers) throws IOException { @@ -56,23 +56,19 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept public HybridQueryScorer(Weight weight, List subScorers, ScoreMode scoreMode) throws IOException { super(weight); - // max this.subScorers = Collections.unmodifiableList(subScorers); - // custom subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); - // base this.subScorersPQ = initializeSubScorersPQ(); - // base boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; - this.approximation = new HybridDisjunctionDISIApproximation(this.subScorersPQ); - // max + + this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ); if (scoreMode == ScoreMode.TOP_SCORES) { - this.disjunctionBlockPropagator = new HybridScorePropagator(subScorers); + this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers); } else { this.disjunctionBlockPropagator = null; } - // base + boolean hasApproximation = false; float sumMatchCost = 0; long sumApproxCost = 0; @@ -269,6 +265,10 @@ public Collection getChildren() throws IOException { return children; } + /** + * Object returned by Scorer.twoPhaseIterator() to provide an approximation of a DocIdSetIterator. + * After calling nextDoc() or advance(int) on the iterator returned by approximation(), you need to check matches() to confirm if the retrieved document ID is a match. + */ static class TwoPhase extends TwoPhaseIterator { private final float matchCost; // list of verified matches on the current doc @@ -292,11 +292,10 @@ protected boolean lessThan(DisiWrapper a, DisiWrapper b) { } DisiWrapper getSubMatches() throws IOException { - // iteration order does not matter - for (DisiWrapper w : unverifiedMatches) { - if (w.twoPhaseView.matches()) { - w.next = verifiedMatches; - verifiedMatches = w; + for (DisiWrapper wrapper : unverifiedMatches) { + if (wrapper.twoPhaseView.matches()) { + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; } } unverifiedMatches.clear(); @@ -308,39 +307,38 @@ public boolean matches() throws IOException { verifiedMatches = null; unverifiedMatches.clear(); - for (DisiWrapper w = subScorers.topList(); w != null;) { - DisiWrapper next = w.next; + for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) { + DisiWrapper next = wrapper.next; - if (w.twoPhaseView == null) { + if (Objects.isNull(wrapper.twoPhaseView)) { // implicitly verified, move it to verifiedMatches - w.next = verifiedMatches; - verifiedMatches = w; + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; if (!needsScores) { // we can stop here return true; } } else { - unverifiedMatches.add(w); + unverifiedMatches.add(wrapper); } - w = next; + wrapper = next; } - if (verifiedMatches != null) { + if (Objects.nonNull(verifiedMatches)) { return true; } // verify subs that have an two-phase iterator // least-costly ones first while (unverifiedMatches.size() > 0) { - DisiWrapper w = unverifiedMatches.pop(); - if (w.twoPhaseView.matches()) { - w.next = null; - verifiedMatches = w; + DisiWrapper wrapper = unverifiedMatches.pop(); + if (wrapper.twoPhaseView.matches()) { + wrapper.next = null; + verifiedMatches = wrapper; return true; } } - return false; } @@ -350,18 +348,22 @@ public float matchCost() { } } - static class HybridDisjunctionDISIApproximation extends DocIdSetIterator { - final DocIdSetIterator delegate; + /** + * A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports + * sub iterators that return empty results + */ + static class HybridSubqueriesDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator docIdSetIterator; final DisiPriorityQueue subIterators; - public HybridDisjunctionDISIApproximation(DisiPriorityQueue subIterators) { - delegate = new DisjunctionDISIApproximation(subIterators); + public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) { + docIdSetIterator = new DisjunctionDISIApproximation(subIterators); this.subIterators = subIterators; } @Override public long cost() { - return delegate.cost(); + return docIdSetIterator.cost(); } @Override @@ -369,7 +371,7 @@ public int docID() { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.docID(); + return docIdSetIterator.docID(); } @Override @@ -377,15 +379,15 @@ public int nextDoc() throws IOException { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.nextDoc(); + return docIdSetIterator.nextDoc(); } @Override - public int advance(int target) throws IOException { + public int advance(final int target) throws IOException { if (subIterators.size() == 0) { return NO_MORE_DOCS; } - return delegate.advance(target); + return docIdSetIterator.advance(target); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java similarity index 78% rename from src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java rename to src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java index 92e1bbf7e..6b47a098d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridScorePropagator.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java @@ -13,7 +13,16 @@ import java.util.Comparator; import java.util.Objects; -public class HybridScorePropagator { +/** + * This class functions as a utility for propagating block boundaries within disjunctions. + * In disjunctions, where a match occurs if any subclause matches, a common approach might involve returning + * the minimum block boundary across all clauses. However, this method can introduce performance challenges, + * particularly when dealing with high minimum competitive scores and clauses with low scores that no longer + * significantly contribute to the iteration process. Therefore, this class computes block boundaries solely for clauses + * with a maximum score equal to or exceeding the minimum competitive score, or for the clause with the maximum + * score if such a clause is absent. + */ +public class HybridScoreBlockBoundaryPropagator { private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { try { @@ -27,7 +36,7 @@ public class HybridScorePropagator { private final float[] maxScores; private int leadIndex = 0; - HybridScorePropagator(Collection scorers) throws IOException { + HybridScoreBlockBoundaryPropagator(final Collection scorers) throws IOException { this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); for (Scorer scorer : this.scorers) { scorer.advanceShallow(0); @@ -73,7 +82,6 @@ int advanceShallow(int target) throws IOException { break; } } - return upTo; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 36a9002e8..1d715a14c 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -34,6 +34,10 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +/** + * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. + * In most cases it will be wrapped in MultiCollectorManager. + */ @RequiredArgsConstructor public abstract class HybridCollectorManager implements CollectorManager { @@ -43,6 +47,12 @@ public abstract class HybridCollectorManager implements CollectorManager uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDoc)); + } + final int[] docs = new int[numDocs]; + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores1 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores1[i] = random().nextFloat(); + } + final float[] scores2 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores2[i] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); + + HybridQueryScorer queryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorerWithTwoPhaseIterator(docs, scores1, fakeWeight(new MatchAllDocsQuery()), maxDoc), + scorerWithTwoPhaseIterator(docs, scores2, fakeWeight(new MatchNoDocsQuery()), maxDoc) + ) + ); + + int doc = -1; + int idx = 0; + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + doc = queryScorer.iterator().nextDoc(); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f); + } + idx++; + } + } + + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { + final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return new TwoPhaseIterator(iterator) { + + @Override + public boolean matches() { + return Arrays.binarySearch(docs, iterator.docID()) >= 0; + } + + @Override + public float matchCost() { + return 10; + } + }; + } + }; + } + private Pair generateDocuments(int maxDocId) { final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2); final int[] docs = new int[numDocs]; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java new file mode 100644 index 000000000..5bf0948ea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class HybridScoreBlockBoundaryPropagatorTests extends OpenSearchQueryTestCase { + + public void testAdvanceShallow_whenMinCompetitiveScoreSet_thenSuccessful() throws IOException { + Scorer scorer1 = new MockScorer(10, 0.6f); + Scorer scorer2 = new MockScorer(40, 1.5f); + Scorer scorer3 = new MockScorer(30, 2f); + Scorer scorer4 = new MockScorer(120, 4f); + + List scorers = Arrays.asList(scorer1, scorer2, scorer3, scorer4); + Collections.shuffle(scorers, random()); + HybridScoreBlockBoundaryPropagator propagator = new HybridScoreBlockBoundaryPropagator(scorers); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.1f); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.8f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.4f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.9f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(2.5f); + assertEquals(120, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(7f); + assertEquals(120, propagator.advanceShallow(0)); + } + + private static class MockWeight extends Weight { + + MockWeight() { + super(new MatchNoDocsQuery()); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } + + private static class MockScorer extends Scorer { + + final int boundary; + final float maxScore; + + MockScorer(int boundary, float maxScore) throws IOException { + super(new MockWeight()); + this.boundary = boundary; + this.maxScore = maxScore; + } + + @Override + public int docID() { + return 0; + } + + @Override + public float score() { + throw new UnsupportedOperationException(); + } + + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public void setMinCompetitiveScore(float minCompetitiveScore) {} + + @Override + public float getMaxScore(int upTo) throws IOException { + return maxScore; + } + + @Override + public int advanceShallow(int target) { + assert target <= boundary; + return boundary; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 2951dd666..674a3ebe6 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -54,9 +54,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEXT_FIELD_NAME = "field"; private static final String TERM_QUERY_TEXT = "keyword"; - private static final float DELTA_FOR_ASSERTION = 0.001f; - private static final float MAX_SCORE = 0.611f; @SneakyThrows public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() {