diff --git a/docs/changelog/116663.yaml b/docs/changelog/116663.yaml new file mode 100644 index 0000000000000..40bcdea29bc31 --- /dev/null +++ b/docs/changelog/116663.yaml @@ -0,0 +1,5 @@ +pr: 116663 +summary: KNN vector rescoring for quantized vectors +area: Vector Search +type: feature +issues: [] diff --git a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java index 05f456b7f2229..8a7f1405f8f4e 100644 --- a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java +++ b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java @@ -1359,7 +1359,7 @@ public void testKnnQueryNotSupportedInPercolator() throws IOException { """); indicesAdmin().prepareCreate("index1").setMapping(mappings).get(); ensureGreen(); - QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null); + QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null, null); IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1") .setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject()); diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index d08a8e2a6d39c..e49f0634a4887 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -18,7 +18,7 @@ setup: dims: 5 index: true index_options: - type: hnsw + type: int8_hnsw similarity: l2_norm - do: @@ -73,3 +73,59 @@ setup: - match: {hits.total.value: 1} - match: {hits.hits.0._id: "3"} - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + knn: + field: vector + query_vector: [2, 2, 2, 2, 3] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [2, 2, 2, 2, 3] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml new file mode 100644 index 0000000000000..d4bf5e7e9807f --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -0,0 +1,137 @@ +setup: + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [ capabilities ] + capabilities: + - method: GET + path: /_search + capabilities: [ knn_quantized_vector_rescore ] + - skip: + features: "headers" + + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 +--- +"Profile rescored knn search": + + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + "rescore_vector": + "num_candidates_factor": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } + + # Search with similarity to check number of operations are propagated correctly + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + similarity: 100000 + "rescore_vector": + "num_candidates_factor": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index b3d86a066550e..534db18b5eb9c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -541,3 +541,58 @@ setup: num_candidates: 3 - match: { hits.total.value: 0 } +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 5767c895fbe7e..2567a4ac597d9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -108,6 +108,75 @@ setup: - match: { hits.hits.1._id: "3" } - match: { hits.hits.2._id: "2" } --- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test bad quantization parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index b7a5517309949..b1e35789e8737 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -368,6 +368,65 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +# Won't be true for larger datasets, but this helps checking kNN vs rescoring vs exact search +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: hnsw_byte_quantized + body: + size: 3 + query: + knn: + k: 3 + num_candidates: 3 + field: vector + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 5f1af2ca5c52f..54e9eadf42e0b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -549,6 +549,62 @@ setup: - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} --- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: hnsw_byte_quantized + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test odd dimensions fail indexing": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index dcdae04aeabb4..a3cd624ef0ab8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -107,6 +107,75 @@ setup: - match: { hits.hits.1._id: "3" } - match: { hits.hits.2._id: "2" } --- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index 1b439967ba163..a59aedceff3d3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -257,6 +257,61 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index b9a0b16f2bd7a..6796a92122f9a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -344,3 +344,58 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int4_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 139747c5e7ee5..d1d312449cb70 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -262,6 +262,60 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int8_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index 02576ad1b2b01..effa3fff61525 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -405,3 +405,59 @@ setup: - match: {hits.hits.0._id: "1"} - match: {hits.hits.0._source.vector1: [2, -1, 1, 4, -3]} - match: {hits.hits.0._source.vector2: [2, -1, 1, 4, -3]} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index ec7bde4de8435..cdc1d9c64763e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -221,3 +221,59 @@ setup: similarity: l2_norm index_options: type: int8_hnsw + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 0cedfaa873095..213b571a0b4be 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -254,3 +254,60 @@ setup: filter: {"term": {"name": "cow.jpg"}} - length: {hits.hits: 0} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } + diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java index d1021715ceffc..aaab14941d4bb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java @@ -69,7 +69,7 @@ public void testSimpleNested() throws Exception { assertResponse( prepareSearch("test").setKnnSearch( - List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null).innerHit(new InnerHitBuilder())) + List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null, null).innerHit(new InnerHitBuilder())) ).setAllowPartialSearchResults(false), response -> assertThat(response.getHits().getHits().length, greaterThan(0)) ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java index 876edc282c903..95d69a6ebaa86 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentFactory; @@ -71,6 +72,7 @@ public void testProfileDfs() throws Exception { new float[] { randomFloat(), randomFloat(), randomFloat() }, randomIntBetween(5, 10), 50, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); if (randomBoolean()) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java index 537ace30e88f0..40849bea5512e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -84,7 +84,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -98,7 +100,7 @@ public void testTelemetryForRetrievers() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -112,7 +114,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#5 - t // his will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 825b49b50c4ed..7151791d0519a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -143,6 +143,7 @@ static TransportVersion def(int id) { public static final TransportVersion NEW_REFRESH_CLUSTER_BLOCK = def(8_803_00_0); public static final TransportVersion RETRIES_AND_OPERATIONS_IN_BLOBSTORE_STATS = def(8_804_00_0); public static final TransportVersion ADD_DATA_STREAM_OPTIONS_TO_TEMPLATES = def(8_805_00_0); + public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_806_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 08835e0552371..b206c503a2739 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -69,6 +69,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; @@ -122,6 +123,7 @@ public static boolean isNotUnitVector(float magnitude) { public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; + public static final int NUM_CANDS_OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -1215,7 +1217,7 @@ public final int hashCode() { } private enum VectorIndexType { - HNSW("hnsw") { + HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1242,7 +1244,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_HNSW("int8_hnsw") { + INT8_HNSW("int8_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1274,7 +1276,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_HNSW("int4_hnsw") { + INT4_HNSW("int4_hnsw", true) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); @@ -1305,7 +1307,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - FLAT("flat") { + FLAT("flat", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1322,7 +1324,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_FLAT("int8_flat") { + INT8_FLAT("int8_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1344,7 +1346,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_FLAT("int4_flat") { + INT4_FLAT("int4_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1366,7 +1368,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - BBQ_HNSW("bbq_hnsw") { + BBQ_HNSW("bbq_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1393,7 +1395,7 @@ public boolean supportsDimension(int dims) { return dims >= BBQ_MIN_DIMS; } }, - BBQ_FLAT("bbq_flat") { + BBQ_FLAT("bbq_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1416,9 +1418,11 @@ static Optional fromString(String type) { } private final String name; + private final boolean quantized; - VectorIndexType(String name) { + VectorIndexType(String name, boolean quantized) { this.name = name; + this.quantized = quantized; } abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); @@ -1427,6 +1431,10 @@ static Optional fromString(String type) { public abstract boolean supportsDimension(int dims); + public boolean isQuantized() { + return quantized; + } + @Override public String toString() { return name; @@ -1999,6 +2007,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2010,11 +2019,23 @@ public Query createKnnQuery( } return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); - case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter); + case FLOAT -> createKnnFloatQuery( + queryVector.asFloatVector(), + k, + numCands, + numCandsFactor, + filter, + similarityThreshold, + parentFilter + ); case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } + private boolean needsRescore(Float rescoreOversample) { + return rescoreOversample != null && (indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized()); + } + private Query createKnnBitQuery( byte[] queryVector, Integer k, @@ -2068,6 +2089,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2087,9 +2109,27 @@ && isNotUnitVector(squaredMagnitude)) { } } } + + Integer adjustedK = k; + int adjustedNumCands = numCands; + if (needsRescore(numCandsFactor)) { + // Get all candidates, get top k as part of rescoring + adjustedK = null; + // numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. + adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT); + } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); + if (needsRescore(numCandsFactor)) { + knnQuery = new RescoreKnnVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), + k, + knnQuery + ); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java new file mode 100644 index 0000000000000..74a7dbe168e6b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DoubleValues; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.search.vectors.QueryProfilerProvider; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the + * original vector values stored in the index + */ +public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider { + + private final String field; + private final float[] target; + private final VectorSimilarityFunction vectorSimilarityFunction; + private long vectorOpsCount; + + public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final LeafReader reader = ctx.reader(); + + FloatVectorValues vectorValues = reader.getFloatVectorValues(field); + final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + vectorOpsCount++; + return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index())); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o; + return Objects.equals(field, that.field) + && Arrays.equals(target, that.target) + && vectorSimilarityFunction == that.vectorSimilarityFunction; + } + + @Override + public String toString() { + return "VectorSimilarityFloatValueSource(" + field + ", [" + target[0] + ",...], " + vectorSimilarityFunction + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index b4dadf8712199..7b6ee6f7806c0 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -43,6 +43,7 @@ private SearchCapabilities() {} private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; + private static final String KNN_QUANTIZED_VECTOR_RESCORE = "knn_quantized_vector_rescore"; public static final Set CAPABILITIES; static { @@ -54,6 +55,7 @@ private SearchCapabilities() {} capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); + capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); if (RankVectorsFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(RANK_VECTORS_FIELD_MAPPER); capabilities.add(RANK_VECTORS_SCRIPT_ACCESS); diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 76b3f45ffb84a..6a99b51ac679c 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -34,7 +34,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.search.vectors.ProfilingQuery; +import org.elasticsearch.search.vectors.QueryProfilerProvider; import org.elasticsearch.tasks.TaskCancelledException; import java.io.IOException; @@ -224,8 +224,8 @@ static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ); topDocs = searcher.search(knnQuery, ipcm); - if (knnQuery instanceof ProfilingQuery profilingQuery) { - profilingQuery.profile(knnProfiler); + if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(knnProfiler); } knnProfiler.setCollectorResult(ipcm.getCollectorTree()); diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java index 98ddfa95bf156..23ce52b6c5b82 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java @@ -39,10 +39,18 @@ public QueryProfiler() { super(new InternalQueryProfileTree()); } - public void setVectorOpsCount(long vectorOpsCount) { - this.vectorOpsCount = vectorOpsCount; + /** + * Adds a number of vector operations to the current count + * @param vectorOpsCount number of vector ops to add to the profiler + */ + public void addVectorOpsCount(long vectorOpsCount) { + this.vectorOpsCount += vectorOpsCount; } + /** + * Retrieves the number of vector operations performed by the queries + * @return number of vector operations performed by the queries + */ public long getVectorOpsCount() { return this.vectorOpsCount; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index f1464c41ca3be..b29546ded75cd 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -20,8 +20,10 @@ import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -52,6 +54,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -73,6 +76,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { (QueryVectorBuilder) args[2], (int) args[3], (int) args[4], + (RescoreVectorBuilder) args[6], (Float) args[5] ); } @@ -89,6 +93,12 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { PARSER.declareInt(constructorArg(), K_FIELD); PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -104,6 +114,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; + private final RescoreVectorBuilder rescoreVectorBuilder; private final Float similarity; public KnnRetrieverBuilder( @@ -112,6 +123,7 @@ public KnnRetrieverBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { if (queryVector == null && queryVectorBuilder == null) { @@ -137,6 +149,7 @@ public KnnRetrieverBuilder( this.k = k; this.numCands = numCands; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVector, QueryVectorBuilder queryVectorBuilder) { @@ -148,6 +161,7 @@ private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVe this.similarity = clone.similarity; this.retrieverName = clone.retrieverName; this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.rescoreVectorBuilder = clone.rescoreVectorBuilder; } @Override @@ -228,6 +242,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder null, k, numCands, + rescoreVectorBuilder, similarity ); if (preFilterQueryBuilders != null) { @@ -241,6 +256,10 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder searchSourceBuilder.knnSearch(knnSearchBuilders); } + RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + // ---- FOR TESTING XCONTENT PARSING ---- @Override @@ -260,6 +279,10 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept if (similarity != null) { builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity); } + + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } } @Override @@ -271,12 +294,13 @@ public boolean doEquals(Object o) { && ((queryVector == null && that.queryVector == null) || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder); } @Override public int doHashCode() { - int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); + int result = Objects.hash(field, queryVectorBuilder, k, numCands, rescoreVectorBuilder, similarity); result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 413840f2b451b..e77b86c9a363e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index 80704a3b552fe..9b8b4d3e3b008 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 14bb94a366e50..d47585c055094 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements ProfilingQuery { +public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -33,6 +33,10 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + public Integer kParam() { + return kParam; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index 590d8cfbbaba1..97ce1bc1d8347 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery { +public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -33,6 +33,10 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + public Integer kParam() { + return kParam; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index 06fb109d6580e..2855fe8bcf0eb 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -36,7 +37,13 @@ public class KnnScoreDocQuery extends Query { private final int[] docs; private final float[] scores; + + // the indexes in docs and scores corresponding to the first matching document in each segment. + // If a segment has no matching documents, it should be assigned the index of the next segment that does. + // There should be a final entry that is always docs.length-1. private final int[] segmentStarts; + + // an object identifying the reader context that was used to build this query private final Object contextIdentity; /** @@ -44,18 +51,31 @@ public class KnnScoreDocQuery extends Query { * * @param docs the global doc IDs of documents that match, in ascending order * @param scores the scores of the matching documents - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param contextIdentity an object identifying the reader context that was used to build this - * query + * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { this.docs = docs; this.scores = scores; - this.segmentStarts = segmentStarts; - this.contextIdentity = contextIdentity; + this.segmentStarts = findSegmentStarts(reader, docs); + this.contextIdentity = reader.getContext().id(); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index b5ba97906f0ec..6fa83ccfb6ac2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.TransportVersion; @@ -25,7 +24,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.Objects; /** @@ -151,9 +149,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { scores[i] = scoreDocs[i].score; } - IndexReader reader = context.getIndexReader(); - int[] segmentStarts = findSegmentStarts(reader, docs); - return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id()); + return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); } @Override @@ -167,24 +163,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return super.doRewrite(queryRewriteContext); } - private static int[] findSegmentStarts(IndexReader reader, int[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - @Override protected boolean doEquals(KnnScoreDocQueryBuilder other) { if (scoreDocs.length != other.scoreDocs.length) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 8ce8fc07f3acd..b18ce2dff65cb 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -56,6 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD; public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD; public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { @@ -65,7 +66,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea .queryVectorBuilder((QueryVectorBuilder) args[4]) .k((Integer) args[2]) .numCandidates((Integer) args[3]) - .similarity((Float) args[5]); + .similarity((Float) args[5]) + .rescoreVectorBuilder((RescoreVectorBuilder) args[6]); }); static { @@ -78,13 +80,18 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea ); PARSER.declareInt(optionalConstructorArg(), K_FIELD); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); - PARSER.declareNamedObject( optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); PARSER.declareFieldArray( KnnSearchBuilder.Builder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -116,6 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw String queryName; float boost = DEFAULT_BOOST; InnerHitBuilder innerHitBuilder; + private final RescoreVectorBuilder rescoreVectorBuilder; /** * Defines a kNN search. @@ -124,14 +132,23 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw * @param queryVector the query vector * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param rescoreVectorBuilder rescore vector information */ - public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, + rescoreVectorBuilder, similarity ); } @@ -144,8 +161,15 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) { - this(field, queryVector, null, k, numCands, similarity); + public KnnSearchBuilder( + String field, + VectorData queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity); } /** @@ -156,13 +180,21 @@ public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCand * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + QueryVectorBuilder queryVectorBuilder, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, null, Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())), k, numCands, + rescoreVectorBuilder, similarity ); } @@ -173,9 +205,22 @@ public KnnSearchBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST); + this( + field, + queryVectorBuilder, + queryVector, + new ArrayList<>(), + k, + numCands, + rescoreVectorBuilder, + similarity, + null, + null, + DEFAULT_BOOST + ); } private KnnSearchBuilder( @@ -183,6 +228,7 @@ private KnnSearchBuilder( Supplier querySupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, List filterQueries, Float similarity ) { @@ -194,6 +240,7 @@ private KnnSearchBuilder( this.filterQueries = filterQueries; this.querySupplier = querySupplier; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnSearchBuilder( @@ -203,6 +250,7 @@ private KnnSearchBuilder( List filterQueries, int k, int numCandidates, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity, InnerHitBuilder innerHitBuilder, String queryName, @@ -242,6 +290,7 @@ private KnnSearchBuilder( this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCandidates; + this.rescoreVectorBuilder = rescoreVectorBuilder; this.innerHitBuilder = innerHitBuilder; this.similarity = similarity; this.queryName = queryName; @@ -280,12 +329,25 @@ public KnnSearchBuilder(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(V_8_11_X)) { this.innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new); } + if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); + } else { + this.rescoreVectorBuilder = null; + } } public int k() { return k; } + public int getNumCands() { + return numCands; + } + + public RescoreVectorBuilder getRescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public QueryVectorBuilder getQueryVectorBuilder() { return queryVectorBuilder; } @@ -354,7 +416,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (querySupplier.get() == null) { return this; } - return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries) .innerHit(innerHitBuilder); @@ -377,7 +439,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost) + return new KnnSearchBuilder(field, toSet::get, k, numCands, rescoreVectorBuilder, filterQueries, similarity).boost(boost) .queryName(queryName) .innerHit(innerHitBuilder); } @@ -391,7 +453,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, queryVector, k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(rewrittenQueries) .innerHit(innerHitBuilder); @@ -403,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -419,6 +481,7 @@ public boolean equals(Object o) { KnnSearchBuilder that = (KnnSearchBuilder) o; return k == that.k && numCands == that.numCands + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder) && Objects.equals(field, that.field) && Objects.equals(queryVector, that.queryVector) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) @@ -438,6 +501,7 @@ public int hashCode() { numCands, querySupplier, queryVectorBuilder, + rescoreVectorBuilder, similarity, Objects.hashCode(queryVector), Objects.hashCode(filterQueries), @@ -482,6 +546,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (queryName != null) { builder.field(NAME_FIELD.getPreferredName(), queryName); } + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } return builder; } @@ -522,6 +589,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(V_8_11_X)) { out.writeOptionalWriteable(innerHitBuilder); } + if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalWriteable(rescoreVectorBuilder); + } } public static class Builder { @@ -536,6 +606,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; private InnerHitBuilder innerHitBuilder; + private RescoreVectorBuilder rescoreVectorBuilder; public Builder addFilterQueries(List filterQueries) { Objects.requireNonNull(filterQueries); @@ -588,6 +659,11 @@ public Builder similarity(Float similarity) { return this; } + public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) { + this.rescoreVectorBuilder = rescoreVectorBuilder; + return this; + } + public KnnSearchBuilder build(int size) { int requestSize = size < 0 ? DEFAULT_SIZE : size; int adjustedK = k == null ? requestSize : k; @@ -601,6 +677,7 @@ public KnnSearchBuilder build(int size) { filterQueries, adjustedK, adjustedNumCandidates, + rescoreVectorBuilder, similarity, innerHitBuilder, queryName, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index a28448336ab3f..81b00f1329591 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -256,7 +256,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null); + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index deb7e6bd035b8..c868274eb8a1b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -45,6 +45,7 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -68,8 +69,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( "knn", args -> new KnnVectorQueryBuilder( @@ -79,6 +80,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); PARSER.declareFieldArray( KnnVectorQueryBuilder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -115,14 +123,22 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { private final String fieldName; private final VectorData queryVector; private final Integer k; - private Integer numCands; + private final Integer numCands; private final List filterQueries = new ArrayList<>(); private final Float vectorSimilarity; private final QueryVectorBuilder queryVectorBuilder; private final Supplier queryVectorSupplier; + private final RescoreVectorBuilder rescoreVectorBuilder; - public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + float[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } public KnnVectorQueryBuilder( @@ -132,15 +148,29 @@ public KnnVectorQueryBuilder( Integer numCands, Float vectorSimilarity ) { - this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity); + this(fieldName, null, queryVectorBuilder, null, k, numCands, null, vectorSimilarity); } - public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + byte[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } - public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + VectorData queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, queryVector, null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } private KnnVectorQueryBuilder( @@ -150,6 +180,7 @@ private KnnVectorQueryBuilder( Supplier queryVectorSupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { if (k != null && k < 1) { @@ -187,6 +218,7 @@ private KnnVectorQueryBuilder( this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; + this.rescoreVectorBuilder = rescoreVectorBuilder; } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -227,6 +259,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.queryVectorBuilder = null; } + if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); + } else { + this.rescoreVectorBuilder = null; + } + this.queryVectorSupplier = null; } @@ -261,6 +299,10 @@ public QueryVectorBuilder queryVectorBuilder() { return queryVectorBuilder; } + public RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); @@ -327,6 +369,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalNamedWriteable(queryVectorBuilder); } + if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalWriteable(rescoreVectorBuilder); + } } @Override @@ -360,6 +405,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep } builder.endArray(); } + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -375,7 +423,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { if (queryVectorSupplier.get() == null) { return this; } - return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost) + return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, rescoreVectorBuilder, vectorSimilarity) + .boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -397,9 +446,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost( - boost - ).queryName(queryName).addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + toSet::get, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); } if (ctx.convertToInnerHitsRewriteContext() != null) { return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName); @@ -417,14 +473,25 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity) - .boost(boost) - .queryName(queryName) - .addFilterQueries(rewrittenQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + queryVectorSupplier, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); } return this; } + @Override + protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { + return super.doIndexMetadataRewrite(context); + } + @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); @@ -459,6 +526,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); + Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); if (parentPath != null) { final BitSetProducer parentBitSet; @@ -491,14 +559,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet); + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + numCandidatesFactor, + filterQuery, + vectorSimilarity, + parentBitSet + ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder); + return Objects.hash( + fieldName, + Objects.hashCode(queryVector), + k, + numCands, + filterQueries, + vectorSimilarity, + queryVectorBuilder, + rescoreVectorBuilder + ); } @Override @@ -509,7 +594,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(numCands, other.numCands) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) - && Objects.equals(queryVectorBuilder, other.queryVectorBuilder); + && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) + && Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java similarity index 96% rename from server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java rename to server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java index 4d36d8eae57cc..47b0e1e299968 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java @@ -18,7 +18,7 @@ * must provide an implementation for profile() to store profiling information in the {@link QueryProfiler}. */ -public interface ProfilingQuery { +public interface QueryProfilerProvider { /** * Store the profiling information in the {@link QueryProfiler} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java new file mode 100644 index 0000000000000..a9c606b1f8618 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; +import org.elasticsearch.search.profile.query.QueryProfiler; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field + */ +public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { + private final String fieldName; + private final float[] floatTarget; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final Integer k; + private final Query innerQuery; + + private QueryProfilerProvider vectorProfiling; + + public RescoreKnnVectorQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query innerQuery + ) { + this.fieldName = fieldName; + this.floatTarget = floatTarget; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.innerQuery = innerQuery; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + // Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need + // to calculate top k and return directly the query to understand how many comparisons were done + vectorProfiling = (QueryProfilerProvider) valueSource; + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); + Query query = searcher.rewrite(functionScoreQuery); + + if (k == null) { + // No need to calculate top k - let the request size limit the results. + return query; + } + + // Retrieve top k documents from the rescored query + TopDocs topDocs = searcher.search(query, k); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + int[] docIds = new int[scoreDocs.length]; + float[] scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docIds[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } + + return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + } + + public Query innerQuery() { + return innerQuery; + } + + public Integer k() { + return k; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); + } + + if (vectorProfiling == null) { + throw new IllegalStateException("Query should have been rewritten"); + } + vectorProfiling.profile(queryProfiler); + } + + @Override + public void visit(QueryVisitor visitor) { + innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o; + return Objects.equals(fieldName, that.fieldName) + && Arrays.equals(floatTarget, that.floatTarget) + && vectorSimilarityFunction == that.vectorSimilarityFunction + && Objects.equals(k, that.k) + && Objects.equals(innerQuery, that.innerQuery); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); + } + + @Override + public String toString(String field) { + return "KnnRescoreVectorQuery{" + + "fieldName='" + + fieldName + + '\'' + + ", floatTarget=" + + floatTarget[0] + + "..." + + ", vectorSimilarityFunction=" + + vectorSimilarityFunction + + ", k=" + + k + + ", vectorQuery=" + + innerQuery + + '}'; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java new file mode 100644 index 0000000000000..4604d4f0ea325 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class RescoreVectorBuilder implements Writeable, ToXContentObject { + + public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor"); + public static final float MIN_OVERSAMPLE = 1.0F; + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "rescore_vector", + args -> new RescoreVectorBuilder((Float) args[0]) + ); + + static { + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), NUM_CANDIDATES_FACTOR_FIELD); + } + + // Oversample is required as of now as it is the only field in the rescore vector + private final float numCandidatesFactor; + + public RescoreVectorBuilder(float numCandidatesFactor) { + Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); + if (numCandidatesFactor < MIN_OVERSAMPLE) { + throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); + } + this.numCandidatesFactor = numCandidatesFactor; + } + + public RescoreVectorBuilder(StreamInput in) throws IOException { + this.numCandidatesFactor = in.readFloat(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(numCandidatesFactor); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUM_CANDIDATES_FACTOR_FIELD.getPreferredName(), numCandidatesFactor); + builder.endObject(); + return builder; + } + + public static RescoreVectorBuilder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreVectorBuilder that = (RescoreVectorBuilder) o; + return Objects.equals(numCandidatesFactor, that.numCandidatesFactor); + } + + @Override + public int hashCode() { + return Objects.hashCode(numCandidatesFactor); + } + + public float numCandidatesFactor() { + return numCandidatesFactor; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java index 77f60adc4fcd8..a41f3afbe47f0 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.elasticsearch.common.lucene.search.function.MinScoreScorer; +import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Objects; @@ -27,9 +28,10 @@ import static org.elasticsearch.common.Strings.format; /** - * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery. + * This query provides a simple post-filter for the provided Query to limit the results of the inner query to those that have a similarity + * above a certain threshold */ -public class VectorSimilarityQuery extends Query { +public class VectorSimilarityQuery extends Query implements QueryProfilerProvider { private final float similarity; private final float docScore; private final Query innerKnnQuery; @@ -77,6 +79,13 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new MinScoreWeight(innerWeight, docScore, similarity, this, boost); } + @Override + public void profile(QueryProfiler queryProfiler) { + if (innerKnnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); + } + } + @Override public String toString(String field) { return "VectorSimilarityQuery[" diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 05198a7d49e70..dd648f1dfd65d 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -344,8 +344,8 @@ public void testRewriteShardSearchRequestWithRank() { SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) .knnSearch( List.of( - new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null), - new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null) + new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null, null), + new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null, null) ) ) .rankBuilder(new TestRankBuilder(100)); diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 042890001c2ea..353188af8be3c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -63,7 +63,7 @@ public void testKnnSearchRemovedVector() throws IOException { client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null, null).boost(5.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -107,7 +107,7 @@ public void testKnnWithQuery() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f).queryName("knn"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f).queryName("knn"); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -156,7 +156,7 @@ public void testKnnFilter() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsQuery("field", "second") ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { @@ -199,7 +199,7 @@ public void testKnnFilterWithRewrite() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field")) ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { @@ -246,8 +246,8 @@ public void testMultiKnnClauses() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(20f, 21f); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null).boost(10.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null).boost(10.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch, knnSearch2)) @@ -308,8 +308,8 @@ public void testMultiKnnClausesSameDoc() throws IOException { float[] queryVector = randomVector(); // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -381,7 +381,7 @@ public void testKnnFilteredAlias() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null, null); final int expectedHitCount = expectedHits; assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { assertHitCount(response, expectedHitCount); @@ -417,7 +417,9 @@ public void testKnnSearchAction() throws IOException { // how the action works (it builds a kNN query under the hood) float[] queryVector = randomVector(); assertResponse( - client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2), + client().prepareSearch("index1", "index2") + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard assertHitCount(response, 5 * 2); @@ -450,7 +452,7 @@ public void testKnnVectorsWith4096Dims() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(4096); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null, null).boost(5.0f); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { assertHitCount(response, 3); assertEquals(3, response.getHits().getHits().length); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index c83427a975a54..93c7a66a2960f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; @@ -116,8 +117,22 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( List.of( - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat()), - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5, randomBoolean() ? null : randomFloat()) + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ), + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 4, 12, 41 }, + 3, + 5, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) ) ); expectThrows( @@ -132,7 +147,16 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( - List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat())) + List.of( + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) + ) ); // Shouldn't throw because its just one KNN request copyWriteable( @@ -143,6 +167,10 @@ public void testSerializationMultiKNN() throws Exception { ); } + private static RescoreVectorBuilder randomRescoreVectorBuilder() { + return randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + public void testRandomVersionSerialization() throws IOException { SearchRequest searchRequest = createSearchRequest(); TransportVersion version = TransportVersionUtils.randomVersion(random()); @@ -482,7 +510,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(0) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -494,7 +522,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(1)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(2) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -521,7 +549,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ).scroll(new TimeValue(1000)); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -532,7 +560,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(9)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -546,7 +574,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(3)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(3) .from(4) ); @@ -557,7 +585,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .addRescorer(new QueryRescorerBuilder(QueryBuilders.termQuery("rescore", "another term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -569,7 +597,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .suggest(new SuggestBuilder().setGlobalText("test").addSuggestion("suggestion", new TermSuggestionBuilder("term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index a9de118c6b859..ed3d26141fe04 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1367,7 +1367,7 @@ public void testShouldMinimizeRoundtrips() throws Exception { { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); searchRequest.setCcsMinimizeRoundtrips(true); @@ -1382,7 +1382,7 @@ public void testAdjustSearchType() { // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); TransportSearchAction.adjustSearchType(searchRequest, randomBoolean()); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index afbf19db455f3..6e13faa99b4b5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1692,7 +1692,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1701,7 +1701,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1710,7 +1718,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1719,7 +1727,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1728,7 +1736,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1740,6 +1756,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1756,6 +1773,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1783,7 +1801,15 @@ public void testFloatVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1795,6 +1821,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1811,6 +1838,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 9e819f38eae6e..d37b4a4bacb4e 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -23,6 +23,9 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; +import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; @@ -31,8 +34,12 @@ import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; public class DenseVectorFieldTypeTests extends FieldTypeTestCase { private final boolean indexed; @@ -69,11 +76,27 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { ); } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { + return randomFrom( + new DenseVectorFieldMapper.Int8HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.Int4HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + ); + } + private DenseVectorFieldType createFloatFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, indexed, VectorSimilarity.COSINE, @@ -86,7 +109,7 @@ private DenseVectorFieldType createByteFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 5, true, VectorSimilarity.COSINE, @@ -159,7 +182,7 @@ public void testCreateNestedKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -170,14 +193,14 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -191,11 +214,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -209,7 +232,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -227,7 +250,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -247,7 +270,7 @@ public void testFloatCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4, false, VectorSimilarity.COSINE, @@ -256,14 +279,22 @@ public void testFloatCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), + 10, + 10, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); DenseVectorFieldType dotProductField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.DOT_PRODUCT, @@ -276,14 +307,14 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.COSINE, @@ -292,7 +323,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -302,7 +333,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4096, true, VectorSimilarity.COSINE, @@ -313,7 +344,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -321,7 +352,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 4096, true, VectorSimilarity.COSINE, @@ -333,7 +364,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -342,7 +373,7 @@ public void testByteCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, false, VectorSimilarity.COSINE, @@ -351,14 +382,14 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, true, VectorSimilarity.COSINE, @@ -367,14 +398,94 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } + + public void testRescoreOversampleUsedWithoutQuantization() { + DenseVectorFieldMapper.ElementType elementType = randomFrom(FLOAT, BYTE); + DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( + "f", + IndexVersion.current(), + elementType, + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsNonQuantized(), + Collections.emptyMap() + ); + + Query knnQuery = nonQuantizedField.createKnnQuery( + new VectorData(null, new byte[] { 1, 4, 10 }), + 10, + 100, + randomFloatBetween(1.0F, 10.0F, false), + null, + null, + null + ); + + if (elementType == BYTE) { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } + } + + public void testRescoreOversampleModifiesNumCandidates() { + DenseVectorFieldType fieldType = new DenseVectorFieldType( + "f", + IndexVersion.current(), + FLOAT, + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsHnswQuantized(), + Collections.emptyMap() + ); + + // Total results is k, internal k is multiplied by oversample + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10); + // If numCands < k, update numCands to k + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10); + // Oversampling limits for num candidates + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000); + checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000); + } + + private static void checkRescoreQueryParameters( + DenseVectorFieldType fieldType, + Integer k, + int candidates, + float numCandsFactor, + Integer expectedK, + int expectedCandidates, + int expectedResults + ) { + Query query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + k, + candidates, + numCandsFactor, + null, + null, + null + ); + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 6076665e26824..7f4f95cdd2416 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, null, 1, + null, null ); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java index 8e424986f04ee..acf2605cc4a4e 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java @@ -105,7 +105,7 @@ public void testValidateSearchRequest() { ).withMethod(RestRequest.Method.GET).withPath("/some_index/_search").withParams(params).build(); SearchRequest searchRequest = new SearchRequest(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null, null); searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch))); Exception ex = expectThrows( diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 240a677f4cbfd..cdbf4cdff15a7 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -848,7 +848,7 @@ public void testSearchSectionsUsageCollection() throws IOException { searchSourceBuilder.fetchField("field"); // these are not correct runtime mappings but they are counted compared to empty object searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword")); - searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null))); + searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null, null))); searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("pitid"))); searchSourceBuilder.docValueField("field"); searchSourceBuilder.storedField("field"); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 7923cb5f0d918..da28b0eff441f 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -51,8 +52,19 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(field, vector, null, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + field, + vector, + null, + k, + numCands, + rescoreVectorBuilder, + similarity + ); List preFilterQueryBuilders = new ArrayList<>(); @@ -93,6 +105,7 @@ public void testRewrite() throws IOException { assertNull(source.query()); assertThat(source.knnSearch().size(), equalTo(1)); assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size())); + assertThat(source.knnSearch().get(0).getRescoreVectorBuilder(), equalTo(knnRetriever.rescoreVectorBuilder())); for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) { assertThat( source.knnSearch().get(0).getFilterQueries().get(j), diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index ccf33c0b71b6b..eafab1d25c38e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -69,6 +70,7 @@ private List innerRetrievers(QueryRewriteContext queryRewriteC null, randomInt(10), randomIntBetween(10, 100), + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat() ); if (randomBoolean()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f93bdd14f0645..375712ee60861 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -37,11 +37,16 @@ import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -52,23 +57,70 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final String VECTOR_ALIAS_FIELD = "vector_alias"; - static final int VECTOR_DIMENSION = 3; + protected static final Set QUANTIZED_INDEX_TYPES = Set.of( + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "int8_flat", + "int4_flat", + "bbq_flat" + ); + protected static final Set NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat"); + protected static final Set ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream()) + .collect(Collectors.toUnmodifiableSet()); + protected static String indexType; + protected static int vectorDimensions; + + @Before + private void checkIndexTypeAndDimensions() { + // Check that these are initialized - should be done as part of the createAdditionalMappings method + assertNotNull(indexType); + assertNotEquals(0, vectorDimensions); + } abstract DenseVectorFieldMapper.ElementType elementType(); - abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity); + abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ); + + protected boolean isQuantizedElementType() { + return QUANTIZED_INDEX_TYPES.contains(indexType); + } + + protected abstract String randomIndexType(); @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + + // These fields are initialized here, as mappings are initialized only once per test class. + // We want the subclasses to be able to override the index type and vector dimensions so we don't make this static / BeforeClass + // for initialization. + indexType = randomIndexType(); + if (indexType.contains("bbq")) { + vectorDimensions = 64; + } else if (indexType.contains("int4")) { + vectorDimensions = 4; + } else { + vectorDimensions = 3; + } + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(VECTOR_FIELD) .field("type", "dense_vector") - .field("dims", VECTOR_DIMENSION) + .field("dims", vectorDimensions) .field("index", true) .field("similarity", "l2_norm") .field("element_type", elementType()) + .startObject("index_options") + .field("type", indexType) + .endObject() .endObject() .startObject(VECTOR_ALIAS_FIELD) .field("type", "alias") @@ -88,7 +140,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; Integer k = randomBoolean() ? null : randomIntBetween(1, 100); int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); - KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat()); + KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( + fieldName, + k, + numCands, + randomRescoreVectorBuilder(), + randomFloat() + ); if (randomBoolean()) { List filters = new ArrayList<>(); @@ -99,24 +157,32 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { } queryBuilder.addFilterQueries(filters); } + return queryBuilder; } + protected RescoreVectorBuilder randomRescoreVectorBuilder() { + if (randomBoolean()) { + return null; + } + + return new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { if (queryBuilder.getVectorSimilarity() != null) { assertTrue(query instanceof VectorSimilarityQuery); - Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery(); assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); - switch (elementType()) { - case FLOAT -> assertTrue(knnQuery instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(knnQuery instanceof ESKnnByteVectorQuery); - } - } else { - switch (elementType()) { - case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); - } + query = ((VectorSimilarityQuery) query).getInnerKnnQuery(); + } + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + query = rescoreQuery.innerQuery(); + } + switch (elementType()) { + case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); + case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -126,21 +192,18 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que BooleanQuery booleanQuery = builder.build(); Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; // The field should always be resolved to the concrete field + Integer k = queryBuilder.k(); + Integer numCands = queryBuilder.numCands(); + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { + Float numCandsFactor = queryBuilder.rescoreVectorBuilder().numCandidatesFactor(); + int minCands = k == null ? 1 : k; + numCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); + numCands = Math.min(numCands, NUM_CANDS_OVERSAMPLE_LIMIT); + } + Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asByteVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); - case FLOAT -> new ESKnnFloatVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asFloatVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); + case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); + case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); @@ -150,17 +213,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), - containsString("The query vector has a different number of dimensions [2] than the document vectors [3]") + containsString("The query vector has a different number of dimensions [2] than the document vectors [" + vectorDimensions + "]") ); } public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); context.setAllowUnmappedFields(false); QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); @@ -168,7 +231,7 @@ public void testNonexistentField() { public void testNonexistentFieldReturnEmpty() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); Query queryNone = query.doToQuery(context); assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } @@ -180,6 +243,7 @@ public void testWrongFieldType() { new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, + null, null ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); @@ -191,14 +255,14 @@ public void testNumCandsLessThanK() { int numCands = 3; IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null) + () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @Override public void testValidOutput() { - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null); String expected = """ { "knn" : { @@ -213,7 +277,7 @@ public void testValidOutput() { }"""; assertEquals(expected, query.toString()); - KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null); + KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null); String expected2 = """ { "knn" : { @@ -238,7 +302,8 @@ public void testMustRewrite() throws IOException { KnnVectorQueryBuilder query = new KnnVectorQueryBuilder( VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, - VECTOR_DIMENSION, + vectorDimensions, + null, null, null ); @@ -254,9 +319,14 @@ public void testMustRewrite() throws IOException { public void testBWCVersionSerializationFilters() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()); + KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( random(), TransportVersions.V_8_0_0, @@ -268,10 +338,14 @@ public void testBWCVersionSerializationFilters() throws IOException { public void testBWCVersionSerializationSimilarity() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()) - .addFilterQueries(query.filterQueries()); + KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); } @@ -289,11 +363,34 @@ public void testBWCVersionSerializationQuery() throws IOException { vectorData, null, query.numCands(), + null, similarity ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); } + public void testBWCVersionSerializationRescoreVector() throws IOException { + KnnVectorQueryBuilder query = createTestQueryBuilder(); + TransportVersion version = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_8_1, + TransportVersionUtils.getPreviousVersion(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) + ); + VectorData vectorData = version.onOrAfter(TransportVersions.V_8_14_0) + ? query.queryVector() + : VectorData.fromFloats(query.queryVector().asFloatVector()); + Integer k = version.before(TransportVersions.V_8_15_0) ? null : query.k(); + KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + k, + query.numCands(), + null, + query.getVectorSimilarity() + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + assertBWCSerialization(query, queryNoRescoreVector, version); + } + private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { assertSerialization(bwcQuery, version); try (BytesStreamOutput output = new BytesStreamOutput()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 0fc2304e904a4..f6c2e754cec63 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -18,11 +18,22 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { - byte[] vector = new byte[VECTOR_DIMENSION]; + protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + byte[] vector = new byte[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + } + + @Override + protected String randomIndexType() { + return randomFrom(NON_QUANTIZED_INDEX_TYPES); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index ba2245ced3305..6f67e4be29a06 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -18,11 +18,22 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { - float[] vector = new float[VECTOR_DIMENSION]; + KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + float[] vector = new float[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + } + + @Override + protected String randomIndexType() { + return randomFrom(ALL_INDEX_TYPES); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 2184e8af54aed..a39438af5b72a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -52,8 +52,18 @@ public static KnnSearchBuilder randomTestInstance() { float[] vector = randomVector(dim); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); - - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat()); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnSearchBuilder builder = new KnnSearchBuilder( + field, + vector, + k, + numCands, + rescoreVectorBuilder, + randomBoolean() ? null : randomFloat() + ); if (randomBoolean()) { builder.boost(randomFloat()); } @@ -100,46 +110,90 @@ protected KnnSearchBuilder createTestInstance() { @Override protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { - switch (random().nextInt(7)) { + switch (random().nextInt(8)) { case 0: String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5)); - return new KnnSearchBuilder(newField, instance.queryVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + newField, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 1: float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5)); - return new KnnSearchBuilder(instance.field, newVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + newVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 2: // given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10)); - return new KnnSearchBuilder(instance.field, instance.queryVector, newK, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + newK, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 3: Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100)); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, newNumCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + newNumCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 4: - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).addFilterQueries(instance.filterQueries) .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) .boost(instance.boost); case 5: float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) - .boost(newBoost); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(newBoost); case 6: return new KnnSearchBuilder( instance.field, instance.queryVector, instance.k, instance.numCands, + instance.getRescoreVectorBuilder(), randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) ).addFilterQueries(instance.filterQueries).boost(instance.boost); + case 7: + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + randomValueOtherThan( + instance.getRescoreVectorBuilder(), + () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) + ), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(instance.boost); default: throw new IllegalStateException(); } @@ -151,7 +205,10 @@ public void testToQueryBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, similarity); float boost = AbstractQueryBuilder.DEFAULT_BOOST; if (randomBoolean()) { @@ -167,15 +224,16 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries) - .boost(boost); + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, rescoreVectorBuilder, similarity).addFilterQueries( + filterQueries + ).boost(boost); assertEquals(expected, builder.toQueryBuilder()); } public void testNumCandsLessThanK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null) + () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @@ -183,7 +241,7 @@ public void testNumCandsLessThanK() { public void testNumCandsExceedsLimit() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null) + () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); } @@ -191,18 +249,28 @@ public void testNumCandsExceedsLimit() { public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null) + () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null, null) ); assertThat(e.getMessage(), containsString("[k] must be greater than 0")); } + public void testInvalidRescoreVectorBuilder() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null) + ); + assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0")); + } + public void testRewrite() throws Exception { float[] expectedArray = randomVector(randomIntBetween(10, 1024)); + RescoreVectorBuilder expectedRescore = new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); KnnSearchBuilder searchBuilder = new KnnSearchBuilder( "field", new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray), 5, 10, + expectedRescore, 1f ); searchBuilder.boost(randomFloat()); @@ -220,6 +288,7 @@ public void testRewrite() throws Exception { assertThat(rewritten.filterQueries, hasSize(1)); assertThat(rewritten.similarity, equalTo(1f)); assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1)); + assertThat(rewritten.getRescoreVectorBuilder(), equalTo(expectedRescore)); } public static float[] randomVector(int dim) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java new file mode 100644 index 0000000000000..7bbe7dcc155c5 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class RescoreKnnVectorQueryTests extends ESTestCase { + + public static final String FIELD_NAME = "float_vector"; + private final int numDocs; + private final Integer k; + + public RescoreKnnVectorQueryTests(boolean useK) { + this.numDocs = randomIntBetween(10, 100); + this.k = useK ? randomIntBetween(1, numDocs - 1) : null; + } + + public void testRescoreDocs() throws Exception { + int numDims = randomIntBetween(5, 100); + + Integer adjustedK = k; + if (k == null) { + adjustedK = numDocs; + } + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims); + + try (IndexReader reader = DirectoryReader.open(d)) { + + // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query + // and thus we're rescoring the top k docs. + float[] queryVector = randomVector(numDims); + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + adjustedK, + new MatchAllDocsQuery() + ); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); + Map rescoredDocs = Arrays.stream(docs.scoreDocs) + .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); + + assertThat(rescoredDocs.size(), equalTo(adjustedK)); + + Collection rescoredScores = new HashSet<>(rescoredDocs.values()); + + // Collect all docs sequentially, and score them using the similarity function to get the top K scores + PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); + + for (LeafReaderContext leafReaderContext : reader.leaves()) { + FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + float[] vectorData = vectorValues.vectorValue(iterator.docID()); + float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); + topK.add(score); + int docId = iterator.docID(); + // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it + // to ensure we found them all + if (rescoredDocs.containsKey(docId)) { + assertThat(rescoredDocs.get(docId), equalTo(score)); + rescoredDocs.remove(docId); + } + } + } + + assertThat(rescoredDocs.size(), equalTo(0)); + + // Check top scoring docs are contained in rescored docs + for (int i = 0; i < adjustedK; i++) { + Float topScore = topK.poll(); + if (rescoredScores.contains(topScore) == false) { + fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + } + } + } + } + } + + public void testProfiling() throws Exception { + int numDims = randomIntBetween(5, 100); + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims); + + try (IndexReader reader = DirectoryReader.open(d)) { + float[] queryVector = randomVector(numDims); + + checkProfiling(queryVector, reader, new MatchAllDocsQuery()); + checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); + } + } + } + + private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + innerQuery + ); + IndexSearcher searcher = newSearcher(reader, true, false); + searcher.search(rescoreKnnVectorQuery, numDocs); + + QueryProfiler queryProfiler = new QueryProfiler(); + rescoreKnnVectorQuery.profile(queryProfiler); + + long expectedVectorOpsCount = numDocs; + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { + QueryProfiler anotherProfiler = new QueryProfiler(); + queryProfilerProvider.profile(anotherProfiler); + assertThat(anotherProfiler.getVectorOpsCount(), greaterThan(0L)); + expectedVectorOpsCount += anotherProfiler.getVectorOpsCount(); + } + + assertThat(queryProfiler.getVectorOpsCount(), equalTo(expectedVectorOpsCount)); + } + + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } + + /** + * A mock query that is used to test profiling + */ + private static class MockQueryProfilerProvider extends Query implements QueryProfilerProvider { + + private final long vectorOpsCount; + + private MockQueryProfilerProvider(long vectorOpsCount) { + this.vectorOpsCount = vectorOpsCount; + } + + @Override + public String toString(String field) { + return ""; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + throw new UnsupportedEncodingException("Should have been rewritten"); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return new MatchAllDocsQuery(); + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof MockQueryProfilerProvider; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + } + + private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { + try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + Document document = new Document(); + float[] vector = randomVector(numDims); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + document.add(vectorField); + w.addDocument(document); + } + w.commit(); + w.forceMerge(1); + } + } + + @ParametersFactory + public static Iterable parameters() { + List params = new ArrayList<>(); + params.add(new Object[] { true }); + params.add(new Object[] { false }); + + return params; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java index 363d34ca3ff86..6e8cf735983aa 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java +++ b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -61,6 +62,7 @@ import static org.elasticsearch.test.ESTestCase.randomByte; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomFloat; +import static org.elasticsearch.test.ESTestCase.randomFloatBetween; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; import static org.elasticsearch.test.ESTestCase.randomIntBetween; @@ -264,7 +266,12 @@ public static SearchSourceBuilder randomSearchSourceBuilder( } int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); - knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat())); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + knnSearchBuilders.add( + new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, randomBoolean() ? null : randomFloat()) + ); } builder.knnSearch(knnSearchBuilders); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index e00dc9f693ff3..1ca6ef0b43a38 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -97,6 +98,7 @@ public final void testKnnSearchBuilderWireSerialization() throws IOException { createTestInstance(), 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); searchBuilder.queryName(randomAlphaOfLengthBetween(5, 10)); @@ -120,6 +122,7 @@ public final void testKnnSearchRewrite() throws Exception { queryVectorBuilder, 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); KnnSearchBuilder serialized = copyWriteable( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3221ef758c547..899c5d4f21b31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -560,7 +560,7 @@ yield new SparseVectorQueryBuilder( k = Math.max(k, DEFAULT_SIZE); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null); } default -> throw new IllegalStateException( "Field [" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java index af57c95c35615..e438090c99163 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -91,7 +91,7 @@ public void testDenseVector() throws Exception { Map queryMap = (Map) queries.get("dense_vector_1"); float[] vector = readDenseVector(queryMap.get("embeddings")); var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); - KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); var shardRequest = createShardSearchRequest(nestedQueryBuilder); var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 916703446995d..084a7f3de4a53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -102,7 +102,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -116,7 +118,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -146,7 +148,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java index b501967524a6b..723ab146f9fd6 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java @@ -136,7 +136,7 @@ public void setupSuiteScopeCluster() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( prepareSearch("tiny_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -167,7 +167,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -208,8 +208,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -260,8 +260,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(true) @@ -389,8 +389,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -457,8 +457,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -704,7 +704,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -762,7 +762,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -837,8 +837,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -899,8 +899,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -979,7 +979,7 @@ public void testBasicRRFExplain() { // the first result should be the one present in both queries (i.e. doc with text0: 10 and vector: [10]) and the other ones // should only match the knn query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1045,7 +1045,7 @@ public void testRRFExplainUnknownField() { // in this test we try knn with a query on an unknown field that would be rewritten to MatchNoneQuery // so we expect results and explanations only for the first part float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1112,7 +1112,7 @@ public void testRRFExplainOneUnknownFieldSubSearches() { // while the other one would produce a match. // So, we'd have a total of 3 queries, a (rewritten) MatchNoneQuery, a TermQuery, and a kNN query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java index 7269d9c3e5e7f..4dcd1e3156bd2 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java @@ -131,7 +131,7 @@ public void setupIndices() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( client().prepareSearch("tiny_index") @@ -164,7 +164,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -206,8 +206,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -259,8 +259,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -390,8 +390,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -459,8 +459,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -709,7 +709,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -768,7 +768,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -844,8 +844,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -907,8 +907,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index e4e06b5031005..443bf5c2465c9 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -183,7 +183,15 @@ public void testRRFPagination() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -233,7 +241,7 @@ public void testRRFWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -288,7 +296,7 @@ public void testRRFWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -345,7 +353,7 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -411,7 +419,7 @@ public void testMultipleRRFRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -430,7 +438,7 @@ public void testMultipleRRFRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), null ) ), @@ -477,7 +485,7 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -536,7 +544,7 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( @@ -756,6 +764,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), 10, 10, + null, null ); source.retriever( @@ -809,7 +818,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); var rrf = new RRFRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index 69c61fe3bca1f..1a36fba8fdb60 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index 4eaea9a596361..9bc1cd80ea381 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -103,7 +103,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -117,7 +119,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -136,7 +138,7 @@ public void testTelemetryForRRFRetriever() throws IOException { new RRFRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null), + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null), null ), new CompoundRetrieverBuilder.RetrieverSource( @@ -153,7 +155,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java index c0866fa7ea694..50b2d0d626481 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java @@ -884,7 +884,7 @@ public void testKnnSearch() throws Exception { // Since there's no kNN search action at the transport layer, we just emulate // how the action works (it builds a kNN query under the hood) float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f }; - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, 50, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, 50, null, null); if (randomBoolean()) { query.addFilterQuery(new WildcardQueryBuilder("other", "value*")); diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java index bffa53b1f4da6..c2ccb923050ed 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java @@ -441,7 +441,7 @@ public void testKnnSearch() throws IOException { // Since there's no kNN search action at the transport layer, we just emulate // how the action works (it builds a kNN query under the hood) float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f }; - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null); // user1 has access to vector field, so the query should match with the document: assertResponse( @@ -475,7 +475,7 @@ public void testKnnSearch() throws IOException { } ); // user1 can access field1, so the filtered query should match with the document: - KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery( + KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null).addFilterQuery( QueryBuilders.matchQuery("field1", "value1") ); assertHitCount( @@ -486,7 +486,7 @@ public void testKnnSearch() throws IOException { ); // user1 cannot access field2, so the filtered query should not match with the document: - KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery( + KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null).addFilterQuery( QueryBuilders.matchQuery("field2", "value2") ); assertHitCount(