Skip to content

Commit

Permalink
Minor refactoring to reuse KnnScoreDocQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 20, 2024
1 parent ff2c1e9 commit bc1e5c6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
Expand Down Expand Up @@ -37,26 +38,45 @@
public class KnnScoreDocQuery extends Query {
private final int[] docs;
private final float[] scores;

// the indexes in docs and scores corresponding to the first matching document in each segment.
// If a segment has no matching documents, it should be assigned the index of the next segment that does.
// There should be a final entry that is always docs.length-1.
private final int[] segmentStarts;
// an object identifying the reader context that was used to build this query

private final Object contextIdentity;

/**
* Creates a query.
*
* @param docs the global doc IDs of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param contextIdentity an object identifying the reader context that was used to build this
* query
* @param reader IndexReader
*/
KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
this.segmentStarts = findSegmentStarts(reader, docs);
this.contextIdentity = reader.getContext().id();
}

private static int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.TransportVersion;
Expand All @@ -25,7 +24,6 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
Expand Down Expand Up @@ -153,9 +151,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
scores[i] = scoreDocs[i].score;
}

IndexReader reader = context.getIndexReader();
int[] segmentStarts = findSegmentStarts(reader, docs);
return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id());
return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
}

@Override
Expand All @@ -169,24 +165,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return super.doRewrite(queryRewriteContext);
}

private static int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}

@Override
protected boolean doEquals(KnnScoreDocQueryBuilder other) {
if (scoreDocs.length != other.scoreDocs.length) {
Expand Down

0 comments on commit bc1e5c6

Please sign in to comment.