Skip to content

Commit

Permalink
Speed up advancing on the disjunction iterator. (#14052)
Browse files Browse the repository at this point in the history
Currently, the disjunction iterator puts all clauses in a heap in order to be
able to merge doc IDs in a streaming fashion. This is a good approach for
exhaustive evaluation, when only one clause moves to a different doc ID on
average and the per-iteration cost is in the order of O(log(N)) where N is the
number of clauses.

However, if a selective filter is applied, this could cause many clauses to
move to a different doc ID. In the worst-case scenario, all clauses could move
to a different doc ID and the cost of maintaiting heap invariants could grow to
O(N * log(N)) (every clause introduces a O(log(N)) cost). With many clauses,
this is much higher than the cost of checking all clauses sequentially: O(N).

To protect from this reordering overhead, DisjunctionDISIApproximation now only
puts the cheapest clauses in a heap in a way that tries to achieve up to 1.5
clauses moving to a different doc ID on average. More expensive clauses are
checked linearly.
  • Loading branch information
jpountz authored Dec 16, 2024
1 parent a8d8d6b commit bc341f2
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 89 deletions.
3 changes: 2 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ Improvements

Optimizations
---------------------
(No changes)

* GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand)

Bug Fixes
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.RamUsageEstimator;

/**
Expand Down Expand Up @@ -151,7 +150,8 @@ protected abstract WeightOrDocIdSetIterator rewriteInner(
int fieldDocCount,
Terms terms,
TermsEnum termsEnum,
List<TermAndState> collectedTerms)
List<TermAndState> collectedTerms,
long leadCost)
throws IOException;

private WeightOrDocIdSetIterator rewriteAsBooleanQuery(
Expand Down Expand Up @@ -247,21 +247,22 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
cost = estimateCost(terms, q.getTermsCount());
}

IOSupplier<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
() -> {
IOLongFunction<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
leadCost -> {
if (collectResult) {
return rewriteAsBooleanQuery(context, collectedTerms);
} else {
// Too many terms to rewrite as a simple bq.
// Invoke rewriteInner logic to handle rewriting:
return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms);
return rewriteInner(
context, fieldDocCount, terms, termsEnum, collectedTerms, leadCost);
}
};

return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) throws IOException {
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.apply(leadCost);
final Scorer scorer;
if (weightOrIterator == null) {
scorer = null;
Expand All @@ -281,7 +282,8 @@ public Scorer get(long leadCost) throws IOException {

@Override
public BulkScorer bulkScorer() throws IOException {
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
WeightOrDocIdSetIterator weightOrIterator =
weightOrIteratorSupplier.apply(Long.MAX_VALUE);
final BulkScorer bulkScorer;
if (weightOrIterator == null) {
bulkScorer = null;
Expand Down Expand Up @@ -311,6 +313,10 @@ public long cost() {
};
}

private static interface IOLongFunction<T> {
T apply(long arg) throws IOException;
}

private static long estimateCost(Terms terms, long queryTermsCount) throws IOException {
// Estimate the cost. If the MTQ can provide its term count, we can do a better job
// estimating.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ BulkScorer booleanScorer() throws IOException {
Scorer prohibitedScorer =
prohibited.size() == 1
? prohibited.get(0)
: new DisjunctionSumScorer(prohibited, ScoreMode.COMPLETE_NO_SCORES);
: new DisjunctionSumScorer(
prohibited, ScoreMode.COMPLETE_NO_SCORES, positiveScorerCost);
return new ReqExclBulkScorer(positiveScorer, prohibitedScorer);
}
}
Expand Down Expand Up @@ -509,7 +510,7 @@ private Scorer opt(
if ((scoreMode == ScoreMode.TOP_SCORES && topLevelScoringClause) || minShouldMatch > 1) {
return new WANDScorer(optionalScorers, minShouldMatch, scoreMode, leadCost);
} else {
return new DisjunctionSumScorer(optionalScorers, scoreMode);
return new DisjunctionSumScorer(optionalScorers, scoreMode, leadCost);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,86 @@
package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;

/**
* A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided
* iterators.
*
* @lucene.internal
*/
public class DisjunctionDISIApproximation extends DocIdSetIterator {
public final class DisjunctionDISIApproximation extends DocIdSetIterator {

final DisiPriorityQueue subIterators;
final long cost;
public static DisjunctionDISIApproximation of(
Collection<DisiWrapper> subIterators, long leadCost) {

return new DisjunctionDISIApproximation(subIterators, leadCost);
}

// Heap of iterators that lead iteration.
private final DisiPriorityQueue leadIterators;
// List of iterators that will likely advance on every call to nextDoc() / advance()
private final DisiWrapper[] otherIterators;
private final long cost;
private DisiWrapper leadTop;
private int minOtherDoc;

public DisjunctionDISIApproximation(Collection<DisiWrapper> subIterators, long leadCost) {
// Using a heap to store disjunctive clauses is great for exhaustive evaluation, when a single
// clause needs to move through the heap on every iteration on average. However, when
// intersecting with a selective filter, it is possible that all clauses need advancing, which
// makes the reordering cost scale in O(N * log(N)) per advance() call when checking clauses
// linearly would scale in O(N).
// To protect against this reordering overhead, we try to have 1.5 clauses or less that advance
// on every advance() call by only putting clauses into the heap as long as Σ min(1, cost /
// leadCost) <= 1.5, or Σ min(leadCost, cost) <= 1.5 * leadCost. Other clauses are checked
// linearly.

List<DisiWrapper> wrappers = new ArrayList<>(subIterators);
// Sort by descending cost.
wrappers.sort(Comparator.<DisiWrapper>comparingLong(w -> w.cost).reversed());

leadIterators = new DisiPriorityQueue(subIterators.size());

long reorderThreshold = leadCost + (leadCost >> 1);
if (reorderThreshold < 0) { // overflow
reorderThreshold = Long.MAX_VALUE;
}
long reorderCost = 0;
while (wrappers.isEmpty() == false) {
DisiWrapper last = wrappers.getLast();
long inc = Math.min(last.cost, leadCost);
if (reorderCost + inc < 0 || reorderCost + inc > reorderThreshold) {
break;
}
leadIterators.add(wrappers.removeLast());
reorderCost += inc;
}

// Make leadIterators not empty. This helps save conditionals in the implementation which are
// rarely tested.
if (leadIterators.size() == 0) {
leadIterators.add(wrappers.removeLast());
}

otherIterators = wrappers.toArray(DisiWrapper[]::new);

public DisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
this.subIterators = subIterators;
long cost = 0;
for (DisiWrapper w : subIterators) {
for (DisiWrapper w : leadIterators) {
cost += w.cost;
}
for (DisiWrapper w : otherIterators) {
cost += w.cost;
}
this.cost = cost;
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
minOtherDoc = Math.min(minOtherDoc, w.doc);
}
leadTop = leadIterators.top();
}

@Override
Expand All @@ -45,29 +106,62 @@ public long cost() {

@Override
public int docID() {
return subIterators.top().doc;
return Math.min(minOtherDoc, leadTop.doc);
}

@Override
public int nextDoc() throws IOException {
DisiWrapper top = subIterators.top();
final int doc = top.doc;
do {
top.doc = top.approximation.nextDoc();
top = subIterators.updateTop();
} while (top.doc == doc);

return top.doc;
if (leadTop.doc < minOtherDoc) {
int curDoc = leadTop.doc;
do {
leadTop.doc = leadTop.approximation.nextDoc();
leadTop = leadIterators.updateTop();
} while (leadTop.doc == curDoc);
return Math.min(leadTop.doc, minOtherDoc);
} else {
return advance(minOtherDoc + 1);
}
}

@Override
public int advance(int target) throws IOException {
DisiWrapper top = subIterators.top();
do {
top.doc = top.approximation.advance(target);
top = subIterators.updateTop();
} while (top.doc < target);
while (leadTop.doc < target) {
leadTop.doc = leadTop.approximation.advance(target);
leadTop = leadIterators.updateTop();
}

return top.doc;
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
if (w.doc < target) {
w.doc = w.approximation.advance(target);
}
minOtherDoc = Math.min(minOtherDoc, w.doc);
}

return Math.min(leadTop.doc, minOtherDoc);
}

/** Return the linked list of iterators positioned on the current doc. */
public DisiWrapper topList() {
if (leadTop.doc < minOtherDoc) {
return leadIterators.topList();
} else {
return computeTopList();
}
}

private DisiWrapper computeTopList() {
assert leadTop.doc >= minOtherDoc;
DisiWrapper topList = null;
if (leadTop.doc == minOtherDoc) {
topList = leadIterators.topList();
}
for (DisiWrapper w : otherIterators) {
if (w.doc == minOtherDoc) {
w.next = topList;
topList = w;
}
}
return topList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public Scorer get(long leadCost) throws IOException {
for (ScorerSupplier ss : scorerSuppliers) {
scorers.add(ss.get(leadCost));
}
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode);
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode, leadCost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
* as they are summed into the result.
* @param subScorers The sub scorers this Scorer should iterate on
*/
DisjunctionMaxScorer(float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode)
DisjunctionMaxScorer(
float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
super(subScorers, scoreMode);
super(subScorers, scoreMode, leadCost);
this.subScorers = subScorers;
this.tieBreakerMultiplier = tieBreakerMultiplier;
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,34 @@
/** Base class for Scorers that score disjunctions. */
abstract class DisjunctionScorer extends Scorer {

private final int numClauses;
private final boolean needsScores;

private final DisiPriorityQueue subScorers;
private final DocIdSetIterator approximation;
private final DisjunctionDISIApproximation approximation;
private final TwoPhase twoPhase;

protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer, false);
this.subScorers.add(w);
}
this.numClauses = subScorers.size();
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new DisjunctionDISIApproximation(this.subScorers);

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : this.subScorers) {
List<DisiWrapper> wrappers = new ArrayList<>();
for (Scorer scorer : subScorers) {
DisiWrapper w = new DisiWrapper(scorer, false);
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
wrappers.add(w);
}
this.approximation = new DisjunctionDISIApproximation(wrappers, leadCost);

if (hasApproximation == false) { // no sub scorer supports approximations
twoPhase = null;
Expand Down Expand Up @@ -91,7 +88,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(DisjunctionScorer.this.subScorers.size()) {
new PriorityQueue<DisiWrapper>(numClauses) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
Expand All @@ -116,7 +113,7 @@ public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null; ) {
for (DisiWrapper w = DisjunctionScorer.this.approximation.topList(); w != null; ) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
Expand Down Expand Up @@ -160,12 +157,12 @@ public float matchCost() {

@Override
public final int docID() {
return subScorers.top().doc;
return approximation.docID();
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();
return approximation.topList();
} else {
return twoPhase.getSubMatches();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
*
* @param subScorers Array of at least two subscorers.
*/
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(subScorers, scoreMode);
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
super(subScorers, scoreMode, leadCost);
this.scorers = subScorers;
}

Expand Down
Loading

0 comments on commit bc341f2

Please sign in to comment.