Skip to content

Commit

Permalink
Speedup concurrent multi-segment HNWS graph search (#12962)
Browse files Browse the repository at this point in the history
Speedup concurrent multi-segment HNWS graph search by exchanging 
the global top candidated collected so far across segments. These global top 
candidates set the minimum threshold that new candidates need to pass
 to be considered. This allows earlier stopping for segments that don't have 
good candidates.
  • Loading branch information
mayya-sharipova authored Feb 6, 2024
1 parent 635d090 commit d095ed0
Show file tree
Hide file tree
Showing 19 changed files with 901 additions and 22 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ Optimizations

* GITHUB#13036 Optimize counts on two clause term disjunctions. (Adrien Grand, Johannes Fredén)

* GITHUB#12962: Speedup concurrent multi-segment HNWS graph search (Mayya Sharipova, Tom Veasey)


Bug Fixes
---------------------
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
Expand Down
1 change: 1 addition & 0 deletions lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
exports org.apache.lucene.search;
exports org.apache.lucene.search.comparators;
exports org.apache.lucene.search.similarities;
exports org.apache.lucene.search.knn;
exports org.apache.lucene.store;
exports org.apache.lucene.util;
exports org.apache.lucene.util.automaton;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public abstract class AbstractKnnCollector implements KnnCollector {

private long visitedCount;
protected long visitedCount;
private final long visitLimit;
private final int k;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -79,11 +81,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
filterWeight = null;
}

KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> searchLeaf(context, filterWeight));
tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager));
}
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);

Expand All @@ -95,8 +98,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight);
private TopDocs searchLeaf(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
Expand All @@ -105,12 +110,14 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IO
return results;
}

private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
private TopDocs getLeafResults(
LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager)
throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();

if (filterWeight == null) {
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager);
}

Scorer scorer = filterWeight.scorer(ctx);
Expand All @@ -128,7 +135,7 @@ private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throw
}

// Perform the approximate kNN search
TopDocs results = approximateSearch(ctx, acceptDocs, cost);
TopDocs results = approximateSearch(ctx, acceptDocs, cost, knnCollectorManager);
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
return results;
} else {
Expand All @@ -155,8 +162,16 @@ protected boolean match(int doc) {
}
}

protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
return new TopKnnCollectorManager(k, searcher);
}

protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException;
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException;

abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;

Expand Down Expand Up @@ -75,10 +76,23 @@ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
}

@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
}
if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
Expand Down Expand Up @@ -76,10 +77,23 @@ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
}

@Override
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit)
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
TopDocs results =
context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) {
// The field does not exist or does not index vectors
return TopDocsCollector.EMPTY_TOPDOCS;
}
if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) {
return TopDocsCollector.EMPTY_TOPDOCS;
}
context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs);
TopDocs results = knnCollector.topDocs();
return results != null ? results : NO_RESULTS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
*
* @lucene.experimental
*/
public final class TopKnnCollector extends AbstractKnnCollector {
public class TopKnnCollector extends AbstractKnnCollector {

private final NeighborQueue queue;
protected final NeighborQueue queue;

/**
* @param k the number of neighbors to collect
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.search.knn;

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.KnnCollector;

/**
* KnnCollectorManager responsible for creating {@link KnnCollector} instances. Useful to create
* {@link KnnCollector} instances that share global state across leaves, such a global queue of
* results collected so far.
*/
public interface KnnCollectorManager {

/**
* Return a new {@link KnnCollector} instance.
*
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @param context the leaf reader context
*/
KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.search.knn;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.util.hnsw.FloatHeap;

/**
* MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results
* across segments through a shared global queue.
*
* @lucene.experimental
*/
public final class MultiLeafTopKnnCollector extends TopKnnCollector {

// greediness of globally non-competitive search: (0,1]
private static final float DEFAULT_GREEDINESS = 0.9f;
// the global queue of the highest similarities collected so far across all segments
private final BlockingFloatHeap globalSimilarityQueue;
// the local queue of the highest similarities if we are not competitive globally
// the size of this queue is defined by greediness
private final FloatHeap nonCompetitiveQueue;
private final float greediness;
// the queue of the local similarities to periodically update with the global queue
private final FloatHeap updatesQueue;
// interval to synchronize the local and global queues, as a number of visited vectors
private final int interval = 0xff; // 255
private boolean kResultsCollected = false;
private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY;

/**
* @param k the number of neighbors to collect
* @param visitLimit how many vector nodes the results are allowed to visit
*/
public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) {
super(k, visitLimit);
this.greediness = DEFAULT_GREEDINESS;
this.globalSimilarityQueue = globalSimilarityQueue;
this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k)));
this.updatesQueue = new FloatHeap(k);
}

@Override
public boolean collect(int docId, float similarity) {
boolean localSimUpdated = queue.insertWithOverflow(docId, similarity);
boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k());
if (firstKResultsCollected) {
kResultsCollected = true;
}
updatesQueue.offer(similarity);
boolean globalSimUpdated = nonCompetitiveQueue.offer(similarity);

if (kResultsCollected) {
// as we've collected k results, we can start do periodic updates with the global queue
if (firstKResultsCollected || (visitedCount & interval) == 0) {
cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap());
updatesQueue.clear();
globalSimUpdated = true;
}
}
return localSimUpdated || globalSimUpdated;
}

@Override
public float minCompetitiveSimilarity() {
if (kResultsCollected == false) {
return Float.NEGATIVE_INFINITY;
}
return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim));
}

@Override
public TopDocs topDocs() {
assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed";
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (int i = 1; i <= scoreDocs.length; i++) {
scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore());
queue.pop();
}
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs);
}

@Override
public String toString() {
return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.search.knn;

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;

/**
* TopKnnCollectorManager responsible for creating {@link TopKnnCollector} instances. When
* concurrency is supported, the {@link BlockingFloatHeap} is used to track the global top scores
* collected across all leaves.
*/
public class TopKnnCollectorManager implements KnnCollectorManager {

// the number of docs to collect
private final int k;
// the global score queue used to track the top scores collected across all leaves
private final BlockingFloatHeap globalScoreQueue;

public TopKnnCollectorManager(int k, IndexSearcher indexSearcher) {
boolean isMultiSegments = indexSearcher.getIndexReader().leaves().size() > 1;
this.k = k;
this.globalScoreQueue = isMultiSegments ? new BlockingFloatHeap(k) : null;
}

/**
* Return a new {@link TopKnnCollector} instance.
*
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
* @param context the leaf reader context
*/
@Override
public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context)
throws IOException {
if (globalScoreQueue == null) {
return new TopKnnCollector(k, visitedLimit);
} else {
return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue);
}
}
}
Loading

0 comments on commit d095ed0

Please sign in to comment.