diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 8e5849abb..77e993297 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.query; +import com.google.common.base.Predicates; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.NonNull; import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; @@ -21,6 +23,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; @@ -36,7 +39,9 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; +import java.util.function.Predicate; @Log4j2 @AllArgsConstructor @@ -55,11 +60,41 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getKnnQuery().getRadius() != null) { + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, exactSearcherContext.getK()); + return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); + } + + /** + * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. + * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores + * to filter out the documents that does not have given min score. + * @param leafReaderContext + * @param exactSearcherContext + * @param iterator {@link KNNIterator} + * @return Map of docId and score + * @throws IOException exception raised by iterator during traversal + */ + private Map doRadialSearch( + LeafReaderContext leafReaderContext, + ExactSearcherContext exactSearcherContext, + KNNIterator iterator + ) throws IOException { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + if (KNNEngine.FAISS != engine) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + final float minScore = spaceType.scoreTranslation(knnQuery.getRadius()); + return filterDocsByMinScore(exactSearcherContext, iterator, minScore); } private Map scoreAllDocs(KNNIterator iterator) throws IOException { @@ -71,15 +106,17 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate filterScore) + throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(k, true); + final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - if (iterator.score() > topDoc.score) { - topDoc.score = iterator.score(); + final float currentScore = iterator.score(); + if (filterScore.test(currentScore) && currentScore > topDoc.score) { + topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we // have seen till now on top. @@ -98,10 +135,16 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc final ScoreDoc doc = queue.pop(); docToScore.put(doc.doc, doc.score); } - return docToScore; } + private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) + throws IOException { + int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); + Predicate scoreGreaterThanOrEqualToMinScore = score -> score >= minScore; + return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); + } + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d5cd80934..a7799b4aa 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -95,8 +96,13 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { } public static void initialize(ModelDao modelDao) { + initialize(modelDao, new ExactSearcher(modelDao)); + } + + @VisibleForTesting + static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { KNNWeight.modelDao = modelDao; - KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao); + KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher; } @Override @@ -204,8 +210,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() - .k(k) .isParentHits(true) + .k(k) // setting to true, so that if quantization details are present we want to do search on the quantized // vectors as this flow is used in first pass of search. .useQuantizedVectorsForSearch(true) @@ -398,12 +404,9 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); - if (knnQuery.getRadius() != null) { - return false; - } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (filterIdsCount <= knnQuery.getK()) { + if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { return true; } // See user has defined Exact Search filtered threshold. if yes, then use that setting. diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 99152ef6b..b5166866c 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -88,6 +88,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest .indexName(indexName) .parentsFilter(parentFilter) .radius(radius) + .vectorDataType(vectorDataType) .methodParameters(methodParameters) .context(knnQueryContext) .filterQuery(filterQuery) diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index eec520a63..c494f7f1f 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -92,6 +92,7 @@ public class FaissIT extends KNNRestTestCase { private static final String INTEGER_FIELD_NAME = "int_field"; private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; static TestUtils.TestData testData; @@ -622,10 +623,11 @@ public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCrea // Assert we have the right number of documents in the index assertEquals(numDocs, getDocCount(indexName)); - // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + // expect result due to exact search + assertEquals(1, results.size()); deleteKNNIndex(indexName); validateGraphEviction(); @@ -681,7 +683,7 @@ public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() // KNN Query should return empty result final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + assertEquals(1, results.size()); // update index setting to build graph and do force merge // update build vector data structure setting @@ -1826,6 +1828,111 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() { + final SpaceType spaceType = SpaceType.L2; + + final List mValues = ImmutableList.of(16, 32, 64, 128); + final List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + final List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + final float distance = 300000000000f; + final List> resultsFromDistance = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + distance, + null, + spaceType, + null, + null + ); + assertFalse(resultsFromDistance.isEmpty()); + resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); }); + final float score = spaceType.scoreTranslation(distance); + final List> resultsFromScore = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + null, + score, + spaceType, + null, + null + ); + assertFalse(resultsFromScore.isEmpty()); + resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); }); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialQueryWithFilter_whenNoGraphIsCreated_thenSuccess() { + setupKNNIndexForFilterQuery(buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD)); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder, + null + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1898,6 +2005,10 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( } protected void setupKNNIndexForFilterQuery() throws Exception { + setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); + } + + protected void setupKNNIndexForFilterQuery(Settings settings) throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -1915,7 +2026,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .endObject(); final String mapping = builder.toString(); - createKnnIndex(INDEX_NAME, mapping); + createKnnIndex(INDEX_NAME, settings, mapping); addKnnDocWithAttributes( DOC_ID_1, diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index bf8168d37..c6e5c8fd4 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -814,6 +814,92 @@ public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuil deleteKNNIndex(indexName); } + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSettingUsingRadialSearch() + throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) throws IOException, ParseException { final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); diff --git a/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java new file mode 100644 index 000000000..8492ca1f0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentCommitInfo; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; + +public class ExactSearcherTests extends KNNTestCase { + + private static final String SEGMENT_NAME = "0"; + + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenSuccess() { + try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + final Float score = Collections.min(expectedScores); + final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType); + final int maxResults = 1000; + final KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + KNNWeight.initialize(null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .build(); + + final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder = + ExactSearcher.ExactSearcherContext.builder() + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(false) + .knnQuery(query); + + ExactSearcher exactSearcher = new ExactSearcher(null); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS); + when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + final Map integerFloatMap = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build()); + assertEquals(integerFloatMap.size(), dataVectors.size()); + assertEquals(expectedScores, new ArrayList<>(integerFloatMap.values())); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7a71c44be..449ab80e0 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -360,7 +361,6 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() final Path path = mock(Path.class); when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); - final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); // When no knn fields are available , field info for vector field will be null when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); @@ -763,6 +763,83 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.00000001f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index eed2772b4..7784c4bf4 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -118,7 +118,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100; @@ -133,7 +136,7 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); @@ -141,7 +144,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100;