Skip to content

Commit

Permalink
Parameterize recore knn vector query tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 29, 2024
1 parent 257b75d commit 81384f2
Showing 1 changed file with 181 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@

package org.elasticsearch.search.vectors;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -24,9 +29,12 @@
import org.apache.lucene.store.Directory;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
Expand All @@ -37,65 +45,56 @@
public class RescoreKnnVectorQueryTests extends ESTestCase {

public static final String FIELD_NAME = "float_vector";
private final int numDocs;
private final VectorProvider vectorProvider;
private final Integer k;

public void testRescoresTopK() throws Exception {
int numDocs = randomIntBetween(10, 100);
testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1));
}

public void testRescoresNoKParameter() throws Exception {
testRescoreDocs(randomIntBetween(10, 100), null);
public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) {
this.vectorProvider = vectorProvider;
this.numDocs = randomIntBetween(10, 100);;
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
}

private void testRescoreDocs(int numDocs, Integer k) throws Exception {
public void testRescoreDocs() throws Exception {
int numDims = randomIntBetween(5, 100);

Integer adjustedK = k;
if (k == null) {
k = numDocs;
adjustedK = numDocs;
}

try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) {
for (int i = 0; i < numDocs; i++) {
Document document = new Document();
float[] vector = randomVector(numDims);
KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector);
document.add(vectorField);
w.addDocument(document);
}
w.commit();
w.forceMerge(1);
}
addRandomDocuments(numDocs, d, numDims, vectorProvider);

try (IndexReader reader = DirectoryReader.open(d)) {
float[] queryVector = randomVector(numDims);

RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector,
VectorSimilarityFunction.COSINE,
k,
new MatchAllDocsQuery()
);
// Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
// and thus we're rescoring the top k docs.
VectorData queryVector = vectorProvider.randomVector(numDims);
RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, adjustedK);

IndexSearcher searcher = newSearcher(reader, true, false);
TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs);
Map<Integer, Float> rescoredDocs = Arrays.stream(docs.scoreDocs)
.collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score));

assertThat(rescoredDocs.size(), equalTo(k));
assertThat(rescoredDocs.size(), equalTo(adjustedK));

Collection<Float> rescoredScores = new HashSet<>(rescoredDocs.values());

Collection<Float> rescoredScores = new ArrayList<>(rescoredDocs.values());
// Collect all docs sequentially, and score them using the similarity function to get the top K scores
PriorityQueue<Float> topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1));

for (LeafReaderContext leafReaderContext : reader.leaves()) {
FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME);
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
KnnVectorValues vectorValues = vectorProvider.vectorValues(leafReaderContext.reader());
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
float[] vector = floatVectorValues.vectorValue(iterator.index());
float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector);
VectorData vectorData = vectorProvider.dataVectorForDoc(vectorValues, iterator.docID());
float score = vectorProvider.score(queryVector, vectorData);
topK.add(score);
int docId = iterator.docID();
// If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it
// to ensure we found them all
if (rescoredDocs.containsKey(docId)) {
assertThat(rescoredDocs.get(docId), equalTo(score));
rescoredDocs.remove(docId);
Expand All @@ -106,7 +105,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
assertThat(rescoredDocs.size(), equalTo(0));

// Check top scoring docs are contained in rescored docs
for (int i = 0; i < k; i++) {
for (int i = 0; i < adjustedK; i++) {
Float topScore = topK.poll();
if (rescoredScores.contains(topScore) == false) {
fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores);
Expand All @@ -116,12 +115,154 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
}
}

private static float[] randomVector(int numDims) {
float[] vector = new float[numDims];
for (int j = 0; j < numDims; j++) {
vector[j] = randomFloatBetween(0, 1, true);
private interface VectorProvider {
VectorData randomVector(int numDimensions);

RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k);

KnnVectorValues vectorValues(LeafReader leafReader) throws IOException;

void addVectorField(Document document, VectorData vector);

VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException;

float score(VectorData queryVector, VectorData dataVector);
}

private static class FloatVectorProvider implements VectorProvider {
@Override
public VectorData randomVector(int numDimensions) {
float[] vector = new float[numDimensions];
for (int j = 0; j < numDimensions; j++) {
vector[j] = randomFloatBetween(0, 1, true);
}
return VectorData.fromFloats(vector);
}

@Override
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
return new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector.floatVector(),
VectorSimilarityFunction.COSINE,
k,
new MatchAllDocsQuery()
);
}

@Override
public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException {
return leafReader.getFloatVectorValues(FIELD_NAME);
}

@Override
public void addVectorField(Document document, VectorData vector) {
KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector.floatVector());
document.add(vectorField);
}

@Override
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
return VectorData.fromFloats(((FloatVectorValues)vectorValues).vectorValue(docId));
}

@Override
public float score(VectorData queryVector, VectorData dataVector) {
return VectorSimilarityFunction.COSINE.compare(queryVector.floatVector(), dataVector.floatVector());
}
return vector;
}

private static class ByteVectorProvider implements VectorProvider {
@Override
public VectorData randomVector(int numDimensions) {
byte[] vector = new byte[numDimensions];
for (int j = 0; j < numDimensions; j++) {
vector[j] = randomByte();
}
return VectorData.fromBytes(vector);
}

@Override
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
return new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector.byteVector(),
VectorSimilarityFunction.COSINE,
k,
new MatchAllDocsQuery()
);
}

@Override
public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException {
return leafReader.getByteVectorValues(FIELD_NAME);
}

@Override
public void addVectorField(Document document, VectorData vector) {
KnnByteVectorField vectorField = new KnnByteVectorField(FIELD_NAME, vector.byteVector());
document.add(vectorField);
}

@Override
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
return VectorData.fromBytes(((ByteVectorValues)vectorValues).vectorValue(docId));
}

@Override
public float score(VectorData queryVector, VectorData dataVector) {
return VectorSimilarityFunction.COSINE.compare(queryVector.byteVector(), dataVector.byteVector());
}
}

private static void addRandomDocuments(int numDocs, Directory d, int numDims, VectorProvider vectorProvider) throws IOException {
try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) {
for (int i = 0; i < numDocs; i++) {
Document document = new Document();
VectorData vector = vectorProvider.randomVector(numDims);
vectorProvider.addVectorField(document, vector);
w.addDocument(document);
}
w.commit();
w.forceMerge(1);
}
}

@ParametersFactory
public static Iterable<Object[]> parameters() {

List<Object[]> params = new ArrayList<>();
params.add(new Object[] {new FloatVectorProvider(), true});
params.add(new Object[] {new FloatVectorProvider(), false});
params.add(new Object[] {new ByteVectorProvider(), true});
params.add(new Object[] {new ByteVectorProvider(), false});

return params;
}

// public void testProfiling() throws Exception {
// int numDocs = randomIntBetween(10, 100);
// int numDims = randomIntBetween(5, 100);
//
// try (Directory d = newDirectory()) {
// addRandomDocuments(numDocs, d, numDims, vectorProvider);
//
// try (IndexReader reader = DirectoryReader.open(d)) {
// float[] queryVector = randomVector(numDims);
//
// RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
// FIELD_NAME,
// queryVector,
// VectorSimilarityFunction.COSINE,
// randomIntBetween(5, numDocs - 1),
// new MatchAllDocsQuery()
// );
//
// IndexSearcher searcher = newSearcher(reader, true, false);
// QueryProfiler queryProfiler = new QueryProfiler();
// rescoreKnnVectorQuery.profile(queryProfiler);
// }
// }
// }

}

0 comments on commit 81384f2

Please sign in to comment.