From 9637fb7f6393f2a2a5968eabe7fe156b25625e7a Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Mon, 30 Sep 2024 19:45:16 -0700 Subject: [PATCH 1/3] Add support for radial search in exact search When threshold value is set, knn plugin will not be creating graph. Hence, when search request is trigged during that time, exact search will return valid results. However, radial search was never included as part of exact search. This will break radial search when threshold is added and radial search is requested. In this commit, new method is introduced to accept min score and return documents that are greater than min score, similar to how radial search is performed by native engines. This search is independent of engine, but, radial search is supported only for FAISS engine out of all native engines. Signed-off-by: Vijayan Balasubramanian --- .../knn/index/query/ExactSearcher.java | 57 ++++++++- .../opensearch/knn/index/query/KNNWeight.java | 7 +- .../knn/index/query/RNNQueryFactory.java | 1 + .../org/opensearch/knn/index/FaissIT.java | 112 +++++++++++++++++- .../opensearch/knn/index/OpenSearchIT.java | 86 ++++++++++++++ .../opensearch/knn/integ/BinaryIndexIT.java | 28 ----- 6 files changed, 252 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 8e5849abb..ecb884ba5 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; @@ -36,7 +37,9 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; +import java.util.function.Predicate; @Log4j2 @AllArgsConstructor @@ -55,6 +58,9 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getKnnQuery().getRadius() != null) { + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); @@ -62,6 +68,33 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, return searchTopK(iterator, exactSearcherContext.getK()); } + /** + * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. + * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores + * to filter out the documents that does not have given min score. + * @param leafReaderContext + * @param exactSearcherContext + * @param iterator {@link KNNIterator} + * @return Map of docId and score + * @throws IOException exception raised by iterator during traversal + */ + private Map doRadialSearch( + LeafReaderContext leafReaderContext, + ExactSearcherContext exactSearcherContext, + KNNIterator iterator + ) throws IOException { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + if (KNNEngine.FAISS != engine) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + final float minScore = spaceType.scoreTranslation(knnQuery.getRadius()); + return filterDocsByMinScore(exactSearcherContext, iterator, minScore); + } + private Map scoreAllDocs(KNNIterator iterator) throws IOException { final Map docToScore = new HashMap<>(); int docId; @@ -71,15 +104,19 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, Predicate filterScore) throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(k, true); + final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - if (iterator.score() > topDoc.score) { - topDoc.score = iterator.score(); + final float currentScore = iterator.score(); + if (filterScore != null && Predicate.not(filterScore).test(currentScore)) { + continue; + } + if (currentScore > topDoc.score) { + topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we // have seen till now on top. @@ -98,10 +135,20 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc final ScoreDoc doc = queue.pop(); docToScore.put(doc.doc, doc.score); } - return docToScore; } + private Map searchTopK(KNNIterator iterator, int k) throws IOException { + return searchTopCandidates(iterator, k, null); + } + + private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) + throws IOException { + int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); + Predicate scoreGreaterThanOrEqualToMinScore = score -> score >= minScore; + return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); + } + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d5cd80934..d43a97c43 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -204,8 +204,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() - .k(k) .isParentHits(true) + .k(k) // setting to true, so that if quantization details are present we want to do search on the quantized // vectors as this flow is used in first pass of search. .useQuantizedVectorsForSearch(true) @@ -398,12 +398,9 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); - if (knnQuery.getRadius() != null) { - return false; - } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (filterIdsCount <= knnQuery.getK()) { + if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { return true; } // See user has defined Exact Search filtered threshold. if yes, then use that setting. diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 99152ef6b..b5166866c 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -88,6 +88,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest .indexName(indexName) .parentsFilter(parentFilter) .radius(radius) + .vectorDataType(vectorDataType) .methodParameters(methodParameters) .context(knnQueryContext) .filterQuery(filterQuery) diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index eec520a63..c7862bae8 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -92,6 +92,7 @@ public class FaissIT extends KNNRestTestCase { private static final String INTEGER_FIELD_NAME = "int_field"; private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; static TestUtils.TestData testData; @@ -1826,6 +1827,111 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() { + final SpaceType spaceType = SpaceType.L2; + + final List mValues = ImmutableList.of(16, 32, 64, 128); + final List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + final List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + final float distance = 300000000000f; + final List> resultsFromDistance = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + distance, + null, + spaceType, + null, + null + ); + assertFalse(resultsFromDistance.isEmpty()); + resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); }); + final float score = spaceType.scoreTranslation(distance); + final List> resultsFromScore = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + null, + score, + spaceType, + null, + null + ); + assertFalse(resultsFromScore.isEmpty()); + resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); }); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialQueryWithFilter_whenNoGraphIsCreated_thenSuccess() { + setupKNNIndexForFilterQuery(buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD)); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder, + null + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1898,6 +2004,10 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( } protected void setupKNNIndexForFilterQuery() throws Exception { + setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); + } + + protected void setupKNNIndexForFilterQuery(Settings settings) throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -1915,7 +2025,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .endObject(); final String mapping = builder.toString(); - createKnnIndex(INDEX_NAME, mapping); + createKnnIndex(INDEX_NAME, settings, mapping); addKnnDocWithAttributes( DOC_ID_1, diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index bf8168d37..c6e5c8fd4 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -814,6 +814,92 @@ public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuil deleteKNNIndex(indexName); } + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSettingUsingRadialSearch() + throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) throws IOException, ParseException { final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index eed2772b4..cca40c55d 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -155,17 +155,6 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } } - @SneakyThrows - public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { - // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); - - // Query - float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; - Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4)); - assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search")); - } - private float getRecall(final Set truth, final Set result) { // Count the number of relevant documents retrieved result.retainAll(truth); @@ -178,23 +167,6 @@ private float getRecall(final Set truth, final Set result) { return (float) relevantRetrieved / totalRelevant; } - private List runRnnQuery( - final String indexName, - final String fieldName, - final float[] queryVector, - final float minScore, - final int size - ) throws Exception { - String query = KNNJsonQueryBuilder.builder() - .fieldName(fieldName) - .vector(ArrayUtils.toObject(queryVector)) - .minScore(minScore) - .build() - .getQueryString(); - Response response = searchKNNIndex(indexName, query, size); - return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName); - } - private List runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k) throws Exception { String query = KNNJsonQueryBuilder.builder() From 54cdfe49e005e90a4254443c7c9f9ca11ec128cb Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Wed, 2 Oct 2024 16:32:10 -0700 Subject: [PATCH 2/3] Fix code review comments Signed-off-by: Vijayan Balasubramanian --- .../knn/index/query/ExactSearcher.java | 16 +- .../org/opensearch/knn/index/FaissIT.java | 7 +- .../knn/index/query/ExactSearcherTests.java | 139 ++++++++++++++++++ .../knn/index/query/KNNWeightTests.java | 1 - .../opensearch/knn/integ/BinaryIndexIT.java | 40 ++++- 5 files changed, 186 insertions(+), 17 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index ecb884ba5..77e993297 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.query; +import com.google.common.base.Predicates; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.NonNull; import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; @@ -65,7 +67,7 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, exactSearcherContext.getK()); + return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); } /** @@ -104,7 +106,8 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopCandidates(KNNIterator iterator, int limit, Predicate filterScore) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate filterScore) + throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); @@ -112,10 +115,7 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { final float currentScore = iterator.score(); - if (filterScore != null && Predicate.not(filterScore).test(currentScore)) { - continue; - } - if (currentScore > topDoc.score) { + if (filterScore.test(currentScore) && currentScore > topDoc.score) { topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we @@ -138,10 +138,6 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { - return searchTopCandidates(iterator, k, null); - } - private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) throws IOException { int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c7862bae8..c494f7f1f 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -623,10 +623,11 @@ public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCrea // Assert we have the right number of documents in the index assertEquals(numDocs, getDocCount(indexName)); - // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + // expect result due to exact search + assertEquals(1, results.size()); deleteKNNIndex(indexName); validateGraphEviction(); @@ -682,7 +683,7 @@ public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() // KNN Query should return empty result final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + assertEquals(1, results.size()); // update index setting to build graph and do force merge // update build vector data structure setting diff --git a/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java new file mode 100644 index 000000000..8492ca1f0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentCommitInfo; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; + +public class ExactSearcherTests extends KNNTestCase { + + private static final String SEGMENT_NAME = "0"; + + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenSuccess() { + try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + final Float score = Collections.min(expectedScores); + final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType); + final int maxResults = 1000; + final KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + KNNWeight.initialize(null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .build(); + + final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder = + ExactSearcher.ExactSearcherContext.builder() + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(false) + .knnQuery(query); + + ExactSearcher exactSearcher = new ExactSearcher(null); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS); + when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + final Map integerFloatMap = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build()); + assertEquals(integerFloatMap.size(), dataVectors.size()); + assertEquals(expectedScores, new ArrayList<>(integerFloatMap.values())); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7a71c44be..7edd4a414 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -360,7 +360,6 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() final Path path = mock(Path.class); when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); - final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); // When no knn fields are available , field info for vector field will be null when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index cca40c55d..7784c4bf4 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -118,7 +118,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100; @@ -133,7 +136,7 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); @@ -141,7 +144,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100; @@ -155,6 +161,17 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } } + @SneakyThrows + public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + + // Query + float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; + Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4)); + assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search")); + } + private float getRecall(final Set truth, final Set result) { // Count the number of relevant documents retrieved result.retainAll(truth); @@ -167,6 +184,23 @@ private float getRecall(final Set truth, final Set result) { return (float) relevantRetrieved / totalRelevant; } + private List runRnnQuery( + final String indexName, + final String fieldName, + final float[] queryVector, + final float minScore, + final int size + ) throws Exception { + String query = KNNJsonQueryBuilder.builder() + .fieldName(fieldName) + .vector(ArrayUtils.toObject(queryVector)) + .minScore(minScore) + .build() + .getQueryString(); + Response response = searchKNNIndex(indexName, query, size); + return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName); + } + private List runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k) throws Exception { String query = KNNJsonQueryBuilder.builder() From e51ae11e019f92eb093f66042bedebed8e187c71 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Thu, 3 Oct 2024 11:11:11 -0700 Subject: [PATCH 3/3] Refactor test Signed-off-by: Vijayan Balasubramanian --- .../opensearch/knn/index/query/KNNWeight.java | 8 +- .../knn/index/query/KNNWeightTests.java | 78 +++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d43a97c43..a7799b4aa 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -95,8 +96,13 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { } public static void initialize(ModelDao modelDao) { + initialize(modelDao, new ExactSearcher(modelDao)); + } + + @VisibleForTesting + static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { KNNWeight.modelDao = modelDao; - KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao); + KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher; } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7edd4a414..449ab80e0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -762,6 +763,83 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.00000001f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class);