diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 892648077507..eb16b3be8c1d 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -244,6 +244,9 @@ Optimizations * GITHUB#13036 Optimize counts on two clause term disjunctions. (Adrien Grand, Johannes Fredén) +* GITHUB#12962: Speedup concurrent multi-segment HNWS graph search (Mayya Sharipova, Tom Veasey) + + Bug Fixes --------------------- * GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh) diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 2bddbdd81a94..b4619dc53a09 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -42,6 +42,7 @@ exports org.apache.lucene.search; exports org.apache.lucene.search.comparators; exports org.apache.lucene.search.similarities; + exports org.apache.lucene.search.knn; exports org.apache.lucene.store; exports org.apache.lucene.util; exports org.apache.lucene.util.automaton; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java index 45586e29b053..a75465679f42 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnCollector.java @@ -23,7 +23,7 @@ */ public abstract class AbstractKnnCollector implements KnnCollector { - private long visitedCount; + protected long visitedCount; private final long visitLimit; private final int k; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index f0de3ea7a396..6c2c50a03e43 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -29,6 +29,8 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.TopKnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -79,11 +81,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } + KnnCollectorManager knnCollectorManager = getKnnCollectorManager(k, indexSearcher); TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight)); + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -95,8 +98,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return createRewrittenQuery(reader, topK); } - private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { - TopDocs results = getLeafResults(ctx, filterWeight); + private TopDocs searchLeaf( + LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) + throws IOException { + TopDocs results = getLeafResults(ctx, filterWeight, knnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -105,12 +110,14 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IO return results; } - private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException { + private TopDocs getLeafResults( + LeafReaderContext ctx, Weight filterWeight, KnnCollectorManager knnCollectorManager) + throws IOException { Bits liveDocs = ctx.reader().getLiveDocs(); int maxDoc = ctx.reader().maxDoc(); if (filterWeight == null) { - return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE); + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager); } Scorer scorer = filterWeight.scorer(ctx); @@ -128,7 +135,7 @@ private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throw } // Perform the approximate kNN search - TopDocs results = approximateSearch(ctx, acceptDocs, cost); + TopDocs results = approximateSearch(ctx, acceptDocs, cost, knnCollectorManager); if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) { return results; } else { @@ -155,8 +162,16 @@ protected boolean match(int doc) { } } + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + return new TopKnnCollectorManager(k, searcher); + } + protected abstract TopDocs approximateSearch( - LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException; + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) + throws IOException; abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException; diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 681ed2b2d9ff..a07342bff33f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; @@ -75,10 +76,23 @@ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { } @Override - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - TopDocs results = - context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return TopDocsCollector.EMPTY_TOPDOCS; + } + if (Math.min(knnCollector.k(), context.reader().getByteVectorValues(fi.name).size()) == 0) { + return TopDocsCollector.EMPTY_TOPDOCS; + } + context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); + TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 63b4f8821595..3d8430a45ff4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; @@ -76,10 +77,23 @@ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { } @Override - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - TopDocs results = - context.reader().searchNearestVectors(field, target, k, acceptDocs, visitedLimit); + KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); + FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field); + if (fi == null || fi.getVectorDimension() == 0) { + // The field does not exist or does not index vectors + return TopDocsCollector.EMPTY_TOPDOCS; + } + if (Math.min(knnCollector.k(), context.reader().getFloatVectorValues(fi.name).size()) == 0) { + return TopDocsCollector.EMPTY_TOPDOCS; + } + context.reader().searchNearestVectors(field, target, knnCollector, acceptDocs); + TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java index 1f67704da4d1..59d2b2abe39d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopKnnCollector.java @@ -25,9 +25,9 @@ * * @lucene.experimental */ -public final class TopKnnCollector extends AbstractKnnCollector { +public class TopKnnCollector extends AbstractKnnCollector { - private final NeighborQueue queue; + protected final NeighborQueue queue; /** * @param k the number of neighbors to collect diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/KnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/KnnCollectorManager.java new file mode 100644 index 000000000000..5e1eeecc8409 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/KnnCollectorManager.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.knn; + +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.KnnCollector; + +/** + * KnnCollectorManager responsible for creating {@link KnnCollector} instances. Useful to create + * {@link KnnCollector} instances that share global state across leaves, such a global queue of + * results collected so far. + */ +public interface KnnCollectorManager { + + /** + * Return a new {@link KnnCollector} instance. + * + * @param visitedLimit the maximum number of nodes that the search is allowed to visit + * @param context the leaf reader context + */ + KnnCollector newCollector(int visitedLimit, LeafReaderContext context) throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java new file mode 100644 index 000000000000..782a4059b20c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/MultiLeafTopKnnCollector.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.hnsw.BlockingFloatHeap; +import org.apache.lucene.util.hnsw.FloatHeap; + +/** + * MultiLeafTopKnnCollector is a specific KnnCollector that can exchange the top collected results + * across segments through a shared global queue. + * + * @lucene.experimental + */ +public final class MultiLeafTopKnnCollector extends TopKnnCollector { + + // greediness of globally non-competitive search: (0,1] + private static final float DEFAULT_GREEDINESS = 0.9f; + // the global queue of the highest similarities collected so far across all segments + private final BlockingFloatHeap globalSimilarityQueue; + // the local queue of the highest similarities if we are not competitive globally + // the size of this queue is defined by greediness + private final FloatHeap nonCompetitiveQueue; + private final float greediness; + // the queue of the local similarities to periodically update with the global queue + private final FloatHeap updatesQueue; + // interval to synchronize the local and global queues, as a number of visited vectors + private final int interval = 0xff; // 255 + private boolean kResultsCollected = false; + private float cachedGlobalMinSim = Float.NEGATIVE_INFINITY; + + /** + * @param k the number of neighbors to collect + * @param visitLimit how many vector nodes the results are allowed to visit + */ + public MultiLeafTopKnnCollector(int k, int visitLimit, BlockingFloatHeap globalSimilarityQueue) { + super(k, visitLimit); + this.greediness = DEFAULT_GREEDINESS; + this.globalSimilarityQueue = globalSimilarityQueue; + this.nonCompetitiveQueue = new FloatHeap(Math.max(1, Math.round((1 - greediness) * k))); + this.updatesQueue = new FloatHeap(k); + } + + @Override + public boolean collect(int docId, float similarity) { + boolean localSimUpdated = queue.insertWithOverflow(docId, similarity); + boolean firstKResultsCollected = (kResultsCollected == false && queue.size() == k()); + if (firstKResultsCollected) { + kResultsCollected = true; + } + updatesQueue.offer(similarity); + boolean globalSimUpdated = nonCompetitiveQueue.offer(similarity); + + if (kResultsCollected) { + // as we've collected k results, we can start do periodic updates with the global queue + if (firstKResultsCollected || (visitedCount & interval) == 0) { + cachedGlobalMinSim = globalSimilarityQueue.offer(updatesQueue.getHeap()); + updatesQueue.clear(); + globalSimUpdated = true; + } + } + return localSimUpdated || globalSimUpdated; + } + + @Override + public float minCompetitiveSimilarity() { + if (kResultsCollected == false) { + return Float.NEGATIVE_INFINITY; + } + return Math.max(queue.topScore(), Math.min(nonCompetitiveQueue.peek(), cachedGlobalMinSim)); + } + + @Override + public TopDocs topDocs() { + assert queue.size() <= k() : "Tried to collect more results than the maximum number allowed"; + ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; + for (int i = 1; i <= scoreDocs.length; i++) { + scoreDocs[scoreDocs.length - i] = new ScoreDoc(queue.topNode(), queue.topScore()); + queue.pop(); + } + TotalHits.Relation relation = + earlyTerminated() + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TopDocs(new TotalHits(visitedCount(), relation), scoreDocs); + } + + @Override + public String toString() { + return "MultiLeafTopKnnCollector[k=" + k() + ", size=" + queue.size() + "]"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java new file mode 100644 index 000000000000..df4431df5fc7 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/TopKnnCollectorManager.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.knn; + +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.util.hnsw.BlockingFloatHeap; + +/** + * TopKnnCollectorManager responsible for creating {@link TopKnnCollector} instances. When + * concurrency is supported, the {@link BlockingFloatHeap} is used to track the global top scores + * collected across all leaves. + */ +public class TopKnnCollectorManager implements KnnCollectorManager { + + // the number of docs to collect + private final int k; + // the global score queue used to track the top scores collected across all leaves + private final BlockingFloatHeap globalScoreQueue; + + public TopKnnCollectorManager(int k, IndexSearcher indexSearcher) { + boolean isMultiSegments = indexSearcher.getIndexReader().leaves().size() > 1; + this.k = k; + this.globalScoreQueue = isMultiSegments ? new BlockingFloatHeap(k) : null; + } + + /** + * Return a new {@link TopKnnCollector} instance. + * + * @param visitedLimit the maximum number of nodes that the search is allowed to visit + * @param context the leaf reader context + */ + @Override + public TopKnnCollector newCollector(int visitedLimit, LeafReaderContext context) + throws IOException { + if (globalScoreQueue == null) { + return new TopKnnCollector(k, visitedLimit); + } else { + return new MultiLeafTopKnnCollector(k, visitedLimit, globalScoreQueue); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/package-info.java b/lucene/core/src/java/org/apache/lucene/search/knn/package-info.java new file mode 100644 index 000000000000..8363eff359a9 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Classes related to vector search: knn and vector fields. */ +package org.apache.lucene.search.knn; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/BlockingFloatHeap.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/BlockingFloatHeap.java new file mode 100644 index 000000000000..a81eaf2fee04 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/BlockingFloatHeap.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import java.util.concurrent.locks.ReentrantLock; + +/** + * A blocking bounded min heap that stores floats. The top element is the lowest value of the heap. + * + *

