Skip to content

Commit

Permalink
Use the new loadIntoBitSet API to speed up dense conjunctions. (apa…
Browse files Browse the repository at this point in the history
…che#14080)

Now that loading doc IDs into a bit set is much more efficient thanks to
auto-vectorization, it has become tempting to evaluate dense conjunctions by
and-ing bit sets.
  • Loading branch information
jpountz authored Dec 19, 2024
1 parent aef16da commit a337d14
Show file tree
Hide file tree
Showing 7 changed files with 484 additions and 10 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ Optimizations

* GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand)

* GITHUB#14080: Use the `DocIdSetIterator#loadIntoBitSet` API to speed up dense
conjunctions. (Adrien Grand)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@ BulkScorer optionalBulkScorer() throws IOException {
BulkScorer filteredOptionalBulkScorer() throws IOException {
if (subs.get(Occur.MUST).isEmpty() == false
|| subs.get(Occur.FILTER).isEmpty()
|| scoreMode != ScoreMode.TOP_SCORES
|| (scoreMode.needsScores() && scoreMode != ScoreMode.TOP_SCORES)
|| subs.get(Occur.SHOULD).size() <= 1
|| minShouldMatch > 1) {
|| minShouldMatch != 1) {
return null;
}
long cost = cost();
Expand All @@ -318,13 +318,28 @@ BulkScorer filteredOptionalBulkScorer() throws IOException {
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
filters.add(ss.get(cost));
}
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();
if (scoreMode == ScoreMode.TOP_SCORES) {
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
// In the beginning of this method, we exited early if the score mode is not either TOP_SCORES
// or a score mode that doesn't need scores.
assert scoreMode.needsScores() == false;
filters.add(new DisjunctionSumScorer(optionalScorers, scoreMode, cost));

if (filters.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& cost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(filters.stream().map(Scorer::iterator).toList());
}

return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.emptyList()));
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
}

// Return a BulkScorer for the required clauses only
Expand Down Expand Up @@ -378,7 +393,14 @@ private BulkScorer requiredBulkScorer() throws IOException {
&& requiredScoring.size() + requiredNoScoring.size() >= 2
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
if (requiredScoring.isEmpty()
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& leadCost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(
requiredNoScoring.stream().map(Scorer::iterator).toList());
} else {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
}
}
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring = Collections.singletonList(new BlockMaxConjunctionScorer(requiredScoring));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

/**
* BulkScorer implementation of {@link ConjunctionScorer} that is specialized for dense clauses.
* Whenever sensible, it intersects clauses by loading their matches into a bit set and computing
* the intersection of clauses by and-ing these bit sets.
*/
final class DenseConjunctionBulkScorer extends BulkScorer {

// Use a small-ish window size to make sure that we can take advantage of gaps in the postings of
// clauses that are not leading iteration.
static final int WINDOW_SIZE = 4096;
// Only use bit sets to compute the intersection if more than 1/32th of the docs are expected to
// match. Experiments suggested that values that are a bit higher than this would work better, but
// we're erring on the conservative side.
static final int DENSITY_THRESHOLD_INVERSE = Long.SIZE / 2;

private final DocIdSetIterator lead;
private final List<DocIdSetIterator> others;

private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE);
private final FixedBitSet clauseWindowMatches = new FixedBitSet(WINDOW_SIZE);
private final DocIdStreamView docIdStreamView = new DocIdStreamView();

DenseConjunctionBulkScorer(List<DocIdSetIterator> iterators) {
if (iterators.size() <= 1) {
throw new IllegalArgumentException("Expected 2 or more clauses, got " + iterators.size());
}
iterators = new ArrayList<>(iterators);
iterators.sort(Comparator.comparingLong(DocIdSetIterator::cost));
lead = iterators.get(0);
others = List.copyOf(iterators.subList(1, iterators.size()));
}

@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
for (DocIdSetIterator it : others) {
min = Math.max(min, it.docID());
}

if (lead.docID() < min) {
lead.advance(min);
}

if (lead.docID() >= max) {
return lead.docID();
}

List<DocIdSetIterator> otherIterators = this.others;
DocIdSetIterator collectorIterator = collector.competitiveIterator();
if (collectorIterator != null) {
otherIterators = new ArrayList<>(otherIterators);
otherIterators.add(collectorIterator);
}

final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new);

int windowMax;
do {
windowMax = (int) Math.min(max, (long) lead.docID() + WINDOW_SIZE);
scoreWindowUsingBitSet(collector, acceptDocs, others, windowMax);
} while (windowMax < max);

return lead.docID();
}

private static int advance(FixedBitSet set, int i) {
if (i >= WINDOW_SIZE) {
return DocIdSetIterator.NO_MORE_DOCS;
} else {
return set.nextSetBit(i);
}
}

