From a7936da7e137fec3ce20f55aec9f11cff6583b87 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 20 Nov 2024 18:05:41 +0100 Subject: [PATCH] Use KnnRescoreVectorQuery to perform rescoring and limiting the number of results from each shard --- .../vectors/DenseVectorFieldMapper.java | 63 +++----- .../search/vectors/KnnRescoreVectorQuery.java | 151 ++++++++++++++++++ 2 files changed, 172 insertions(+), 42 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 6792202462297..4d7366dd3dda9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -30,7 +30,6 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -71,6 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.KnnRescoreVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; @@ -2019,16 +2019,6 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } - if (rescoreOversample != null && indexOptions.type.isQuantized() == false) { - throw new IllegalArgumentException( - "cannot use rescore oversample on field [" - + name() - + "], that uses non-quantized type [" - + indexOptions.type - + "]. " - + "Only quantized index option types support rescore oversample." - ); - } return switch (getElementType()) { case BYTE -> createKnnByteQuery( queryVector.asByteVector(), @@ -2060,6 +2050,10 @@ public Query createKnnQuery( }; } + private boolean needsRescore(Float rescoreOversample) { + return rescoreOversample != null && (indexOptions == null || indexOptions.type == null || indexOptions.type.isQuantized()); + } + private Query createKnnBitQuery( byte[] queryVector, Integer k, @@ -2084,17 +2078,6 @@ private Query createKnnBitQuery( similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityByteValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) - ) - ); - - } return knnQuery; } @@ -2113,7 +2096,7 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Integer adjustedK = k == null || rescoreOversample == null + Integer adjustedK = k == null || needsRescore(rescoreOversample) == false ? null : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands); @@ -2128,16 +2111,14 @@ private Query createKnnByteQuery( similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityByteValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) - ) + if (needsRescore(rescoreOversample)) { + knnQuery = new KnnRescoreVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE), + k, + knnQuery ); - } return knnQuery; } @@ -2167,7 +2148,7 @@ && isNotUnitVector(squaredMagnitude)) { } } - Integer adjustedK = k == null || rescoreOversample == null + Integer adjustedK = k == null || needsRescore(rescoreOversample) == false ? k : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands); @@ -2181,16 +2162,14 @@ && isNotUnitVector(squaredMagnitude)) { similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityFloatValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT) - ) + if (needsRescore(rescoreOversample)) { + knnQuery = new KnnRescoreVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), + k, + knnQuery ); - } return knnQuery; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java new file mode 100644 index 0000000000000..092217d1be7ab --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; +import org.elasticsearch.search.profile.query.QueryProfiler; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Wraps a kNN vector query to rescore the results using the non-quantized vectors + */ +public class KnnRescoreVectorQuery extends Query implements ProfilingQuery { + private final String fieldName; + private final byte[] byteTarget; + private final float[] floatTarget; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final Integer k; + private final Query vectorQuery; + + private long vectorOpsCount; + + public KnnRescoreVectorQuery( + String fieldName, + byte[] byteTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query vectorQuery + ) { + this.fieldName = fieldName; + this.byteTarget = byteTarget; + this.floatTarget = null; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.vectorQuery = vectorQuery; + } + + public KnnRescoreVectorQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query vectorQuery + ) { + this.fieldName = fieldName; + this.byteTarget = null; + this.floatTarget = floatTarget; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.vectorQuery = vectorQuery; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query rewritten = super.rewrite(searcher); + if (rewritten != this) { + return rewritten; + } + + final DoubleValuesSource valueSource; + if (byteTarget != null) { + valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction); + } else { + valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + } + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource); + Query query = searcher.rewrite(functionScoreQuery); + + if (k == null) { + // No need to calculate top k - let the request size limit the results + return query; + } + + TopDocs topDocs = searcher.search(query, k); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + int[] docIds = new int[scoreDocs.length]; + float[] scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docIds[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } + + vectorOpsCount = scoreDocs.length; + + return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.setVectorOpsCount(vectorOpsCount); + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(fieldName)) { + visitor.visitLeaf(this); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o; + return Objects.equals(fieldName, that.fieldName) + && Objects.deepEquals(byteTarget, that.byteTarget) + && Objects.deepEquals(floatTarget, that.floatTarget) + && vectorSimilarityFunction == that.vectorSimilarityFunction + && Objects.equals(k, that.k) + && Objects.equals(vectorQuery, that.vectorQuery); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery); + } + + @Override + public String toString(String field) { + final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{"); + sb.append("fieldName='").append(fieldName).append('\''); + if (byteTarget != null) { + sb.append(", byteTarget=").append(Arrays.toString(byteTarget)); + } else { + sb.append(", floatTarget=").append(Arrays.toString(floatTarget)); + } + sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); + sb.append(", k=").append(k); + sb.append(", vectorQuery=").append(vectorQuery); + sb.append('}'); + return sb.toString(); + } +}