Skip to content

Commit

Permalink
Move BooleanScorer to work on top of Scorers rather than BulkScorers. (
Browse files Browse the repository at this point in the history
…#13931)

I was looking at some queries where Lucene performs significantly worse than
Tantivy at https://tantivy-search.github.io/bench/, and found out that we get
quite some overhead from implementing `BooleanScorer` on top of `BulkScorer`
(effectively implemented by `DefaultBulkScorer` since it only runs term queries
as boolean clauses) rather than `Scorer` directly.

The `CountOrHighHigh` and `CountOrHighMed` tasks are a bit noisy on my machine,
so I did 3 runs on wikibigall, and all of them had speedups for these two
tasks, often with a very low p-value.

In theory, this change could make things slower when the inner query has a
specialized bulk scorer, such as `MatchAllDocsQuery` or a conjunction. It does
feel right to optimize for term queries though.
  • Loading branch information
jpountz authored Oct 21, 2024
1 parent 86457a5 commit a779a64
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 197 deletions.
5 changes: 5 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Optimizations

* GITHUB#13930: Use growNoCopy when copying bytes in BytesRefBuilder. (Ignacio Vera)

* GITHUB#13931: Refactored `BooleanScorer` to evaluate matches of sub clauses
using the `Scorer` abstraction rather than the `BulkScorer` abstraction. This
speeds up exhaustive evaluation of disjunctions of term queries.
(Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
191 changes: 87 additions & 104 deletions lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.PriorityQueue;

/**
* {@link BulkScorer} that is used for pure disjunctions and disjunctions that have low values of
* {@link BooleanQuery.Builder#setMinimumNumberShouldMatch(int)} and dense clauses. This scorer
* scores documents by batches of 2048 docs.
* scores documents by batches of 4,096 docs.
*/
final class BooleanScorer extends BulkScorer {

Expand All @@ -41,71 +42,32 @@ static class Bucket {
int freq;
}

private class BulkScorerAndDoc {
final BulkScorer scorer;
final long cost;
int next;

BulkScorerAndDoc(BulkScorer scorer) {
this.scorer = scorer;
this.cost = scorer.cost();
this.next = -1;
}

void advance(int min) throws IOException {
score(orCollector, null, min, min);
}

void score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
next = scorer.score(collector, acceptDocs, min, max);
}
}

// See WANDScorer for an explanation
private static long cost(Collection<BulkScorer> scorers, int minShouldMatch) {
final PriorityQueue<BulkScorer> pq =
new PriorityQueue<BulkScorer>(scorers.size() - minShouldMatch + 1) {
@Override
protected boolean lessThan(BulkScorer a, BulkScorer b) {
return a.cost() > b.cost();
}
};
for (BulkScorer scorer : scorers) {
pq.insertWithOverflow(scorer);
}
long cost = 0;
for (BulkScorer scorer = pq.pop(); scorer != null; scorer = pq.pop()) {
cost += scorer.cost();
}
return cost;
}

static final class HeadPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class HeadPriorityQueue extends PriorityQueue<DisiWrapper> {

public HeadPriorityQueue(int maxSize) {
super(maxSize);
}

@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
return a.next < b.next;
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.doc < b.doc;
}
}

static final class TailPriorityQueue extends PriorityQueue<BulkScorerAndDoc> {
static final class TailPriorityQueue extends PriorityQueue<DisiWrapper> {

public TailPriorityQueue(int maxSize) {
super(maxSize);
}

@Override
protected boolean lessThan(BulkScorerAndDoc a, BulkScorerAndDoc b) {
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.cost < b.cost;
}

public BulkScorerAndDoc get(int i) {
public DisiWrapper get(int i) {
Objects.checkIndex(i, size());
return (BulkScorerAndDoc) getHeapArray()[1 + i];
return (DisiWrapper) getHeapArray()[1 + i];
}
}

Expand All @@ -115,39 +77,14 @@ public BulkScorerAndDoc get(int i) {
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];

final BulkScorerAndDoc[] leads;
final DisiWrapper[] leads;
final HeadPriorityQueue head;
final TailPriorityQueue tail;
final Score score = new Score();
final int minShouldMatch;
final long cost;
final boolean needsScores;

final class OrCollector implements LeafCollector {
Scorable scorer;

@Override
public void setScorer(Scorable scorer) {
this.scorer = scorer;
}

@Override
public void collect(int doc) throws IOException {
final int i = doc & MASK;
final int idx = i >>> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}

final OrCollector orCollector = new OrCollector();

final class DocIdStreamView extends DocIdStream {

int base;
Expand Down Expand Up @@ -194,7 +131,7 @@ public int count() throws IOException {

private final DocIdStreamView docIdStreamView = new DocIdStreamView();

BooleanScorer(Collection<BulkScorer> scorers, int minShouldMatch, boolean needsScores) {
BooleanScorer(Collection<Scorer> scorers, int minShouldMatch, boolean needsScores) {
if (minShouldMatch < 1 || minShouldMatch > scorers.size()) {
throw new IllegalArgumentException(
"minShouldMatch should be within 1..num_scorers. Got " + minShouldMatch);
Expand All @@ -211,38 +148,71 @@ public int count() throws IOException {
} else {
buckets = null;
}
this.leads = new BulkScorerAndDoc[scorers.size()];
this.leads = new DisiWrapper[scorers.size()];
this.head = new HeadPriorityQueue(scorers.size() - minShouldMatch + 1);
this.tail = new TailPriorityQueue(minShouldMatch - 1);
this.minShouldMatch = minShouldMatch;
this.needsScores = needsScores;
for (BulkScorer scorer : scorers) {
final BulkScorerAndDoc evicted = tail.insertWithOverflow(new BulkScorerAndDoc(scorer));
LongArrayList costs = new LongArrayList(scorers.size());
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
costs.add(w.cost);
final DisiWrapper evicted = tail.insertWithOverflow(w);
if (evicted != null) {
head.add(evicted);
}
}
this.cost = cost(scorers, minShouldMatch);
this.cost = ScorerUtil.costWithMinShouldMatch(costs.stream(), costs.size(), minShouldMatch);
}

@Override
public long cost() {
return cost;
}

private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
throws IOException {
boolean needsScores = BooleanScorer.this.needsScores;
long[] matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;

DocIdSetIterator it = w.iterator;
Scorer scorer = w.scorer;
int doc = w.doc;
if (doc < min) {
doc = it.advance(min);
}
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
final int idx = i >> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
bucket.score += scorer.score();
}
}
}
}

w.doc = doc;
}

private void scoreWindowIntoBitSetAndReplay(
LeafCollector collector,
Bits acceptDocs,
int base,
int min,
int max,
BulkScorerAndDoc[] scorers,
DisiWrapper[] scorers,
int numScorers)
throws IOException {
for (int i = 0; i < numScorers; ++i) {
final BulkScorerAndDoc scorer = scorers[i];
assert scorer.next < max;
scorer.score(orCollector, acceptDocs, min, max);
final DisiWrapper w = scorers[i];
assert w.doc < max;
scoreDisiWrapperIntoBitSet(w, acceptDocs, min, max);
}

docIdStreamView.base = base;
Expand All @@ -251,20 +221,20 @@ private void scoreWindowIntoBitSetAndReplay(
Arrays.fill(matching, 0L);
}

private BulkScorerAndDoc advance(int min) throws IOException {
private DisiWrapper advance(int min) throws IOException {
assert tail.size() == minShouldMatch - 1;
final HeadPriorityQueue head = this.head;
final TailPriorityQueue tail = this.tail;
BulkScorerAndDoc headTop = head.top();
BulkScorerAndDoc tailTop = tail.top();
while (headTop.next < min) {
DisiWrapper headTop = head.top();
DisiWrapper tailTop = tail.top();
while (headTop.doc < min) {
if (tailTop == null || headTop.cost <= tailTop.cost) {
headTop.advance(min);
headTop.doc = headTop.iterator.advance(min);
headTop = head.updateTop();
} else {
// swap the top of head and tail
final BulkScorerAndDoc previousHeadTop = headTop;
tailTop.advance(min);
final DisiWrapper previousHeadTop = headTop;
tailTop.doc = tailTop.iterator.advance(min);
headTop = head.updateTop(tailTop);
tailTop = tail.updateTop(previousHeadTop);
}
Expand All @@ -282,9 +252,11 @@ private void scoreWindowMultipleScorers(
throws IOException {
while (maxFreq < minShouldMatch && maxFreq + tail.size() >= minShouldMatch) {
// a match is still possible
final BulkScorerAndDoc candidate = tail.pop();
candidate.advance(windowMin);
if (candidate.next < windowMax) {
final DisiWrapper candidate = tail.pop();
if (candidate.doc < windowMin) {
candidate.doc = candidate.iterator.advance(windowMin);
}
if (candidate.doc < windowMax) {
leads[maxFreq++] = candidate;
} else {
head.add(candidate);
Expand All @@ -304,49 +276,60 @@ private void scoreWindowMultipleScorers(

// Push back scorers into head and tail
for (int i = 0; i < maxFreq; ++i) {
final BulkScorerAndDoc evicted = head.insertWithOverflow(leads[i]);
final DisiWrapper evicted = head.insertWithOverflow(leads[i]);
if (evicted != null) {
tail.add(evicted);
}
}
}

private void scoreWindowSingleScorer(
BulkScorerAndDoc bulkScorer,
DisiWrapper w,
LeafCollector collector,
Bits acceptDocs,
int windowMin,
int windowMax,
int max)
throws IOException {
assert tail.size() == 0;
final int nextWindowBase = head.top().next & ~MASK;
final int nextWindowBase = head.top().doc & ~MASK;
final int end = Math.max(windowMax, Math.min(max, nextWindowBase));

bulkScorer.score(collector, acceptDocs, windowMin, end);
DocIdSetIterator it = w.iterator;
int doc = w.doc;
if (doc < windowMin) {
doc = it.advance(windowMin);
}
collector.setScorer(w.scorer);
for (; doc < end; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
collector.collect(doc);
}
}
w.doc = doc;

// reset the scorer that should be used for the general case
collector.setScorer(score);
}

private BulkScorerAndDoc scoreWindow(
BulkScorerAndDoc top, LeafCollector collector, Bits acceptDocs, int min, int max)
private DisiWrapper scoreWindow(
DisiWrapper top, LeafCollector collector, Bits acceptDocs, int min, int max)
throws IOException {
final int windowBase = top.next & ~MASK; // find the window that the next match belongs to
final int windowBase = top.doc & ~MASK; // find the window that the next match belongs to
final int windowMin = Math.max(min, windowBase);
final int windowMax = Math.min(max, windowBase + SIZE);

// Fill 'leads' with all scorers from 'head' that are in the right window
leads[0] = head.pop();
int maxFreq = 1;
while (head.size() > 0 && head.top().next < windowMax) {
while (head.size() > 0 && head.top().doc < windowMax) {
leads[maxFreq++] = head.pop();
}

if (minShouldMatch == 1 && maxFreq == 1) {
// special case: only one scorer can match in the current window,
// we can collect directly
final BulkScorerAndDoc bulkScorer = leads[0];
final DisiWrapper bulkScorer = leads[0];
scoreWindowSingleScorer(bulkScorer, collector, acceptDocs, windowMin, windowMax, max);
return head.add(bulkScorer);
} else {
Expand All @@ -360,11 +343,11 @@ private BulkScorerAndDoc scoreWindow(
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(score);

BulkScorerAndDoc top = advance(min);
while (top.next < max) {
DisiWrapper top = advance(min);
while (top.doc < max) {
top = scoreWindow(top, collector, acceptDocs, min, max);
}

return top.next;
return top.doc;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ BulkScorer optionalBulkScorer() throws IOException {
return new MaxScoreBulkScorer(maxDoc, optionalScorers);
}

List<BulkScorer> optional = new ArrayList<BulkScorer>();
List<Scorer> optional = new ArrayList<Scorer>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optional.add(ss.bulkScorer());
optional.add(ss.get(Long.MAX_VALUE));
}

return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());
Expand Down
Loading

0 comments on commit a779a64

Please sign in to comment.