Skip to content

Commit

Permalink
Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code …
Browse files Browse the repository at this point in the history
…duplication (#13529)

* Lucene99HnswVectorsReader.search float-vs-byte variants: reduce code duplication

* action review feedback: use org.apache.lucene.util.IOSupplier
  • Loading branch information
cpoerschke authored Jul 1, 2024
1 parent 0ad270d commit f4cd4b4
Showing 1 changed file with 24 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
Expand Down Expand Up @@ -248,45 +249,39 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
if (knnCollector.k() < scorer.maxOrd()) {
HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds);
} else {
// if k is larger than the number of vectors, we can just iterate over all vectors
// and collect them
for (int i = 0; i < scorer.maxOrd(); i++) {
if (acceptedOrds == null || acceptedOrds.get(i)) {
if (knnCollector.earlyTerminated()) {
break;
}
knnCollector.incVisitedCount(1);
knnCollector.collect(scorer.ordToDoc(i), scorer.score(i));
}
}
}
search(
fields.get(field),
knnCollector,
acceptDocs,
VectorEncoding.FLOAT32,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}

@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
search(
fields.get(field),
knnCollector,
acceptDocs,
VectorEncoding.BYTE,
() -> flatVectorsReader.getRandomVectorScorer(field, target));
}

private void search(
FieldEntry fieldEntry,
KnnCollector knnCollector,
Bits acceptDocs,
VectorEncoding vectorEncoding,
IOSupplier<RandomVectorScorer> scorerSupplier)
throws IOException {

if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|| fieldEntry.vectorEncoding != vectorEncoding) {
return;
}
final RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
final RandomVectorScorer scorer = scorerSupplier.get();
final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
Expand Down

0 comments on commit f4cd4b4

Please sign in to comment.