From bd920c5ed4390de88ac43a96f1caea52ae18a5bc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Nov 2024 17:36:55 +0100 Subject: [PATCH] Add tests --- .../search/KnnSearchSingleNodeTests.java | 4 +- .../index/query/NestedQueryBuilderTests.java | 1 + ...AbstractKnnVectorQueryBuilderTestCase.java | 109 +++++++++++++----- .../KnnByteVectorQueryBuilderTests.java | 10 +- .../KnnFloatVectorQueryBuilderTests.java | 10 +- .../search/vectors/KnnSearchBuilderTests.java | 2 +- 6 files changed, 101 insertions(+), 35 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 042890001c2ea..a52e3bc910bc2 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -417,7 +417,9 @@ public void testKnnSearchAction() throws IOException { // how the action works (it builds a kNN query under the hood) float[] queryVector = randomVector(); assertResponse( - client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2), + client().prepareSearch("index1", "index2") + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard assertHitCount(response, 5 * 2); diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 6076665e26824..7f4f95cdd2416 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, null, 1, + null, null ); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f93bdd14f0645..d603ad7e39b1f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -56,7 +57,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa abstract DenseVectorFieldMapper.ElementType elementType(); - abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity); + abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ); @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { @@ -88,7 +95,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; Integer k = randomBoolean() ? null : randomIntBetween(1, 100); int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); - KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat()); + KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( + fieldName, + k, + numCands, + randomRescoreVectorBuilder(), + randomFloat() + ); if (randomBoolean()) { List filters = new ArrayList<>(); @@ -99,11 +112,24 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { } queryBuilder.addFilterQueries(filters); } + return queryBuilder; } + protected RescoreVectorBuilder randomRescoreVectorBuilder() { + if (randomBoolean()) { + return null; + } + + return new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { + if (queryBuilder.rescoreVectorBuilder() != null) { + assertTrue(query instanceof org.apache.lucene.queries.function.FunctionScoreQuery); + query = ((org.apache.lucene.queries.function.FunctionScoreQuery) query).getWrappedQuery(); + } if (queryBuilder.getVectorSimilarity() != null) { assertTrue(query instanceof VectorSimilarityQuery); Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery(); @@ -126,21 +152,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que BooleanQuery booleanQuery = builder.build(); Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; // The field should always be resolved to the concrete field + Integer k = queryBuilder.k(); + Integer numCands = queryBuilder.numCands(); + if (queryBuilder.rescoreVectorBuilder() != null) { + Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample(); + k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); + numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands); + } + Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asByteVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); - case FLOAT -> new ESKnnFloatVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asFloatVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); + case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); + case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); @@ -150,7 +172,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), @@ -160,7 +182,7 @@ public void testWrongDimension() { public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); context.setAllowUnmappedFields(false); QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); @@ -168,7 +190,7 @@ public void testNonexistentField() { public void testNonexistentFieldReturnEmpty() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); Query queryNone = query.doToQuery(context); assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } @@ -180,6 +202,7 @@ public void testWrongFieldType() { new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, + null, null ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); @@ -191,14 +214,14 @@ public void testNumCandsLessThanK() { int numCands = 3; IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null) + () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @Override public void testValidOutput() { - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null); String expected = """ { "knn" : { @@ -213,7 +236,7 @@ public void testValidOutput() { }"""; assertEquals(expected, query.toString()); - KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null); + KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null); String expected2 = """ { "knn" : { @@ -240,6 +263,7 @@ public void testMustRewrite() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION, null, + null, null ); query.addFilterQuery(termQuery); @@ -254,9 +278,14 @@ public void testMustRewrite() throws IOException { public void testBWCVersionSerializationFilters() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()); + KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( random(), TransportVersions.V_8_0_0, @@ -268,10 +297,14 @@ public void testBWCVersionSerializationFilters() throws IOException { public void testBWCVersionSerializationSimilarity() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()) - .addFilterQueries(query.filterQueries()); + KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); } @@ -289,11 +322,29 @@ public void testBWCVersionSerializationQuery() throws IOException { vectorData, null, query.numCands(), + null, similarity ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); } + public void testBWCVersionSerializationRescoreVector() throws IOException { + KnnVectorQueryBuilder query = createTestQueryBuilder(); + KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder( + query.getFieldName(), + query.queryVector(), + query.k(), + query.numCands(), + null, + query.getVectorSimilarity() + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + assertBWCSerialization( + query, + queryNoRescoreVector, + TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_8_0, TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) + ); + } + private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { assertSerialization(bwcQuery, version); try (BytesStreamOutput output = new BytesStreamOutput()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 0fc2304e904a4..980e506c0ca35 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { + protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { byte[] vector = new byte[VECTOR_DIMENSION]; for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index ba2245ced3305..75b1f395c57e7 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { + KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { float[] vector = new float[VECTOR_DIMENSION]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 2184e8af54aed..3753e8ce874cb 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -167,7 +167,7 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries) + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, null, similarity).addFilterQueries(filterQueries) .boost(boost); assertEquals(expected, builder.toQueryBuilder()); }