Skip to content

Commit

Permalink
Support expand_nested_docs parameter for nmslib engine
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 14, 2024
1 parent aa6936a commit f298037
Show file tree
Hide file tree
Showing 18 changed files with 119 additions and 154 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
Expand Down Expand Up @@ -110,15 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

return knnQuery;
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

Integer requestEfSearch = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
* A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given
* set of nested document IDs.
*
* The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents
* of that parent document will be iterated multiple times.
*
* It is permissible for {@link #docIds} to contain multiple nested document IDs linked to a single parent document.
* In such cases, this iterator will still iterate over each nested document ID only once.
*/
public class GroupedNestedDocIdSetIterator extends DocIdSetIterator {
private final BitSet parentBitSet;
Expand Down Expand Up @@ -99,9 +98,14 @@ public long cost() {

private long calculateCost() {
long numDocs = 0;
int lastDocId = -1;
for (int docId : docIds) {
for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) {
if (filterBits.get(i)) {
if (docId < lastDocId) {
continue;
}

for (lastDocId = parentBitSet.prevSetBit(docId) + 1; lastDocId < parentBitSet.nextSetBit(docId); lastDocId++) {
if (filterBits.get(lastDocId)) {
numDocs++;
}
}
Expand All @@ -111,12 +115,19 @@ private long calculateCost() {

private void moveToNextIndex() {
currentIndex++;
if (currentIndex >= docIds.size()) {
currentDocId = NO_MORE_DOCS;
while (currentIndex < docIds.size()) {
// Advance currentIndex until the docId at the currentIndex is greater than currentDocId.
// This ensures proper handling when docIds contain multiple entries under the same parent ID
// that have already been iterated.
if (docIds.get(currentIndex) <= currentDocId) {
currentIndex++;
continue;
}
currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1;
currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex));
assert currentParentId != NO_MORE_DOCS;
return;
}
currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1;
currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex));
assert currentParentId != NO_MORE_DOCS;
currentDocId = NO_MORE_DOCS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -191,7 +192,7 @@ public void testDoToQuery_Normal() throws Exception {
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4));
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertEquals(knnQueryBuilder.getK(), query.getK());
assertEquals(knnQueryBuilder.fieldName(), query.getField());
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
Expand Down Expand Up @@ -599,8 +600,8 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {

// Then
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KNNQuery.class));
assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters());
assertTrue(query.getClass().isAssignableFrom(NativeEngineKnnVectorQuery.class));
assertEquals(HNSW_METHOD_PARAMS, ((NativeEngineKnnVectorQuery) query).getKnnQuery().getMethodParameters());
}

public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() {
Expand Down Expand Up @@ -670,7 +671,7 @@ public void testDoToQuery_FromModel() {
KNNQueryBuilder.initialize(modelDao);

when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertEquals(knnQueryBuilder.getK(), query.getK());
assertEquals(knnQueryBuilder.fieldName(), query.getField());
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
Expand Down Expand Up @@ -1026,7 +1027,7 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception {
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32));
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertArrayEquals(expectedQueryVector, query.getByteQueryVector());
assertNull(query.getQueryVector());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public void setUp() throws Exception {

public void testCreateCustomKNNQuery() {
for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) {
Query query = KNNQueryFactory.create(
Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
Expand All @@ -78,14 +78,14 @@ public void testCreateCustomKNNQuery() {
.k(testK)
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector());
assertEquals(testK, ((KNNQuery) query).getK());

query = KNNQueryFactory.create(
query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
Expand All @@ -94,7 +94,7 @@ public void testCreateCustomKNNQuery() {
.k(testK)
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
)).getKnnQuery();

assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
Expand Down Expand Up @@ -269,7 +269,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
.filter(FILTER_QUERY_BUILDER)
.build();

final Query actual = KNNQueryFactory.create(createQueryRequest);
final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();

// Then
assertEquals(expectedQuery, actual);
Expand Down Expand Up @@ -303,7 +303,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc
.filter(FILTER_QUERY_BUILDER)
.build();

final Query actual = KNNQueryFactory.create(createQueryRequest);
final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();

// Then
assertEquals(expectedQuery, actual);
Expand Down Expand Up @@ -338,7 +338,7 @@ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnTo
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
mockedNestedHelper.close();
assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass());
}
Expand Down Expand Up @@ -367,7 +367,7 @@ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery(
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
mockedNestedHelper.close();
assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass());
}
Expand All @@ -388,7 +388,7 @@ public void testCreate_whenFaissWithParentFilter_thenSuccess() {
.vectorDataType(VectorDataType.FLOAT)
.context(mockQueryShardContext)
.build();
final Query query = KNNQueryFactory.create(createQueryRequest);
final Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
Expand Down Expand Up @@ -441,7 +441,7 @@ public void testCreate_whenBinary_thenSuccess() {
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertNotNull(((KNNQuery) query).getByteQueryVector());
assertNull(((KNNQuery) query).getQueryVector());
Expand Down Expand Up @@ -488,7 +488,7 @@ public void testCreate_whenExpandNestedDocsQueryWithFaiss_thenCreateNativeEngine
}

public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() {
testExpandNestedDocsQuery(KNNEngine.NMSLIB, KNNQuery.class, VectorDataType.FLOAT);
testExpandNestedDocsQuery(KNNEngine.NMSLIB, NativeEngineKnnVectorQuery.class, VectorDataType.FLOAT);
}

public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,33 @@ public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsEx
assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID());
assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost());
}

public void testGroupedNestedDocIdSetIterator_whenGivenMultipleDocsUnderSameParent_thenBehaveAsExpected() throws Exception {
// 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent)
BitSet parentBitSet = new FixedBitSet(new long[1], 11);
parentBitSet.set(2);
parentBitSet.set(7);
parentBitSet.set(10);

BitSet filterBits = new FixedBitSet(new long[1], 11);
filterBits.set(1);
filterBits.set(8);
filterBits.set(9);

// Run
Set<Integer> docIds = Set.of(0, 1, 3, 4, 5, 8, 9);
GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits);

// Verify
Set<Integer> expectedDocIds = Set.of(1, 8, 9);
groupedNestedDocIdSetIterator.advance(1);
assertEquals(1, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.nextDoc();
assertEquals(8, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.advance(9);
assertEquals(9, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.nextDoc();
assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID());
assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost());
}
}
9 changes: 0 additions & 9 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

@SneakyThrows
public void testHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() {
// Create Index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@
@Log4j2
@AllArgsConstructor
public class BinaryIndexInvalidMappingIT extends KNNRestTestCase {
@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

private String description;
private String indexMapping;
private String expectedExceptionMessage;
Expand Down
24 changes: 16 additions & 8 deletions src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -70,12 +71,6 @@ public class ExpandNestedDocsIT extends KNNRestTestCase {
private Mode mode;
private Integer dimension;

@After
@SneakyThrows
public final void cleanUp() {
deleteKNNIndex(INDEX_NAME);
}

@ParametersFactory(argumentFormatting = "description:%1$s; engine:%2$s, data_type:%3$s, mode:%4$s, dimension:%5$s")
public static Collection<Object[]> parameters() throws IOException {
int dimension = 1;
Expand All @@ -99,13 +94,19 @@ public static Collection<Object[]> parameters() throws IOException {
Mode.ON_DISK,
// Currently, on disk mode only supports dimension of multiple of 8
dimension * 8
)
),
$("Nmslib with float format and in memory mode", KNNEngine.NMSLIB, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, dimension)
)
);
}

@SneakyThrows
public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc() {
if (engine == KNNEngine.NMSLIB) {
// NMSLIB does not support filtering
return;
}

int numberOfNestedFields = 2;
createKnnIndex(engine, mode, dimension, dataType);
addRandomVectorsWithTopLevelField(1, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE);
Expand All @@ -131,6 +132,11 @@ public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc(

@SneakyThrows
public void testExpandNestedDocs_whenFilteredOnNestedFieldDoc_thenReturnFilteredNestedDoc() {
if (engine == KNNEngine.NMSLIB) {
// NMSLIB does not support filtering
return;
}

int numberOfNestedFields = 2;
createKnnIndex(engine, mode, dimension, dataType);
addRandomVectorsWithMetadata(1, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE));
Expand Down Expand Up @@ -175,7 +181,9 @@ public void testExpandNestedDocs_whenMultiShards_thenReturnCorrectResult() {

// Run
Float[] queryVector = createVector();
Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, numberOfDocuments, queryVector);
// NMSLIB does not support dedup per parent documents. Therefore, we need to multiply the k by number of nestedFields.
int k = engine == KNNEngine.NMSLIB ? numberOfDocuments * numberOfNestedFields : numberOfDocuments;
Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, k, queryVector);

// Verify
String entity = EntityUtils.toString(response.getEntity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS });
}

@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

@SneakyThrows
public void testFilteredSearchHnswBinary_whenDoingApproximateSearch_thenReturnCorrectResults() {
validateFilteredSearchHnswBinary(INDEX_NAME, false);
Expand Down
Loading

0 comments on commit f298037

Please sign in to comment.