From bc1e5c692158ff6dbcb80da390cc374c6ac8eaec Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 20 Nov 2024 18:04:41 +0100 Subject: [PATCH] Minor refactoring to reuse KnnScoreDocQuery --- .../search/vectors/KnnScoreDocQuery.java | 38 ++++++++++++++----- .../vectors/KnnScoreDocQueryBuilder.java | 24 +----------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index bb83b8528c6c8..db7484dddc226 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -37,7 +38,13 @@ public class KnnScoreDocQuery extends Query { private final int[] docs; private final float[] scores; + + // the indexes in docs and scores corresponding to the first matching document in each segment. + // If a segment has no matching documents, it should be assigned the index of the next segment that does. + // There should be a final entry that is always docs.length-1. private final int[] segmentStarts; + // an object identifying the reader context that was used to build this query + private final Object contextIdentity; /** @@ -45,18 +52,31 @@ public class KnnScoreDocQuery extends Query { * * @param docs the global doc IDs of documents that match, in ascending order * @param scores the scores of the matching documents - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param contextIdentity an object identifying the reader context that was used to build this - * query + * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { this.docs = docs; this.scores = scores; - this.segmentStarts = segmentStarts; - this.contextIdentity = contextIdentity; + this.segmentStarts = findSegmentStarts(reader, docs); + this.contextIdentity = reader.getContext().id(); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index f52addefc8b1c..10bee9ec66c2c 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.TransportVersion; @@ -25,7 +24,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.Objects; /** @@ -153,9 +151,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { scores[i] = scoreDocs[i].score; } - IndexReader reader = context.getIndexReader(); - int[] segmentStarts = findSegmentStarts(reader, docs); - return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id()); + return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); } @Override @@ -169,24 +165,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return super.doRewrite(queryRewriteContext); } - private static int[] findSegmentStarts(IndexReader reader, int[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - @Override protected boolean doEquals(KnnScoreDocQueryBuilder other) { if (scoreDocs.length != other.scoreDocs.length) {