Skip to content

Commit

Permalink
WIP - add rescore to custom ESKnnFloatVectorQuery to do exact search …
Browse files Browse the repository at this point in the history
…after an approximate search
  • Loading branch information
carlosdelest committed Oct 31, 2024
1 parent 6182921 commit 3a65731
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INDEX_REQUEST_REMOVE_METERING = def(8_780_00_0);
public static final TransportVersion CPU_STAT_STRING_PARSING = def(8_781_00_0);
public static final TransportVersion QUERY_RULES_RETRIEVER = def(8_782_00_0);
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_783_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ public Query createKnnQuery(
VectorData queryVector,
Integer k,
int numCands,
Float rescoreOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -1994,7 +1995,15 @@ public Query createKnnQuery(
}
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter);
case FLOAT -> createKnnFloatQuery(
queryVector.asFloatVector(),
k,
numCands,
rescoreOversample,
filter,
similarityThreshold,
parentFilter
);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
};
}
Expand Down Expand Up @@ -2052,6 +2061,7 @@ private Query createKnnFloatQuery(
float[] queryVector,
Integer k,
int numCands,
Float rescoreOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -2073,7 +2083,7 @@ && isNotUnitVector(squaredMagnitude)) {
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter);
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, rescoreOversample, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TimeLimitingKnnCollectorManager;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.elasticsearch.search.profile.query.QueryProfiler;

import java.io.IOException;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery {
private final Integer kParam;
private long vectorOpsCount;
private final Float rescoreOversample;

public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) {
super(field, target, numCands, filter);
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Float rescoreOversample, Query filter) {
super(field, target, adjustCandidates(numCands, rescoreOversample), filter);
this.kParam = k;
this.rescoreOversample = rescoreOversample;
}

private static int adjustCandidates(int numCands, Float rescoreOversample) {
return rescoreOversample == null ? numCands : (int) Math.ceil(numCands * rescoreOversample);
}

@Override
Expand All @@ -31,8 +48,43 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
return topK;
}

@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager
) throws IOException {
TopDocs topDocs = super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager);
if (rescoreOversample == null) {
return topDocs;
}

BitSet exactSearchAcceptDocs = topDocsToBitSet(topDocs, acceptDocs.length());
BitSetIterator bitSetIterator = new BitSetIterator(exactSearchAcceptDocs, topDocs.scoreDocs.length);
QueryTimeout queryTimeout = null;
if (knnCollectorManager instanceof TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) {
queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
}
return exactSearch(context, bitSetIterator, queryTimeout);
}

@Override
public void profile(QueryProfiler queryProfiler) {
queryProfiler.setVectorOpsCount(vectorOpsCount);
}

// Convert TopDocs to BitSet
private static BitSet topDocsToBitSet(TopDocs topDocs, int numBits) {
// Create a FixedBitSet with a size equal to the maximum number of documents
BitSet bitSet = new FixedBitSet(numBits);

// Iterate through each document in TopDocs
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
// Set the corresponding bit for each doc ID
bitSet.set(scoreDoc.doc);
}

return bitSet;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE;
import static org.elasticsearch.common.Strings.format;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
Expand All @@ -66,6 +67,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
public static final ParseField RESCORE_VECTOR_OVERSAMPLE = new ParseField("rescore_vector_oversample");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");

Expand All @@ -79,7 +81,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
null,
(Integer) args[2],
(Integer) args[3],
(Float) args[4]
(Float) args[4],
(Float) args[5]
)
);

Expand All @@ -106,6 +109,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
ObjectParser.ValueType.OBJECT_ARRAY
);
declareStandardFields(PARSER);
PARSER.declareFloat(optionalConstructorArg(), RESCORE_VECTOR_OVERSAMPLE);
}

public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
Expand All @@ -120,6 +124,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
private final Float vectorSimilarity;
private final QueryVectorBuilder queryVectorBuilder;
private final Supplier<float[]> queryVectorSupplier;
private final Float rescoreOversample;

public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
Expand Down Expand Up @@ -151,6 +156,19 @@ private KnnVectorQueryBuilder(
Integer k,
Integer numCands,
Float vectorSimilarity
) {
this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, 0F);
}

private KnnVectorQueryBuilder(
String fieldName,
VectorData queryVector,
QueryVectorBuilder queryVectorBuilder,
Supplier<float[]> queryVectorSupplier,
Integer k,
Integer numCands,
Float vectorSimilarity,
Float rescoreOversample
) {
if (k != null && k < 1) {
throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
Expand Down Expand Up @@ -187,6 +205,7 @@ private KnnVectorQueryBuilder(
this.vectorSimilarity = vectorSimilarity;
this.queryVectorBuilder = queryVectorBuilder;
this.queryVectorSupplier = queryVectorSupplier;
this.rescoreOversample = rescoreOversample;
}

public KnnVectorQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -227,6 +246,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
} else {
this.queryVectorBuilder = null;
}
if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
this.rescoreOversample = in.readOptionalFloat();
} else {
this.rescoreOversample = null;
}

this.queryVectorSupplier = null;
}

Expand All @@ -252,6 +277,10 @@ public Integer numCands() {
return numCands;
}

public Float rescoreOversample() {
return rescoreOversample;
}

public List<QueryBuilder> filterQueries() {
return filterQueries;
}
Expand Down Expand Up @@ -327,6 +356,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
out.writeOptionalNamedWriteable(queryVectorBuilder);
}
if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) {
out.writeOptionalFloat(rescoreOversample);
}
}

@Override
Expand Down Expand Up @@ -491,14 +523,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
// Now join the filterQuery & parentFilter to provide the matching blocks of children
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
rescoreOversample,
filterQuery,
vectorSimilarity,
parentBitSet
);
}
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null);
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, rescoreOversample, filterQuery, vectorSimilarity, null);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
return Objects.hash(
fieldName,
Objects.hashCode(queryVector),
k,
numCands,
filterQueries,
vectorSimilarity,
queryVectorBuilder,
rescoreOversample
);
}

@Override
Expand All @@ -509,7 +558,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
&& Objects.equals(numCands, other.numCands)
&& Objects.equals(filterQueries, other.filterQueries)
&& Objects.equals(vectorSimilarity, other.vectorSimilarity)
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder);
&& Objects.equals(queryVectorBuilder, other.queryVectorBuilder)
&& Objects.equals(rescoreOversample, other.rescoreOversample);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

public class VectorRescoreQuery extends Query {

private final KnnFloatVectorQuery knnQuery;

public VectorRescoreQuery(KnnFloatVectorQuery knnQuery) {
this.knnQuery = knnQuery;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return super.createWeight(searcher, scoreMode, boost);
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}

@Override
public String toString(String field) {
return "";
}

@Override
public void visit(QueryVisitor visitor) {

}

@Override
public boolean equals(Object obj) {
return false;
}

@Override
public int hashCode() {
return 0;
}
}
Loading

0 comments on commit 3a65731

Please sign in to comment.