private void scoreWindowUsingBitSet(
LeafCollector collector, Bits acceptDocs, DocIdSetIterator[] others, int max)
throws IOException {
assert windowMatches.scanIsEmpty();
assert clauseWindowMatches.scanIsEmpty();

int offset = lead.docID();
lead.intoBitSet(acceptDocs, max, windowMatches, offset);

int upTo = 0;
for (;
upTo < others.length
&& windowMatches.cardinality() >= WINDOW_SIZE / DENSITY_THRESHOLD_INVERSE;
upTo++) {
DocIdSetIterator other = others[upTo];
if (other.docID() < offset) {
other.advance(offset);
}
// No need to apply acceptDocs on other clauses since we already applied live docs on the
// leading clause.
other.intoBitSet(null, max, clauseWindowMatches, offset);
windowMatches.and(clauseWindowMatches);
clauseWindowMatches.clear();
}

if (upTo < others.length) {
// If the leading clause is sparse on this doc ID range or if the intersection became sparse
// after applying a few clauses, we finish evaluating the intersection using the traditional
// leap-frog approach. This proved important with a query such as "+secretary +of +state" on
// wikibigall, where the intersection becomes sparse after intersecting "secretary" and
// "state".
advanceHead:
for (int windowMatch = windowMatches.nextSetBit(0);
windowMatch != DocIdSetIterator.NO_MORE_DOCS; ) {
int doc = offset + windowMatch;
for (int i = upTo; i < others.length; ++i) {
DocIdSetIterator other = others[i];
int otherDoc = other.docID();
if (otherDoc < doc) {
otherDoc = other.advance(doc);
}
if (doc != otherDoc) {
int clearUpTo = Math.min(WINDOW_SIZE, otherDoc - offset);
windowMatches.clear(windowMatch, clearUpTo);
windowMatch = advance(windowMatches, clearUpTo);
continue advanceHead;
}
}
windowMatch = advance(windowMatches, windowMatch + 1);
}
}

docIdStreamView.offset = offset;
collector.collect(docIdStreamView);
windowMatches.clear();

// If another clause is more advanced than lead1 then advance lead1, it's important to take
// advantage of large gaps in the postings lists of other clauses.
int maxOtherDocID = -1;
for (DocIdSetIterator other : others) {
maxOtherDocID = Math.max(maxOtherDocID, other.docID());
}
if (lead.docID() < maxOtherDocID) {
lead.advance(maxOtherDocID);
}
}

@Override
public long cost() {
return lead.cost();
}

final class DocIdStreamView extends DocIdStream {

int offset;

@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
int offset = this.offset;
long[] bitArray = windowMatches.getBits();
for (int idx = 0; idx < bitArray.length; idx++) {
long bits = bitArray[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
consumer.accept(offset + ((idx << 6) | ntz));
bits ^= 1L << ntz;
}
}
}

@Override
public int count() throws IOException {
return windowMatches.cardinality();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

/**
* A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided
Expand Down Expand Up @@ -141,6 +143,23 @@ public int advance(int target) throws IOException {
return Math.min(leadTop.doc, minOtherDoc);
}

@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
while (leadTop.doc < upTo) {
leadTop.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
leadTop.doc = leadTop.approximation.docID();
leadTop = leadIterators.updateTop();
}

minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
w.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
w.doc = w.approximation.docID();
minOtherDoc = Math.min(minOtherDoc, w.doc);
}
}

/** Return the linked list of iterators positioned on the current doc. */
public DisiWrapper topList() {
if (leadTop.doc < minOtherDoc) {
Expand Down
11 changes: 9 additions & 2 deletions lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public final class FixedBitSet extends BitSet {
private static final long BASE_RAM_BYTES_USED =
RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class);

// An array that is small enough to use reasonable amounts of RAM and large enough to allow
// Arrays#mismatch to use SIMD instructions and multiple registers under the hood.
private static long[] ZEROES = new long[32];

private final long[] bits; // Array of longs holding the bits
private final int numBits; // The number of bits in use
private final int numWords; // The exact number of longs needed to hold numBits (<= bits.length)
Expand Down Expand Up @@ -470,8 +474,11 @@ public boolean scanIsEmpty() {
// Depends on the ghost bits being clear!
final int count = numWords;

for (int i = 0; i < count; i++) {
if (bits[i] != 0) return false;
for (int i = 0; i < count; i += ZEROES.length) {
int cmpLen = Math.min(ZEROES.length, bits.length - i);
if (Arrays.equals(bits, i, i + cmpLen, ZEROES, 0, cmpLen) == false) {
return false;
}
}

return true;
Expand Down
Loading

0 comments on commit a337d14

Please sign in to comment.