A primitive priority queue that maintains a partial ordering of its elements such that the + * least element can always be found in constant time. Implementation is based on {@link + * org.apache.lucene.util.LongHeap} + * + * @lucene.internal + */ +public final class BlockingFloatHeap { + private final int maxSize; + private final float[] heap; + private final ReentrantLock lock; + private int size; + + public BlockingFloatHeap(int maxSize) { + this.maxSize = maxSize; + this.heap = new float[maxSize + 1]; + this.lock = new ReentrantLock(); + this.size = 0; + } + + /** + * Inserts a value into this heap. + * + *

If the number of values would exceed the heap's maxSize, the least value is discarded + * + * @param value the value to add + * @return the new 'top' element in the queue. + */ + public float offer(float value) { + lock.lock(); + try { + if (size < maxSize) { + push(value); + return heap[1]; + } else { + if (value >= heap[1]) { + updateTop(value); + } + return heap[1]; + } + } finally { + lock.unlock(); + } + } + + /** + * Inserts array of values into this heap. + * + *

Values must be sorted in ascending order. + * + * @param values a set of values to insert, must be sorted in ascending order + * @return the new 'top' element in the queue. + */ + public float offer(float[] values) { + lock.lock(); + try { + for (int i = values.length - 1; i >= 0; i--) { + if (size < maxSize) { + push(values[i]); + } else { + if (values[i] >= heap[1]) { + updateTop(values[i]); + } else { + break; + } + } + } + return heap[1]; + } finally { + lock.unlock(); + } + } + + /** + * Removes and returns the head of the heap + * + * @return the head of the heap, the smallest value + * @throws IllegalStateException if the heap is empty + */ + public float poll() { + if (size > 0) { + float result; + + lock.lock(); + try { + result = heap[1]; // save first value + heap[1] = heap[size]; // move last to first + size--; + downHeap(1); // adjust heap + } finally { + lock.unlock(); + } + return result; + } else { + throw new IllegalStateException("The heap is empty"); + } + } + + /** + * Retrieves, but does not remove, the head of this heap. + * + * @return the head of the heap, the smallest value + */ + public float peek() { + lock.lock(); + try { + return heap[1]; + } finally { + lock.unlock(); + } + } + + /** + * Returns the number of elements in this heap. + * + * @return the number of elements in this heap + */ + public int size() { + lock.lock(); + try { + return size; + } finally { + lock.unlock(); + } + } + + private void push(float element) { + size++; + heap[size] = element; + upHeap(size); + } + + private float updateTop(float value) { + heap[1] = value; + downHeap(1); + return heap[1]; + } + + private void downHeap(int i) { + float value = heap[i]; // save top value + int j = i << 1; // find smaller child + int k = j + 1; + if (k <= size && heap[k] < heap[j]) { + j = k; + } + while (j <= size && heap[j] < value) { + heap[i] = heap[j]; // shift up child + i = j; + j = i << 1; + k = j + 1; + if (k <= size && heap[k] < heap[j]) { + j = k; + } + } + heap[i] = value; // install saved value + } + + private void upHeap(int origPos) { + int i = origPos; + float value = heap[i]; // save bottom value + int j = i >>> 1; + while (j > 0 && value < heap[j]) { + heap[i] = heap[j]; // shift parents down + i = j; + j = j >>> 1; + } + heap[i] = value; // install saved value + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/FloatHeap.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/FloatHeap.java new file mode 100644 index 000000000000..e1a267a3c015 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/FloatHeap.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** + * A bounded min heap that stores floats. The top element is the lowest value of the heap. + * + *

A primitive priority queue that maintains a partial ordering of its elements such that the + * least element can always be found in constant time. Implementation is based on {@link + * org.apache.lucene.util.LongHeap} + * + * @lucene.internal + */ +public final class FloatHeap { + private final int maxSize; + private final float[] heap; + private int size; + + public FloatHeap(int maxSize) { + this.maxSize = maxSize; + this.heap = new float[maxSize + 1]; + this.size = 0; + } + + /** + * Inserts a value into this heap. + * + *

If the number of values would exceed the heap's maxSize, the least value is discarded + * + * @param value the value to add + * @return whether the value was added (unless the heap is full, or the new value is less than the + * top value) + */ + public boolean offer(float value) { + if (size >= maxSize) { + if (value < heap[1]) { + return false; + } + updateTop(value); + return true; + } + push(value); + return true; + } + + public float[] getHeap() { + float[] result = new float[size]; + System.arraycopy(this.heap, 1, result, 0, size); + return result; + } + + /** + * Removes and returns the head of the heap + * + * @return the head of the heap, the smallest value + * @throws IllegalStateException if the heap is empty + */ + public float poll() { + if (size > 0) { + float result; + result = heap[1]; // save first value + heap[1] = heap[size]; // move last to first + size--; + downHeap(1); // adjust heap + return result; + } else { + throw new IllegalStateException("The heap is empty"); + } + } + + /** + * Retrieves, but does not remove, the head of this heap. + * + * @return the head of the heap, the smallest value + */ + public float peek() { + return heap[1]; + } + + /** + * Returns the number of elements in this heap. + * + * @return the number of elements in this heap + */ + public int size() { + return size; + } + + public void clear() { + size = 0; + } + + private void push(float element) { + size++; + heap[size] = element; + upHeap(size); + } + + private float updateTop(float value) { + heap[1] = value; + downHeap(1); + return heap[1]; + } + + private void downHeap(int i) { + float value = heap[i]; // save top value + int j = i << 1; // find smaller child + int k = j + 1; + if (k <= size && heap[k] < heap[j]) { + j = k; + } + while (j <= size && heap[j] < value) { + heap[i] = heap[j]; // shift up child + i = j; + j = i << 1; + k = j + 1; + if (k <= size && heap[k] < heap[j]) { + j = k; + } + } + heap[i] = value; // install saved value + } + + private void upHeap(int origPos) { + int i = origPos; + float value = heap[i]; // save bottom value + int j = i >>> 1; + while (j > 0 && value < heap[j]) { + heap[i] = heap[j]; // shift parents down + i = j; + j = j >>> 1; + } + heap[i] = value; // install saved value + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index eb6413e53a32..3c82cd6b33e4 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -547,6 +547,7 @@ public void testByteVectorValues() throws IOException { 5, leaf.getLiveDocs(), Integer.MAX_VALUE)); + } else { DocIdSetIterator iter = leaf.getByteVectorValues("vector"); scanAndRetrieve(leaf, iter); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestBlockingFloatHeap.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestBlockingFloatHeap.java new file mode 100644 index 000000000000..07df7b3044ef --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestBlockingFloatHeap.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; + +import java.util.concurrent.CountDownLatch; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.SuppressForbidden; + +public class TestBlockingFloatHeap extends LuceneTestCase { + + public void testBasicOperations() { + BlockingFloatHeap heap = new BlockingFloatHeap(3); + heap.offer(2); + heap.offer(4); + heap.offer(1); + heap.offer(3); + assertEquals(3, heap.size()); + assertEquals(2, heap.peek(), 0); + + assertEquals(2, heap.poll(), 0); + assertEquals(3, heap.poll(), 0); + assertEquals(4, heap.poll(), 0); + assertEquals(0, heap.size(), 0); + } + + public void testBasicOperations2() { + int size = atLeast(10); + BlockingFloatHeap heap = new BlockingFloatHeap(size); + double sum = 0, sum2 = 0; + + for (int i = 0; i < size; i++) { + float next = random().nextFloat(100f); + sum += next; + heap.offer(next); + } + + float last = Float.NEGATIVE_INFINITY; + for (long i = 0; i < size; i++) { + float next = heap.poll(); + assertTrue(next >= last); + last = next; + sum2 += last; + } + assertEquals(sum, sum2, 0.01); + } + + @SuppressForbidden(reason = "Thread sleep") + public void testMultipleThreads() throws Exception { + Thread[] threads = new Thread[randomIntBetween(3, 20)]; + final CountDownLatch latch = new CountDownLatch(1); + BlockingFloatHeap globalHeap = new BlockingFloatHeap(1); + + for (int i = 0; i < threads.length; i++) { + threads[i] = + new Thread( + () -> { + try { + latch.await(); + int numIterations = randomIntBetween(10, 100); + float bottomValue = 0; + + while (numIterations-- > 0) { + bottomValue += randomIntBetween(0, 5); + globalHeap.offer(bottomValue); + Thread.sleep(randomIntBetween(0, 50)); + + float globalBottomValue = globalHeap.peek(); + assertTrue(globalBottomValue >= bottomValue); + bottomValue = globalBottomValue; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + threads[i].start(); + } + + latch.countDown(); + for (Thread t : threads) { + t.join(); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestFloatHeap.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestFloatHeap.java new file mode 100644 index 000000000000..004412496cd2 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestFloatHeap.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestFloatHeap extends LuceneTestCase { + + public void testBasicOperations() { + FloatHeap heap = new FloatHeap(3); + heap.offer(2); + heap.offer(4); + heap.offer(1); + heap.offer(3); + assertEquals(3, heap.size()); + assertEquals(2, heap.peek(), 0); + + assertEquals(2, heap.poll(), 0); + assertEquals(3, heap.poll(), 0); + assertEquals(4, heap.poll(), 0); + assertEquals(0, heap.size(), 0); + } + + public void testBasicOperations2() { + int size = atLeast(10); + FloatHeap heap = new FloatHeap(size); + double sum = 0, sum2 = 0; + + for (int i = 0; i < size; i++) { + float next = random().nextFloat(100f); + sum += next; + heap.offer(next); + } + + float last = Float.NEGATIVE_INFINITY; + for (long i = 0; i < size; i++) { + float next = heap.poll(); + assertTrue(next >= last); + last = next; + sum2 += last; + } + assertEquals(sum, sum2, 0.01); + } + + public void testClear() { + FloatHeap heap = new FloatHeap(3); + heap.offer(20); + heap.offer(40); + heap.offer(30); + assertEquals(3, heap.size()); + assertEquals(20, heap.peek(), 0); + + heap.clear(); + assertEquals(0, heap.size(), 0); + assertEquals(20, heap.peek(), 0); + + heap.offer(15); + heap.offer(35); + assertEquals(2, heap.size()); + assertEquals(15, heap.peek(), 0); + + assertEquals(15, heap.poll(), 0); + assertEquals(35, heap.poll(), 0); + assertEquals(0, heap.size(), 0); + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index b8c494fa05d4..dff70f31bacb 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Query; @@ -33,6 +34,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -123,7 +125,16 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept } @Override - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) throws IOException { BitSet parentBitSet = parentsFilter.getBitSet(context); if (parentBitSet == null) { diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index e08b3d4c4ff4..a84d809ac6cf 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -26,6 +26,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; @@ -33,6 +34,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; @@ -123,14 +125,21 @@ protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator accept } @Override - protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + return new DiversifyingNearestChildrenKnnCollectorManager(k, parentsFilter); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) throws IOException { - BitSet parentBitSet = parentsFilter.getBitSet(context); - if (parentBitSet == null) { + KnnCollector collector = knnCollectorManager.newCollector(visitedLimit, context); + if (collector == null) { return NO_RESULTS; } - KnnCollector collector = - new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); context.reader().searchNearestVectors(field, query, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java new file mode 100644 index 000000000000..8e8a54eedfc1 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingNearestChildrenKnnCollectorManager.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.join; + +import java.io.IOException; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.BitSet; + +/** + * DiversifyingNearestChildrenKnnCollectorManager responsible for creating {@link + * DiversifyingNearestChildrenKnnCollector} instances. + */ +public class DiversifyingNearestChildrenKnnCollectorManager implements KnnCollectorManager { + + // the number of docs to collect + private final int k; + // filter identifying the parent documents. + private final BitSetProducer parentsFilter; + + /** + * Constructor + * + * @param k - the number of top k vectors to collect + * @param parentsFilter Filter identifying the parent documents. + */ + public DiversifyingNearestChildrenKnnCollectorManager(int k, BitSetProducer parentsFilter) { + this.k = k; + this.parentsFilter = parentsFilter; + } + + /** + * Return a new {@link DiversifyingNearestChildrenKnnCollector} instance. + * + * @param visitedLimit the maximum number of nodes that the search is allowed to visit + * @param context the leaf reader context + */ + @Override + public DiversifyingNearestChildrenKnnCollector newCollector( + int visitedLimit, LeafReaderContext context) throws IOException { + BitSet parentBitSet = parentsFilter.getBitSet(context); + if (parentBitSet == null) { + return null; + } + return new DiversifyingNearestChildrenKnnCollector(k, visitedLimit, parentBitSet); + } +}