Skip to content

Commit

Permalink
Adds profiling, including a small refactoring of the QueryProfiler in…
Browse files Browse the repository at this point in the history
…terface
  • Loading branch information
carlosdelest committed Nov 29, 2024
1 parent 81384f2 commit 1347d4b
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public QueryProfiler() {
super(new InternalQueryProfileTree());
}

public void setVectorOpsCount(long vectorOpsCount) {
this.vectorOpsCount = vectorOpsCount;
public void addVectorOpsCount(long vectorOpsCount) {
this.vectorOpsCount += vectorOpsCount;
}

public long getVectorOpsCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

public Integer kParam() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

public Integer kParam() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ public RescoreKnnVectorQuery(
public Query rewrite(IndexSearcher searcher) throws IOException {
assert byteTarget == null ^ floatTarget == null : "Either byteTarget or floatTarget must be set";

Query rewritten = super.rewrite(searcher);
if (rewritten != this) {
return rewritten;
}

final DoubleValuesSource valueSource;
if (byteTarget != null) {
valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction);
Expand Down Expand Up @@ -115,7 +110,10 @@ public Integer k() {

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
if (innerQuery instanceof ProfilingQuery profilingQuery) {
profilingQuery.profile(queryProfiler);
}
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.lucene.search.function.MinScoreScorer;
import org.elasticsearch.search.profile.query.QueryProfiler;

import java.io.IOException;
import java.util.Objects;
Expand All @@ -30,7 +31,7 @@
/**
* This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery.
*/
public class VectorSimilarityQuery extends Query {
public class VectorSimilarityQuery extends Query implements ProfilingQuery {
private final float similarity;
private final float docScore;
private final Query innerKnnQuery;
Expand Down Expand Up @@ -78,6 +79,13 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return new MinScoreWeight(innerWeight, docScore, similarity, this, boost);
}

@Override
public void profile(QueryProfiler queryProfiler) {
if (innerKnnQuery instanceof ProfilingQuery profilingQuery) {
profilingQuery.profile(queryProfiler);
}
}

@Override
public String toString(String field) {
return "VectorSimilarityQuery["
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
Expand All @@ -41,6 +46,7 @@

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;

public class RescoreKnnVectorQueryTests extends ESTestCase {

Expand All @@ -51,7 +57,8 @@ public class RescoreKnnVectorQueryTests extends ESTestCase {

public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) {
this.vectorProvider = vectorProvider;
this.numDocs = randomIntBetween(10, 100);;
this.numDocs = randomIntBetween(10, 100);
;
this.k = useK ? randomIntBetween(1, numDocs - 1) : null;
}

Expand All @@ -71,7 +78,11 @@ public void testRescoreDocs() throws Exception {
// Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
// and thus we're rescoring the top k docs.
VectorData queryVector = vectorProvider.randomVector(numDims);
RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, adjustedK);
RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(
queryVector,
adjustedK,
new MatchAllDocsQuery()
);

IndexSearcher searcher = newSearcher(reader, true, false);
TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs);
Expand Down Expand Up @@ -115,10 +126,90 @@ public void testRescoreDocs() throws Exception {
}
}

public void testProfiling() throws Exception {
int numDims = randomIntBetween(5, 100);

try (Directory d = newDirectory()) {
addRandomDocuments(numDocs, d, numDims, vectorProvider);

try (IndexReader reader = DirectoryReader.open(d)) {
VectorData queryVector = vectorProvider.randomVector(numDims);

checkProfiling(queryVector, reader, new MatchAllDocsQuery());
checkProfiling(queryVector, reader, new MockProfilingQuery(randomIntBetween(1, 100)));
}
}
}

private void checkProfiling(VectorData queryVector, IndexReader reader, Query innerQuery) throws IOException {
RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, k, innerQuery);
IndexSearcher searcher = newSearcher(reader, true, false);
searcher.search(rescoreKnnVectorQuery, numDocs);

QueryProfiler queryProfiler = new QueryProfiler();
rescoreKnnVectorQuery.profile(queryProfiler);

long expectedVectorOpsCount = 0;
if (k != null) {
expectedVectorOpsCount += k;
}
if (innerQuery instanceof ProfilingQuery profilingQuery) {
QueryProfiler anotherProfiler = new QueryProfiler();
profilingQuery.profile(anotherProfiler);
assertThat(anotherProfiler.getVectorOpsCount(), greaterThan(0L));
expectedVectorOpsCount += anotherProfiler.getVectorOpsCount();
}

assertThat(queryProfiler.getVectorOpsCount(), equalTo(expectedVectorOpsCount));
}

/**
* A mock query that is used to test profiling
*/
private static class MockProfilingQuery extends Query implements ProfilingQuery {

private final long vectorOpsCount;

private MockProfilingQuery(long vectorOpsCount) {
this.vectorOpsCount = vectorOpsCount;
}

@Override
public String toString(String field) {
return "";
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new MatchAllDocsQuery().createWeight(searcher, scoreMode, boost);
}

@Override
public void visit(QueryVisitor visitor) {}

@Override
public boolean equals(Object obj) {
return obj instanceof MockProfilingQuery;
}

@Override
public int hashCode() {
return 0;
}

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}
}

/**
* Vector operations depend on the type of vector field used. This interface abstracts the operations needed to perform the tests
*/
private interface VectorProvider {
VectorData randomVector(int numDimensions);

RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k);
RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery);

KnnVectorValues vectorValues(LeafReader leafReader) throws IOException;

Expand All @@ -140,14 +231,8 @@ public VectorData randomVector(int numDimensions) {
}

@Override
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
return new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector.floatVector(),
VectorSimilarityFunction.COSINE,
k,
new MatchAllDocsQuery()
);
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) {
return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.floatVector(), VectorSimilarityFunction.COSINE, k, innerQuery);
}

