From dff866d2d3d08ee50e2224d9b2943a944b467b7c Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 10 Nov 2023 17:23:59 -0800 Subject: [PATCH] Adding null check for case when hybrid query wrapped into bool query Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/query/HybridQueryScorer.java | 2 +- .../neuralsearch/query/HybridQueryIT.java | 88 ++++++++++++++----- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b92edd850..10063185c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 42f0a56e6..9ee3e9dbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -83,7 +83,7 @@ public DocIdSetIterator iterator() { */ @Override public float getMaxScore(int upTo) throws IOException { - return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> { + return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> { try { return scorer.getMaxScore(upTo); } catch (IOException e) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index eec6955ff..171d2f4a4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.SpaceType; @@ -33,6 +34,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -188,6 +190,35 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows + public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + MatchQueryBuilder matchQuery2Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + hybridQueryBuilderOnlyTerm.add(matchQuery2Builder); + MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + boolQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertTrue(getHitCount(searchResponseAsMap) > 0); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -242,32 +273,45 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_NAME, Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "1", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector1).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT1) - ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "2", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector2).toArray()) - ); - addKnnDoc( - TEST_MULTI_DOC_INDEX_NAME, - "3", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector3).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT2) + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); + } + + if (TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 ); - assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } } + private void addDocsToIndex(final String testMultiDocIndexName) { + addKnnDoc( + testMultiDocIndexName, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + testMultiDocIndexName, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + testMultiDocIndexName, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + assertEquals(3, getDocCount(testMultiDocIndexName)); + } + private List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits");