From 3a65731de67c23ad64385ab8f46962d4e7594dce Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 31 Oct 2024 13:41:08 +0100 Subject: [PATCH] WIP - add rescore to custom ESKnnFloatVectorQuery to do exact search after an approximate search --- .../org/elasticsearch/TransportVersions.java | 1 + .../vectors/DenseVectorFieldMapper.java | 14 ++- .../search/vectors/ESKnnFloatVectorQuery.java | 56 +++++++++- .../search/vectors/KnnVectorQueryBuilder.java | 60 +++++++++- .../search/vectors/VectorRescoreQuery.java | 58 ++++++++++ .../vectors/VectorRescorerQueryBuilder.java | 104 ++++++++++++++++++ .../vectors/DenseVectorFieldMapperTests.java | 40 ++++++- .../vectors/DenseVectorFieldTypeTests.java | 30 +++-- ...AbstractKnnVectorQueryBuilderTestCase.java | 1 + 9 files changed, 338 insertions(+), 26 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/VectorRescoreQuery.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/VectorRescorerQueryBuilder.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 7bf3204b7e1a..f818fc6f61e4 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -185,6 +185,7 @@ static TransportVersion def(int id) { public static final TransportVersion INDEX_REQUEST_REMOVE_METERING = def(8_780_00_0); public static final TransportVersion CPU_STAT_STRING_PARSING = def(8_781_00_0); public static final TransportVersion QUERY_RULES_RETRIEVER = def(8_782_00_0); + public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_783_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 809532c0e8f5..19bd26bf36b2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1983,6 +1983,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, + Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -1994,7 +1995,15 @@ public Query createKnnQuery( } return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); - case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter); + case FLOAT -> createKnnFloatQuery( + queryVector.asFloatVector(), + k, + numCands, + rescoreOversample, + filter, + similarityThreshold, + parentFilter + ); case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } @@ -2052,6 +2061,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, + Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2073,7 +2083,7 @@ && isNotUnitVector(squaredMagnitude)) { } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter); + : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, rescoreOversample, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index be0437af9131..cd5058b960cc 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -9,18 +9,35 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TimeLimitingKnnCollectorManager; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; import org.elasticsearch.search.profile.query.QueryProfiler; +import java.io.IOException; + public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery { private final Integer kParam; private long vectorOpsCount; + private final Float rescoreOversample; - public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) { - super(field, target, numCands, filter); + public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Float rescoreOversample, Query filter) { + super(field, target, adjustCandidates(numCands, rescoreOversample), filter); this.kParam = k; + this.rescoreOversample = rescoreOversample; + } + + private static int adjustCandidates(int numCands, Float rescoreOversample) { + return rescoreOversample == null ? numCands : (int) Math.ceil(numCands * rescoreOversample); } @Override @@ -31,8 +48,43 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return topK; } + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager + ) throws IOException { + TopDocs topDocs = super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); + if (rescoreOversample == null) { + return topDocs; + } + + BitSet exactSearchAcceptDocs = topDocsToBitSet(topDocs, acceptDocs.length()); + BitSetIterator bitSetIterator = new BitSetIterator(exactSearchAcceptDocs, topDocs.scoreDocs.length); + QueryTimeout queryTimeout = null; + if (knnCollectorManager instanceof TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) { + queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout(); + } + return exactSearch(context, bitSetIterator, queryTimeout); + } + @Override public void profile(QueryProfiler queryProfiler) { queryProfiler.setVectorOpsCount(vectorOpsCount); } + + // Convert TopDocs to BitSet + private static BitSet topDocsToBitSet(TopDocs topDocs, int numBits) { + // Create a FixedBitSet with a size equal to the maximum number of documents + BitSet bitSet = new FixedBitSet(numBits); + + // Iterate through each document in TopDocs + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + // Set the corresponding bit for each doc ID + bitSet.set(scoreDoc.doc); + } + + return bitSet; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index deb7e6bd035b..c9e30e6ba15d 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -45,6 +45,7 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -66,6 +67,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder queryVectorSupplier; + private final Float rescoreOversample; public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity); @@ -151,6 +156,19 @@ private KnnVectorQueryBuilder( Integer k, Integer numCands, Float vectorSimilarity + ) { + this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, 0F); + } + + private KnnVectorQueryBuilder( + String fieldName, + VectorData queryVector, + QueryVectorBuilder queryVectorBuilder, + Supplier queryVectorSupplier, + Integer k, + Integer numCands, + Float vectorSimilarity, + Float rescoreOversample ) { if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); @@ -187,6 +205,7 @@ private KnnVectorQueryBuilder( this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; + this.rescoreOversample = rescoreOversample; } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -227,6 +246,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.queryVectorBuilder = null; } + if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreOversample = in.readOptionalFloat(); + } else { + this.rescoreOversample = null; + } + this.queryVectorSupplier = null; } @@ -252,6 +277,10 @@ public Integer numCands() { return numCands; } + public Float rescoreOversample() { + return rescoreOversample; + } + public List filterQueries() { return filterQueries; } @@ -327,6 +356,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalNamedWriteable(queryVectorBuilder); } + if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalFloat(rescoreOversample); + } } @Override @@ -491,14 +523,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet); + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + rescoreOversample, + filterQuery, + vectorSimilarity, + parentBitSet + ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, rescoreOversample, filterQuery, vectorSimilarity, null); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder); + return Objects.hash( + fieldName, + Objects.hashCode(queryVector), + k, + numCands, + filterQueries, + vectorSimilarity, + queryVectorBuilder, + rescoreOversample + ); } @Override @@ -509,7 +558,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(numCands, other.numCands) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) - && Objects.equals(queryVectorBuilder, other.queryVectorBuilder); + && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) + && Objects.equals(rescoreOversample, other.rescoreOversample); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorRescoreQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorRescoreQuery.java new file mode 100644 index 000000000000..14da01c9c93e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorRescoreQuery.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +public class VectorRescoreQuery extends Query { + + private final KnnFloatVectorQuery knnQuery; + + public VectorRescoreQuery(KnnFloatVectorQuery knnQuery) { + this.knnQuery = knnQuery; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return super.createWeight(searcher, scoreMode, boost); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); + } + + @Override + public String toString(String field) { + return ""; + } + + @Override + public void visit(QueryVisitor visitor) { + + } + + @Override + public boolean equals(Object obj) { + return false; + } + + @Override + public int hashCode() { + return 0; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorRescorerQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorRescorerQueryBuilder.java new file mode 100644 index 000000000000..1524cb4c52d4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorRescorerQueryBuilder.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.routing.Preference; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; + +// WIP - Perform a query on the rewrite phase, then do an exact query with a filter +public class VectorRescorerQueryBuilder extends AbstractQueryBuilder { + + private final QueryBuilder knnQueryBuilder; + + public VectorRescorerQueryBuilder(QueryBuilder knnQueryBuilder) { + this.knnQueryBuilder = knnQueryBuilder; + } + + @Override + protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { + QueryBuilder rewrittenQueryBuilder = knnQueryBuilder.rewrite(context); + if (rewrittenQueryBuilder != knnQueryBuilder) { + return new VectorRescorerQueryBuilder(rewrittenQueryBuilder); + } + + // Query query = knnQueryBuilder.toQuery(context); + context.registerAsyncAction((client, listener) -> { + String[] indices = Arrays.stream(context.getResolvedIndices().getConcreteLocalIndices()) + .map(Index::getName) + .toArray(String[]::new); + client.prepareSearch(indices) + .setPreference(Preference.ONLY_LOCAL.type()) + .setQuery(knnQueryBuilder) + .execute(new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + + } + + @Override + public void onFailure(Exception e) { + + } + }); + }); + // client.execute(new ESKnnFloatVectorQuery("field", new float[0], 0, 0, query), listener); + // }); + + return this; // super.doIndexMetadataRewrite(context); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + return null; + } + + @Override + protected boolean doEquals(VectorRescorerQueryBuilder other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return ""; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index de084cd4582e..1d949c657c2f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1683,7 +1683,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1692,7 +1700,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1701,7 +1709,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1710,7 +1718,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1722,6 +1738,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1738,6 +1755,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1765,7 +1783,15 @@ public void testFloatVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1777,6 +1803,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1793,6 +1820,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 6433cf2f1c0d..c8aba47b91d2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -169,7 +169,7 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -190,11 +190,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -255,7 +255,15 @@ public void testFloatCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), + 10, + 10, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -275,7 +283,7 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -291,7 +299,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -312,7 +320,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -332,7 +340,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -350,7 +358,7 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -366,13 +374,13 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f93bdd14f064..f166784e0ca1 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -139,6 +139,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que queryBuilder.queryVector().asFloatVector(), queryBuilder.k(), queryBuilder.numCands(), + queryBuilder.rescoreOversample(), filterQuery ); };