@Override
Expand All @@ -163,7 +248,7 @@ public void addVectorField(Document document, VectorData vector) {

@Override
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
return VectorData.fromFloats(((FloatVectorValues)vectorValues).vectorValue(docId));
return VectorData.fromFloats(((FloatVectorValues) vectorValues).vectorValue(docId));
}

@Override
Expand All @@ -183,14 +268,8 @@ public VectorData randomVector(int numDimensions) {
}

@Override
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) {
return new RescoreKnnVectorQuery(
FIELD_NAME,
queryVector.byteVector(),
VectorSimilarityFunction.COSINE,
k,
new MatchAllDocsQuery()
);
public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) {
return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.byteVector(), VectorSimilarityFunction.COSINE, k, innerQuery);
}

@Override
Expand All @@ -206,7 +285,7 @@ public void addVectorField(Document document, VectorData vector) {

@Override
public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException {
return VectorData.fromBytes(((ByteVectorValues)vectorValues).vectorValue(docId));
return VectorData.fromBytes(((ByteVectorValues) vectorValues).vectorValue(docId));
}

@Override
Expand All @@ -230,39 +309,12 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims, Ve

@ParametersFactory
public static Iterable<Object[]> parameters() {

List<Object[]> params = new ArrayList<>();
params.add(new Object[] {new FloatVectorProvider(), true});
params.add(new Object[] {new FloatVectorProvider(), false});
params.add(new Object[] {new ByteVectorProvider(), true});
params.add(new Object[] {new ByteVectorProvider(), false});
params.add(new Object[] { new FloatVectorProvider(), true });
params.add(new Object[] { new FloatVectorProvider(), false });
params.add(new Object[] { new ByteVectorProvider(), true });
params.add(new Object[] { new ByteVectorProvider(), false });

return params;
}

// public void testProfiling() throws Exception {
// int numDocs = randomIntBetween(10, 100);
// int numDims = randomIntBetween(5, 100);
//
// try (Directory d = newDirectory()) {
// addRandomDocuments(numDocs, d, numDims, vectorProvider);
//
// try (IndexReader reader = DirectoryReader.open(d)) {
// float[] queryVector = randomVector(numDims);
//
// RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
// FIELD_NAME,
// queryVector,
// VectorSimilarityFunction.COSINE,
// randomIntBetween(5, numDocs - 1),
// new MatchAllDocsQuery()
// );
//
// IndexSearcher searcher = newSearcher(reader, true, false);
// QueryProfiler queryProfiler = new QueryProfiler();
// rescoreKnnVectorQuery.profile(queryProfiler);
// }
// }
// }

}

0 comments on commit 1347d4b

Please sign in to comment.