Skip to content

Commit

Permalink
Added null checks for fieldInfo in ExactSearcher to avoid NPE while r…
Browse files Browse the repository at this point in the history
…unning exact search for segments with no vector field (opensearch-project#2278)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Nov 22, 2024
1 parent 2d1a408 commit 7523cc3
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 10 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
### Documentation
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
### Refactoring
15 changes: 14 additions & 1 deletion src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import lombok.experimental.UtilityClass;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
Expand All @@ -27,7 +29,7 @@
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;

/**
* A utility class to extract information from FieldInfo.
* A utility class to extract information from FieldInfo and also provides utility functions to extract fieldInfo
*/
@UtilityClass
public class FieldInfoExtractor {
Expand Down Expand Up @@ -103,4 +105,15 @@ public static SpaceType getSpaceType(final ModelDao modelDao, final FieldInfo fi
}
return modelMetadata.getSpaceType();
}

/**
* Get the field info for the given field name, do a null check on the fieldInfo, as this function can return null,
* if the field is not found.
* @param leafReader {@link LeafReader}
* @param fieldName {@link String}
* @return {@link FieldInfo}
*/
public static @Nullable FieldInfo getFieldInfo(final LeafReader leafReader, final String fieldName) {
return leafReader.getFieldInfos().fieldInfo(fieldName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.knn.common.FieldInfoExtractor;

import java.io.IOException;

Expand Down Expand Up @@ -40,7 +41,7 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}
Expand Down
22 changes: 17 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
Expand All @@ -59,7 +60,11 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
final KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
// because of any reason if we are not able to get KNNIterator, return an empty map
if (iterator == null) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
Expand All @@ -74,8 +79,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param leafReaderContext {@link LeafReaderContext}
* @param exactSearcherContext {@link ExactSearcherContext}
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
Expand All @@ -87,7 +92,10 @@ private Map<Integer, Float> doRadialSearch(
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
return Collections.emptyMap();
}
final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
if (KNNEngine.FAISS != engine) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine));
Expand Down Expand Up @@ -149,7 +157,11 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return null;
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ private Map<Integer, Float> doANNSearch(
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo == null) {
log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
Expand Down Expand Up @@ -479,7 +479,7 @@ private boolean isFilteredExactSearchRequireAfterANNSearch(final int filterIdsCo
*/
private boolean isMissingNativeEngineFiles(LeafReaderContext context) {
final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
// if segment has no documents with at least 1 vector field, field info will be null
if (fieldInfo == null) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.common;

import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.LeafReader;
import org.junit.Assert;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
Expand Down Expand Up @@ -63,4 +65,15 @@ public void testExtractVectorDataType() {
when(fieldInfo.getAttribute("model_id")).thenReturn(null);
assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo));
}

public void testGetFieldInfo_whenDifferentInput_thenSuccess() {
LeafReader leafReader = Mockito.mock(LeafReader.class);
FieldInfos fieldInfos = Mockito.mock(FieldInfos.class);
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(leafReader.getFieldInfos()).thenReturn(fieldInfos);
Mockito.when(fieldInfos.fieldInfo("invalid")).thenReturn(null);
Mockito.when(fieldInfos.fieldInfo("valid")).thenReturn(fieldInfo);
Assert.assertNull(FieldInfoExtractor.getFieldInfo(leafReader, "invalid"));
Assert.assertEquals(fieldInfo, FieldInfoExtractor.getFieldInfo(leafReader, "valid"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
Expand Down Expand Up @@ -50,6 +51,59 @@ public class ExactSearcherTests extends KNNTestCase {

private static final String SEGMENT_NAME = "0";

@SneakyThrows
public void testExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() {
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
final KNNQuery query = KNNQuery.builder().field(FIELD_NAME).queryVector(queryVector).k(10).indexName(INDEX_NAME).build();

final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder =
ExactSearcher.ExactSearcherContext.builder().knnQuery(query);

ExactSearcher exactSearcher = new ExactSearcher(null);
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FieldInfos fieldInfos = mock(FieldInfos.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(query.getField())).thenReturn(null);
Map<Integer, Float> docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build());
Mockito.verify(fieldInfos).fieldInfo(query.getField());
Mockito.verify(reader).getFieldInfos();
Mockito.verify(leafReaderContext).reader();
assertEquals(0, docIds.size());
}

@SneakyThrows
public void testRadialSearchExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() {
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
KNNQuery.Context context = new KNNQuery.Context(10);
final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(queryVector)
.context(context)
.radius(1.0f)
.indexName(INDEX_NAME)
.build();

final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder =
ExactSearcher.ExactSearcherContext.builder().knnQuery(query);

ExactSearcher exactSearcher = new ExactSearcher(null);
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FieldInfos fieldInfos = mock(FieldInfos.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(query.getField())).thenReturn(null);
Map<Integer, Float> docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build());
Mockito.verify(fieldInfos).fieldInfo(query.getField());
Mockito.verify(reader).getFieldInfos();
Mockito.verify(leafReaderContext).reader();
assertEquals(0, docIds.size());
}

@SneakyThrows
public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
try (MockedStatic<KNNVectorValuesFactory> valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) {
Expand All @@ -75,6 +129,7 @@ public void testRadialSearch_whenNoEngineFiles_thenSuccess() {
.queryVector(queryVector)
.radius(radius)
.indexName(INDEX_NAME)
.vectorDataType(VectorDataType.FLOAT)
.context(context)
.build();

Expand Down

0 comments on commit 7523cc3

Please sign in to comment.