Skip to content

Commit

Permalink
Adds in lazy execution for Lucene kNN queries
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal committed Dec 5, 2024
1 parent 9276c77 commit 564c783
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand Down Expand Up @@ -106,9 +107,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter));
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter));
default:
throw new IllegalArgumentException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

/**
* LuceneEngineKnnVectorQuery is a wrapper around a vector queries for the Lucene engine.
* This enables us to defer rewrites until weight creation to optimize repeated execution
* of Lucene based k-NN queries.
*/
@AllArgsConstructor
@Log4j2
public class LuceneEngineKnnVectorQuery extends Query {
private final Query luceneQuery;

/*
Prevents repeated rewrites of the query for the Lucene engine.
*/
@Override
public Query rewrite(IndexSearcher indexSearcher) {
return this;
}

/*
Rewrites the query just before weight creation.
*/
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
Query rewrittenQuery = luceneQuery.rewrite(searcher);
return rewrittenQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public String toString(String s) {
return luceneQuery.toString();
}

@Override
public void visit(QueryVisitor queryVisitor) {
queryVisitor.visitLeaf(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o;
return luceneQuery.equals(otherQuery.luceneQuery);
}

@Override
public int hashCode() {
return luceneQuery.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.lucene.search.FloatVectorSimilarityQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.junit.Before;
Expand All @@ -33,6 +32,7 @@
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -512,7 +512,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {

// Then
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
assertTrue(query.getClass().isAssignableFrom(LuceneEngineKnnVectorQuery.class));
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.junit.Before;
import org.mockito.Mock;
Expand All @@ -29,6 +27,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand Down Expand Up @@ -119,7 +118,7 @@ public void testCreateLuceneDefaultQuery() {
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
assertEquals(KnnFloatVectorQuery.class, query.getClass());
assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass());
}
}

Expand All @@ -137,7 +136,7 @@ public void testLuceneFloatVectorQuery() {
);

// efsearch > k
Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null);
Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null));
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
Expand All @@ -152,7 +151,7 @@ public void testLuceneFloatVectorQuery() {
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
Expand All @@ -165,7 +164,7 @@ public void testLuceneFloatVectorQuery() {
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);
}

Expand All @@ -183,7 +182,7 @@ public void testLuceneByteVectorQuery() {
);

// efsearch > k
Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null);
Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null));
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
Expand All @@ -198,7 +197,7 @@ public void testLuceneByteVectorQuery() {
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
Expand All @@ -211,7 +210,7 @@ public void testLuceneByteVectorQuery() {
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);
}

Expand All @@ -234,7 +233,7 @@ public void testCreateLuceneQueryWithFilter() {
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertEquals(KnnFloatVectorQuery.class, query.getClass());
assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass());
}
}

Expand Down Expand Up @@ -310,8 +309,8 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc
}

public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() {
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, LuceneEngineKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, LuceneEngineKnnVectorQuery.class);
}

public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnToChildBlockJoinQueryForFilters() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.openMocks;

public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase {

@Mock
IndexSearcher indexSearcher;

@Mock
Query luceneQuery;

@Mock
Weight weight;

@Mock
QueryVisitor queryVisitor;

@Spy
@InjectMocks
LuceneEngineKnnVectorQuery objectUnderTest;

@Override
public void setUp() throws Exception {
super.setUp();
openMocks(this);
when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery);
when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight);
}

public void testRewrite() {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
}

public void testCreateWeight() throws Exception {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f);
verify(luceneQuery, times(1)).rewrite(indexSearcher);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
assertEquals(weight, actualWeight);
}

public void testVisit() {
objectUnderTest.visit(queryVisitor);
verify(queryVisitor).visitLeaf(objectUnderTest);
}

public void testEquals() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery, otherQuery);
assertEquals(mainQuery, mainQuery);
assertNotEquals(mainQuery, null);
assertNotEquals(mainQuery, new Object());
LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null);
assertNotEquals(mainQuery, otherQuery2);
}

public void testHashCode() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.hashCode(), luceneQuery.hashCode());
}

public void testToString() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.toString(), luceneQuery.toString());
}
}

0 comments on commit 564c783

Please sign in to comment.