diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a8d987c7..0740fa39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] +- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index f75c7f1d9..1e560a11b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -34,7 +34,6 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); - public static final Set ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 7bac6c126..8e6c97f05 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -26,7 +26,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; -import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS; /** * Creates the Lucene k-NN queries @@ -110,15 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .build(); } - if (createQueryRequest.getRescoreContext().isPresent()) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - return knnQuery; + return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java index 19842a67a..727c508fb 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java @@ -19,9 +19,8 @@ * A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given * set of nested document IDs. * - * The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents - * of that parent document will be iterated multiple times. - * + * It is permissible for {@link #docIds} to contain multiple nested document IDs linked to a single parent document. + * In such cases, this iterator will still iterate over each nested document ID only once. */ public class GroupedNestedDocIdSetIterator extends DocIdSetIterator { private final BitSet parentBitSet; @@ -99,9 +98,14 @@ public long cost() { private long calculateCost() { long numDocs = 0; + int lastDocId = -1; for (int docId : docIds) { - for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) { - if (filterBits.get(i)) { + if (docId < lastDocId) { + continue; + } + + for (lastDocId = parentBitSet.prevSetBit(docId) + 1; lastDocId < parentBitSet.nextSetBit(docId); lastDocId++) { + if (filterBits.get(lastDocId)) { numDocs++; } } @@ -111,12 +115,19 @@ private long calculateCost() { private void moveToNextIndex() { currentIndex++; - if (currentIndex >= docIds.size()) { - currentDocId = NO_MORE_DOCS; + while (currentIndex < docIds.size()) { + // Advance currentIndex until the docId at the currentIndex is greater than currentDocId. + // This ensures proper handling when docIds contain multiple entries under the same parent ID + // that have already been iterated. + if (docIds.get(currentIndex) <= currentDocId) { + currentIndex++; + continue; + } + currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; + currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); + assert currentParentId != NO_MORE_DOCS; return; } - currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; - currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); - assert currentParentId != NO_MORE_DOCS; + currentDocId = NO_MORE_DOCS; } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index b609bb0df..30c8007e6 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -33,6 +33,7 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -191,7 +192,7 @@ public void testDoToQuery_Normal() throws Exception { when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); @@ -599,8 +600,8 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { // Then assertNotNull(query); - assertTrue(query.getClass().isAssignableFrom(KNNQuery.class)); - assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters()); + assertTrue(query.getClass().isAssignableFrom(NativeEngineKnnVectorQuery.class)); + assertEquals(HNSW_METHOD_PARAMS, ((NativeEngineKnnVectorQuery) query).getKnnQuery().getMethodParameters()); } public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { @@ -670,7 +671,7 @@ public void testDoToQuery_FromModel() { KNNQueryBuilder.initialize(modelDao); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); @@ -1026,7 +1027,7 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); assertNull(query.getQueryVector()); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index eff2ca895..329222636 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -69,7 +69,7 @@ public void setUp() throws Exception { public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - Query query = KNNQueryFactory.create( + Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create( BaseQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -78,14 +78,14 @@ public void testCreateCustomKNNQuery() { .k(testK) .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .build() - ); + )).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); assertEquals(testK, ((KNNQuery) query).getK()); - query = KNNQueryFactory.create( + query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create( BaseQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -94,7 +94,7 @@ public void testCreateCustomKNNQuery() { .k(testK) .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .build() - ); + )).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); @@ -269,7 +269,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { .filter(FILTER_QUERY_BUILDER) .build(); - final Query actual = KNNQueryFactory.create(createQueryRequest); + final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); // Then assertEquals(expectedQuery, actual); @@ -303,7 +303,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc .filter(FILTER_QUERY_BUILDER) .build(); - final Query actual = KNNQueryFactory.create(createQueryRequest); + final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); // Then assertEquals(expectedQuery, actual); @@ -338,7 +338,7 @@ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnTo .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); mockedNestedHelper.close(); assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass()); } @@ -367,7 +367,7 @@ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery( .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); mockedNestedHelper.close(); assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass()); } @@ -388,7 +388,7 @@ public void testCreate_whenFaissWithParentFilter_thenSuccess() { .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .build(); - final Query query = KNNQueryFactory.create(createQueryRequest); + final Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); @@ -441,7 +441,7 @@ public void testCreate_whenBinary_thenSuccess() { .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - Query query = KNNQueryFactory.create(createQueryRequest); + Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertNotNull(((KNNQuery) query).getByteQueryVector()); assertNull(((KNNQuery) query).getQueryVector()); @@ -488,7 +488,7 @@ public void testCreate_whenExpandNestedDocsQueryWithFaiss_thenCreateNativeEngine } public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() { - testExpandNestedDocsQuery(KNNEngine.NMSLIB, KNNQuery.class, VectorDataType.FLOAT); + testExpandNestedDocsQuery(KNNEngine.NMSLIB, NativeEngineKnnVectorQuery.class, VectorDataType.FLOAT); } public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() { diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java index 55f3d91d9..976b50ea6 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java @@ -70,4 +70,33 @@ public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsEx assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); } + + public void testGroupedNestedDocIdSetIterator_whenGivenMultipleDocsUnderSameParent_thenBehaveAsExpected() throws Exception { + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet parentBitSet = new FixedBitSet(new long[1], 11); + parentBitSet.set(2); + parentBitSet.set(7); + parentBitSet.set(10); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(0, 1, 3, 4, 5, 8, 9); + GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + groupedNestedDocIdSetIterator.advance(1); + assertEquals(1, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(8, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.advance(9); + assertEquals(9, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); + assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); + } } diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index 6f243ff3a..6498eeda5 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -68,15 +68,6 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath()); } - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { // Create Index diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java index a706dd0cd..f16fa1494 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java @@ -29,15 +29,6 @@ @Log4j2 @AllArgsConstructor public class BinaryIndexInvalidMappingIT extends KNNRestTestCase { - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - private String description; private String indexMapping; private String expectedExceptionMessage; diff --git a/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java index 164aa7100..2ec1808be 100644 --- a/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java +++ b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java @@ -14,6 +14,7 @@ import org.junit.After; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; @@ -70,12 +71,6 @@ public class ExpandNestedDocsIT extends KNNRestTestCase { private Mode mode; private Integer dimension; - @After - @SneakyThrows - public final void cleanUp() { - deleteKNNIndex(INDEX_NAME); - } - @ParametersFactory(argumentFormatting = "description:%1$s; engine:%2$s, data_type:%3$s, mode:%4$s, dimension:%5$s") public static Collection parameters() throws IOException { int dimension = 1; @@ -99,13 +94,19 @@ public static Collection parameters() throws IOException { Mode.ON_DISK, // Currently, on disk mode only supports dimension of multiple of 8 dimension * 8 - ) + ), + $("Nmslib with float format and in memory mode", KNNEngine.NMSLIB, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, dimension) ) ); } @SneakyThrows public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc() { + if (engine == KNNEngine.NMSLIB) { + // NMSLIB does not support filtering + return; + } + int numberOfNestedFields = 2; createKnnIndex(engine, mode, dimension, dataType); addRandomVectorsWithTopLevelField(1, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); @@ -131,6 +132,11 @@ public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc( @SneakyThrows public void testExpandNestedDocs_whenFilteredOnNestedFieldDoc_thenReturnFilteredNestedDoc() { + if (engine == KNNEngine.NMSLIB) { + // NMSLIB does not support filtering + return; + } + int numberOfNestedFields = 2; createKnnIndex(engine, mode, dimension, dataType); addRandomVectorsWithMetadata(1, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE)); @@ -175,7 +181,9 @@ public void testExpandNestedDocs_whenMultiShards_thenReturnCorrectResult() { // Run Float[] queryVector = createVector(); - Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, numberOfDocuments, queryVector); + // NMSLIB does not support dedup per parent documents. Therefore, we need to multiply the k by number of nestedFields. + int k = engine == KNNEngine.NMSLIB ? numberOfDocuments * numberOfNestedFields : numberOfDocuments; + Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, k, queryVector); // Verify String entity = EntityUtils.toString(response.getEntity()); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java index a5daac9e5..e0914e7ff 100644 --- a/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchBinaryIT.java @@ -39,15 +39,6 @@ public static Collection parameters() { return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); } - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testFilteredSearchHnswBinary_whenDoingApproximateSearch_thenReturnCorrectResults() { validateFilteredSearchHnswBinary(INDEX_NAME, false); diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java index fe4dc7db9..450fcf400 100644 --- a/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchByteIT.java @@ -25,15 +25,6 @@ @Log4j2 public class FilteredSearchByteIT extends KNNRestTestCase { - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testFilteredSearchWithFaissHnswByte_whenDoingApproximateSearch_thenReturnCorrectResults() { validateFilteredSearchWithFaissHnswByte(INDEX_NAME, false); diff --git a/src/test/java/org/opensearch/knn/integ/IndexIT.java b/src/test/java/org/opensearch/knn/integ/IndexIT.java index 02faf6a0f..460dc5b4d 100644 --- a/src/test/java/org/opensearch/knn/integ/IndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/IndexIT.java @@ -52,15 +52,6 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath()); } - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testFaissHnsw_when1000Data_thenRecallIsAboveNinePointZero() { // Create Index diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index d1288c5f3..d5d3168b7 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -842,43 +842,39 @@ private void createIndexAndAssertScriptScore( */ Settings settings = Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", enableKnn).build(); createKnnIndex(INDEX_NAME, settings, mapper); - try { - final int numDocsWithField = randomIntBetween(4, 10); - Map dataset = createDataset( - v -> scoreFunction.apply(queryVector, v), - dimensions, - numDocsWithField, - dense, - vectorDataType - ); - final float[] dummyVector = new float[1]; - dataset.forEach((k, v) -> { - final float[] vector = (v != null) ? v.getVector() : dummyVector; - ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); - }); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": FIELD_NAME, - * "vector": queryVector - * } - */ - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", spaceType.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - assertTrue(results.stream().allMatch(r -> dataset.get(r.getDocId()).equals(r))); - } finally { - deleteKNNIndex(INDEX_NAME); - } + final int numDocsWithField = randomIntBetween(4, 10); + Map dataset = createDataset( + v -> scoreFunction.apply(queryVector, v), + dimensions, + numDocsWithField, + dense, + vectorDataType + ); + final float[] dummyVector = new float[1]; + dataset.forEach((k, v) -> { + final float[] vector = (v != null) ? v.getVector() : dummyVector; + ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); + }); + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + /* + * params": { + * "field": FIELD_NAME, + * "vector": queryVector + * } + */ + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType.getValue()); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertTrue(results.stream().allMatch(r -> dataset.get(r.getDocId()).equals(r))); } } diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java index 05291783d..cc32f4f0a 100644 --- a/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchBinaryIT.java @@ -39,15 +39,6 @@ public static Collection parameters() { return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); } - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testNestedSearchHnswBinary_whenKIsTwo_thenReturnTwoResults() { diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java index 7985d08a7..e6af947ab 100644 --- a/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchByteIT.java @@ -25,15 +25,6 @@ @Log4j2 public class NestedSearchByteIT extends KNNRestTestCase { - @After - public void cleanUp() { - try { - deleteKNNIndex(INDEX_NAME); - } catch (Exception e) { - log.error(e); - } - } - @SneakyThrows public void testNestedSearchWithFaissHnswByte_whenKIsTwo_thenReturnTwoResults() { String nestedFieldName = "nested"; diff --git a/src/test/java/org/opensearch/knn/integ/NestedSearchIT.java b/src/test/java/org/opensearch/knn/integ/NestedSearchIT.java index c9b33d70a..f8bcf28aa 100644 --- a/src/test/java/org/opensearch/knn/integ/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/integ/NestedSearchIT.java @@ -58,12 +58,6 @@ public class NestedSearchIT extends KNNRestTestCase { private static final int M = 16; private static final SpaceType SPACE_TYPE = SpaceType.L2; - @After - @SneakyThrows - public final void cleanUp() { - deleteKNNIndex(INDEX_NAME); - } - @SneakyThrows public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() { createKnnIndex(2, KNNEngine.LUCENE.getName()); diff --git a/src/test/java/org/opensearch/knn/integ/search/ConcurrentSegmentSearchIT.java b/src/test/java/org/opensearch/knn/integ/search/ConcurrentSegmentSearchIT.java index cc64cfe9c..933dacc0f 100644 --- a/src/test/java/org/opensearch/knn/integ/search/ConcurrentSegmentSearchIT.java +++ b/src/test/java/org/opensearch/knn/integ/search/ConcurrentSegmentSearchIT.java @@ -78,8 +78,6 @@ public void testConcurrentSegmentSearch_thenSucceed() { updateIndexSettings(indexName, Settings.builder().put("index.search.concurrent_segment_search.mode", "all")); verifySearch(indexName, fieldName, k); - - deleteKNNIndex(indexName); } /*