Skip to content

Commit

Permalink
Add mult-leaf optimizations for diversify children collector (#13121)
Browse files Browse the repository at this point in the history
This adds multi-leaf optimizations for diversified children collector. This means as children vectors are collected within a block join, we can share information between leaves to speed up vector search.

To make this happen, I refactored the multi-leaf collector slightly. Now, instead of inheriting from TopKnnCollector, we inject a inner collector.
  • Loading branch information
benwtrent authored Mar 5, 2024
1 parent 51122f8 commit 012b959
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 37 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ Optimizations

* GITHUB#13085: Remove unnecessary toString() / substring() calls to save some String allocations (Dmitry Cherniachenko)

* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962.
(Ben Trent)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public final int k() {
@Override
public abstract boolean collect(int docId, float similarity);

public abstract int numCollected();

@Override
public abstract float minCompetitiveSimilarity();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ public TopDocs topDocs() {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}

@Override
public int numCollected() {
return queue.size();
}

@Override
public String toString() {
return "TopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,9 @@ public TopDocs topDocs() {
return new TopDocs(
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
}

@Override
public int numCollected() {
return scoreDocList.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@

package org.apache.lucene.search.knn;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.AbstractKnnCollector;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.util.hnsw.FloatHeap;

/**
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
* MultiLeafKnnCollector is a specific KnnCollector that can exchange the top collected results
* across segments through a shared global queue.
*
* @lucene.experimental
*/
public final class MultiLeafTopKnnCollector extends TopKnnCollector {
public final class MultiLeafKnnCollector implements KnnCollector {

// greediness of globally non-competitive search: (0,1]
private static final float DEFAULT_GREEDINESS = 0.9f;
Expand All @@ -46,23 +45,55 @@ public final class MultiLeafTopKnnCollector extends TopKnnCollector {
private final int interval = 0xff; // 255
private boolean kResultsCollected = false;
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;
private final AbstractKnnCollector subCollector;

/**
* Create a new MultiLeafKnnCollector.
*
* @param k the number of neighbors to collect
* @param visitLimit how many vector nodes the results are allowed to visit
* @param globalSimilarityQueue the global queue of the highest similarities collected so far
* across all segments
* @param subCollector the local collector
*/
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
super(k, visitLimit);
public MultiLeafKnnCollector(
int k, BlockingFloatHeap globalSimilarityQueue, AbstractKnnCollector subCollector) {
this.greediness = DEFAULT_GREEDINESS;
this.subCollector = subCollector;
this.globalSimilarityQueue = globalSimilarityQueue;
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
this.updatesQueue = new FloatHeap(k);
}

@Override
public boolean earlyTerminated() {
return subCollector.earlyTerminated();
}

@Override
public void incVisitedCount(int count) {
subCollector.incVisitedCount(count);
}

@Override
public long visitedCount() {
return subCollector.visitedCount();
}

@Override
public long visitLimit() {
return subCollector.visitLimit();
}

@Override
public int k() {
return subCollector.k();
}

@Override
public boolean collect(int docId, float similarity) {
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
boolean localSimUpdated = subCollector.collect(docId, similarity);
boolean firstKResultsCollected =
(kResultsCollected == false && subCollector.numCollected() == k());
if (firstKResultsCollected) {
kResultsCollected = true;
}
Expand All @@ -71,7 +102,7 @@ public boolean collect(int docId, float similarity) {

if (kResultsCollected) {
// as we've collected k results, we can start do periodic updates with the global queue
if (firstKResultsCollected || (visitedCount & interval) == 0) {
if (firstKResultsCollected || (subCollector.visitedCount() & interval) == 0) {
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
updatesQueue.clear();
globalSimUpdated = true;
Expand All @@ -85,26 +116,18 @@ public float minCompetitiveSimilarity() {
if (kResultsCollected == false) {
return Float.NEGATIVE_INFINITY;
}
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
return Math.max(
subCollector.minCompetitiveSimilarity(),
Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
}

@Override
public TopDocs topDocs() {
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int i = 1; i <= scoreDocs.length; i++) {
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
queue.pop();
}
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
return subCollector.topDocs();
}

@Override
public String toString() {
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
return "MultiLeafKnnCollector[subCollector=" + subCollector + "]";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;

Expand Down Expand Up @@ -48,12 +49,11 @@ public TopKnnCollectorManager(int k, IndexSearcher indexSearcher) {
* @param context the leaf reader context
*/
@Override
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
throws IOException {
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
if (globalScoreQueue == null) {
return new TopKnnCollector(k, visitedLimit);
} else {
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
return new MultiLeafKnnCollector(k, globalScoreQueue, new TopKnnCollector(k, visitedLimit));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}

@Override
Expand All @@ -136,12 +136,10 @@ protected TopDocs approximateSearch(
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context);
if (collector == null) {
return NO_RESULTS;
}
KnnCollector collector =
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
return collector.topDocs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter);
return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter, searcher);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ public TopDocs topDocs() {
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}

@Override
public int numCollected() {
return heap.size();
}

/**
* This is a minimum binary heap, inspired by {@link org.apache.lucene.util.LongHeap}. But instead
* of encoding and using `long` values. Node ids and scores are kept separate. Additionally, this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.MultiLeafKnnCollector;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;

/**
* DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link
Expand All @@ -32,16 +36,20 @@ public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollec
private final int k;
// filter identifying the parent documents.
private final BitSetProducer parentsFilter;
private final BlockingFloatHeap globalScoreQueue;

/**
* Constructor
*
* @param k - the number of top k vectors to collect
* @param parentsFilter Filter identifying the parent documents.
*/
public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) {
public DiversifyingNearestChildrenKnnCollectorManager(
int k, BitSetProducer parentsFilter, IndexSearcher indexSearcher) {
this.k = k;
this.parentsFilter = parentsFilter;
this.globalScoreQueue =
indexSearcher.getIndexReader().leaves().size() > 1 ? new BlockingFloatHeap(k) : null;
}

/**
Expand All @@ -51,12 +59,18 @@ public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer pare
* @param context the leaf reader context
*/
@Override
public DiversifyingNearestChildrenKnnCollector newCollector(
int visitedLimit, LeafReaderContext context) throws IOException {
public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException {
BitSet parentBitSet = parentsFilter.getBitSet(context);
if (parentBitSet == null) {
return null;
}
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
if (globalScoreQueue == null) {
return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet);
} else {
return new MultiLeafKnnCollector(
k,
globalScoreQueue,
new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet));
}
}
}

0 comments on commit 012b959

Please sign in to comment.