diff --git a/CHANGELOG.md b/CHANGELOG.md index 5290345db..2c790ead3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x) ### Features +- Pagination in Hybrid query ([]()) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..8d737efae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -58,7 +58,21 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + int fromValueForSingleShard = 0; + boolean isSingleShard = false; + if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) { + isSingleShard = true; + fromValueForSingleShard = searchPhaseContext.getRequest().source().from(); + } + + normalizationWorkflow.execute( + querySearchResults, + fetchSearchResult, + normalizationTechnique, + combinationTechnique, + fromValueForSingleShard, + isSingleShard + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index fe9dd64b6..97ce9af20 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -55,7 +55,9 @@ public void execute( final List querySearchResults, final Optional fetchSearchResultOptional, final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique + final ScoreCombinationTechnique combinationTechnique, + final int fromValueForSingleShard, + final boolean isSingleShard ) { // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -73,6 +75,8 @@ public void execute( .scoreCombinationTechnique(combinationTechnique) .querySearchResults(querySearchResults) .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .fromValueForSingleShard(fromValueForSingleShard) + .isSingleShard(isSingleShard) .build(); // combine @@ -82,7 +86,7 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); + updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, fromValueForSingleShard); } /** @@ -123,10 +127,14 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); + if (combineScoresDTO.isSingleShard()) { + querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); + } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } - if (from > 0 && from > totalScoreDocsCount) { + if ((from > 0 || combineScoresDTO.getFromValueForSingleShard() > 0) + && (from > totalScoreDocsCount || combineScoresDTO.getFromValueForSingleShard() > totalScoreDocsCount)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") ); @@ -189,7 +197,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) { private void updateOriginalFetchResults( final List querySearchResults, final Optional fetchSearchResultOptional, - final List docIds + final List docIds, + final int fromValueForSingleShard ) { if (fetchSearchResultOptional.isEmpty()) { return; @@ -221,14 +230,26 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + // iterate over the normalized/combined scores, that solves (1) and (3) - SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard]; + for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { + ScoreDoc scoreDoc = topDocs.scoreDocs[i]; // get fetched hit content by doc_id SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); // update score to normalized/combined value (3) searchHit.score(scoreDoc.score); - return searchHit; - }).toArray(SearchHit[]::new); + updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; + } + + // iterate over the normalized/combined scores, that solves (1) and (3) + // SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + // // get fetched hit content by doc_id + // SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); + // // update score to normalized/combined value (3) + // searchHit.score(scoreDoc.score); + // return searchHit; + // }).toArray(SearchHit[]::new); SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java index c4783969b..42ebf6ea2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java @@ -29,4 +29,6 @@ public class CombineScoresDto { private List querySearchResults; @Nullable private Sort sort; + private int fromValueForSingleShard; + private boolean isSingleShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index cc8f63175..4a979995a 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -6,7 +6,9 @@ import java.util.Locale; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Weight; @@ -56,6 +58,7 @@ * In most cases it will be wrapped in MultiCollectorManager. */ @RequiredArgsConstructor +@Log4j2 public abstract class HybridCollectorManager implements CollectorManager { private final int numHits; @@ -68,6 +71,7 @@ public abstract class HybridCollectorManager implements CollectorManager 0) { + searchContext.from(0); + } Weight filteringWeight = null; // Check for post filter to create weight for filter query and later use that weight in the search workflow @@ -412,6 +416,42 @@ private ReduceableSearchResult reduceSearchResults(final List 0 && paginationDepth == 0) { + return DEFAULT_PAGINATION_DEPTH; + } else { + return searchContext.from() + searchContext.size(); + } + } + + private static void validatePaginationDepth(int depth) { + if (depth < 0 || depth > 10000) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", depth) + ); + } + } + /** * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to * use saved state of collector diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index a5505d671..12a38c3de 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -60,10 +60,6 @@ public boolean searchWith( validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { - // TODO remove this check after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. - // if (searchContext.from() != 0) { - // throw new IllegalArgumentException("In the current OpenSearch version pagination is not supported with hybrid query"); - // } Query hybridQuery = extractHybridQuery(searchContext, query); QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext); return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); @@ -214,6 +210,5 @@ protected boolean searchWithCollector( hasTimeout ); } - } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..fc700da92 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -6,6 +6,8 @@ import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -272,7 +274,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -328,7 +330,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -417,7 +419,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 59fb51563..0b5b9f978 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; @@ -76,7 +77,9 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ); TestUtils.assertQueryResultScores(querySearchResults); @@ -118,7 +121,9 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResults, Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); @@ -177,7 +182,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ); TestUtils.assertQueryResultScores(querySearchResults); @@ -237,7 +244,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ); TestUtils.assertQueryResultScores(querySearchResults); @@ -291,7 +300,9 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ) ); } @@ -341,13 +352,71 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu querySearchResults, Optional.of(fetchSearchResult), ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false ); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); } + public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + null + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + // requested page is out of bound for the total number of results + querySearchResult.from(17); + querySearchResults.add(querySearchResult); + } + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.empty(), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD, + 0, + false + ) + ); + + assertEquals( + String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results"), + illegalArgumentException.getMessage() + ); + } + private static SearchHits getSearchHits() { SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 610e08dd0..e20dd5e01 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -28,6 +28,7 @@ import org.junit.Before; import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -320,6 +321,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); hybridQueryBuilderOnlyTerm.add(termQueryBuilder); hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + hybridQueryBuilderOnlyTerm.paginationDepth(10); Map searchResponseAsMap = search( TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, @@ -793,46 +795,178 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS } } - // TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved. @SneakyThrows - public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() { + public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); - HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); - hybridQueryBuilderOnlyTerm.add(matchQueryBuilder); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ResponseException exceptionNoNestedTypes = expectThrows( - ResponseException.class, - () -> search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - hybridQueryBuilderOnlyTerm, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - null, - null, - null, - false, - null, - 10 - ) + @SneakyThrows + public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } - ); + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } - org.hamcrest.MatcherAssert.assertThat( - exceptionNoNestedTypes.getMessage(), - allOf( - containsString("In the current OpenSearch version pagination is not supported with hybrid query"), - containsString("illegal_argument_exception") - ) - ); + @SneakyThrows + public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); } } + @SneakyThrows + public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(10); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + @SneakyThrows + public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 5 + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("Reached end of search result, increase pagination_depth value to see more results")) + ); + } + + @SneakyThrows + public void testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(String indexName) { + HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); + hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); + hybridQueryBuilderOnlyMatchAll.paginationDepth(100001); + + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 0 + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("Pagination depth should lie in the range of 1-1000. Received: 100001")) + ); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index f821e7ddf..5cbcc7b2a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -96,16 +96,19 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery query1 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 0 ); HybridQuery query2 = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 0 ); HybridQuery query3 = new HybridQuery( List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 5 ); QueryUtils.check(query1); QueryUtils.checkEqual(query1, query2); @@ -120,6 +123,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); + assertEquals(5, query3.getPaginationDepth()); } @SneakyThrows @@ -142,7 +146,8 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), + 0 ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it @@ -161,11 +166,11 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K); Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext); - HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery)); + HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery), 0); rewritten = hybridQueryWithKnn.rewrite(reader); assertSame(hybridQueryWithKnn, rewritten); - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 0)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -198,7 +203,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), + 0 ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -244,7 +250,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)))); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), 0); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -280,7 +286,8 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR IndexSearcher searcher = newSearcher(reader); HybridQuery query = new HybridQuery( - List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))) + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + 0 ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -294,7 +301,7 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of())); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 0)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } @@ -311,7 +318,8 @@ public void testToString_whenCallQueryToString_thenSuccessful() { new BoolQueryBuilder().should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) - ) + ), + 0 ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -331,7 +339,8 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), - List.of(filter) + List.of(filter), + 0 ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index f44e762f0..27718df41 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -69,7 +69,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -129,7 +129,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index de9c6006b..a94980f30 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -46,12 +46,14 @@ import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; import org.opensearch.search.DocValueFormat; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.mockito.ArgumentMatchers.eq; @@ -78,7 +80,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -109,7 +111,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -140,7 +142,7 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -184,7 +186,7 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -229,7 +231,8 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); HybridQuery hybridQueryWithTerm = new HybridQuery( - List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), + 0 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -330,7 +333,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -367,7 +370,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -398,7 +401,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -495,7 +498,8 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD List.of( QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) - ) + ), + 0 ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -617,7 +621,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 0); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -734,4 +738,158 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD reader2.close(); directory2.close(); } + + @SneakyThrows + public void testNumDocsCount_whenPaginationDepthIsLessThanZero_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), -1); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: -1"), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testNumDocsCount_whenPaginationDepthIsGreaterThan10000_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10001); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: 10001"), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testCreateCollectorManager_whenPaginationDepthAndFromAreEqualToZero_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testCreateCollectorManager_whenPaginationDepthIsEqualToZeroAndFromIsGreaterThanZero_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + // From >0 + searchContext.from(5); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + // if pagination_depth ==0 then internally by default it will pick 10 as the depth + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testScrollWithHybridQuery_thenFail() { + SearchContext searchContext = mock(SearchContext.class); + ScrollContext scrollContext = new ScrollContext(); + when(searchContext.scrollContext()).thenReturn(scrollContext); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "Scroll operation is not supported in hybrid query"), + illegalArgumentException.getMessage() + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index be9dbc2cc..3be4ad090 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -45,7 +45,8 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .rewrite(mockQueryShardContext) .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) - ) + ), + 0 ); SearchContext searchContext = mock(SearchContext.class); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 7ad0e63f8..00f607dfd 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -594,7 +594,6 @@ protected Map search( if (requestParams != null && !requestParams.isEmpty()) { requestParams.forEach(request::addParameter); } - logger.info("Sorting request " + builder.toString()); request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));