Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 2.x] Fixed exception for case when Hybrid query being wrapped into bool query #496

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public DocIdSetIterator iterator() {
*/
@Override
public float getMaxScore(int upTo) throws IOException {
return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
try {
return scorer.getMaxScore(upTo);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -33,6 +34,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index";
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index";
private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index";
private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index";
private static final String TEST_QUERY_TEXT = "greetings";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final String TEST_QUERY_TEXT3 = "hello";
Expand Down Expand Up @@ -188,6 +190,35 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult(
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);

MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
MatchQueryBuilder matchQuery2Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(matchQueryBuilder);
hybridQueryBuilderOnlyTerm.add(matchQuery2Builder);
MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
boolQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertTrue(getHitCount(searchResponseAsMap) > 0);
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f);

Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertTrue((int) total.get("value") > 0);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -242,32 +273,45 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
TEST_MULTI_DOC_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE))
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME);
}

if (TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD)) {
prepareKnnIndex(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
1
);
assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME));
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
addKnnDoc(
testMultiDocIndexName,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
testMultiDocIndexName,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
testMultiDocIndexName,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
);
assertEquals(3, getDocCount(testMultiDocIndexName));
}

private List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.neuralsearch.query;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -21,6 +23,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.tests.util.TestUtil;

Expand Down Expand Up @@ -169,6 +172,63 @@ public void testWithRandomDocuments_whenMultipleScorersAndSomeScorersEmpty_thenR
testWithQuery(docs, scores, hybridQueryScorer);
}

@SneakyThrows
public void testMaxScore_whenMultipleScorers_thenSuccessful() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(
scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())),
scorer(docs, scores, fakeWeight(new MatchNoDocsQuery()))
)
);

float maxScore = hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithSomeNullSubScorers = new HybridQueryScorer(
weight,
Arrays.asList(null, scorer(docs, scores, fakeWeight(new MatchAllDocsQuery())), null)
);

maxScore = hybridQueryScorerWithSomeNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertTrue(maxScore > 0.0f);

HybridQueryScorer hybridQueryScorerWithAllNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(null, null));

maxScore = hybridQueryScorerWithAllNullSubScorers.getMaxScore(Integer.MAX_VALUE);
assertEquals(0.0f, maxScore, 0.0f);
}

@SneakyThrows
public void testMaxScoreFailures_whenScorerThrowsException_thenFail() {
int maxDocId = TestUtil.nextInt(random(), 10, 10_000);
Pair<int[], float[]> docsAndScores = generateDocuments(maxDocId);
int[] docs = docsAndScores.getLeft();
float[] scores = docsAndScores.getRight();

Weight weight = mock(Weight.class);

Scorer scorer = mock(Scorer.class);
when(scorer.getWeight()).thenReturn(fakeWeight(new MatchAllDocsQuery()));
when(scorer.iterator()).thenReturn(iterator(docs));
when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception"));

HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer));

RuntimeException runtimeException = expectThrows(
RuntimeException.class,
() -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE)
);
assertTrue(runtimeException.getMessage().contains("Test exception"));
}

private Pair<int[], float[]> generateDocuments(int maxDocId) {
final int numDocs = RandomizedTest.randomIntBetween(1, maxDocId / 2);
final int[] docs = new int[numDocs];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() {
List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext))
);
IndexSearcher searcher = newSearcher(reader);
Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f);
Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);

assertNotNull(weight);

LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0);
Scorer scorer = weight.scorer(leafReaderContext);

assertNotNull(scorer);
Expand Down
Loading