From e0ed9cb9bfaf6368047efcfec5c986d9c3c2b912 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Sun, 3 Nov 2024 17:14:39 -0800 Subject: [PATCH 1/2] refactoring code Signed-off-by: Varun Jain --- .../common/MinClusterVersionUtil.java | 5 + .../processor/NormalizationProcessor.java | 26 ++- .../NormalizationProcessorWorkflow.java | 56 ++++--- .../processor/combination/ScoreCombiner.java | 2 +- .../CombineScoresDto.java | 6 +- .../dto/NormalizationExecuteDto.java | 38 +++++ .../neuralsearch/query/HybridQuery.java | 15 +- .../query/HybridQueryBuilder.java | 39 ++++- .../search/query/HybridCollectorManager.java | 44 ++--- .../NormalizationProcessorTests.java | 16 +- .../NormalizationProcessorWorkflowTests.java | 154 +++++++++++------- .../ScoreCombinationTechniqueTests.java | 2 +- .../query/HybridQueryBuilderTests.java | 89 ++++++++++ .../neuralsearch/query/HybridQueryIT.java | 47 +++--- .../neuralsearch/query/HybridQueryTests.java | 36 ++-- .../query/HybridQueryWeightTests.java | 6 +- .../HybridAggregationProcessorTests.java | 4 +- .../query/HybridCollectorManagerTests.java | 113 +++---------- .../util/HybridQueryUtilTests.java | 2 +- 19 files changed, 427 insertions(+), 273 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/{combination => dto}/CombineScoresDto.java (82%) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java diff --git a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java index 0f5cbefcf..05e04e84a 100644 --- a/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/common/MinClusterVersionUtil.java @@ -22,6 +22,7 @@ public final class MinClusterVersionUtil { private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0; + private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0; // Note this minimal version will act as a override private static final Map MINIMAL_VERSION_NEURAL = ImmutableMap.builder() @@ -38,6 +39,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH); } + public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY); + } + public static boolean isClusterOnOrAfterMinReqVersion(String key) { Version version; if (MINIMAL_VERSION_NEURAL.containsKey(key)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 8d737efae..a30bd7f56 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -16,6 +16,7 @@ import org.opensearch.action.search.SearchPhaseName; import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; @@ -58,21 +59,16 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - int fromValueForSingleShard = 0; - boolean isSingleShard = false; - if (searchPhaseContext.getNumShards() == 1 && fetchSearchResult.isPresent()) { - isSingleShard = true; - fromValueForSingleShard = searchPhaseContext.getRequest().source().from(); - } - - normalizationWorkflow.execute( - querySearchResults, - fetchSearchResult, - normalizationTechnique, - combinationTechnique, - fromValueForSingleShard, - isSingleShard - ); + // Builds data transfer object to pass into execute + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationWorkflow.execute(normalizationExecuteDto); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 97ce9af20..1e3c8fc0d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -18,10 +18,12 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.FieldDoc; +import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; @@ -47,18 +49,17 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param normalizationExecuteDto contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, searchPhaseContext. */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique, - final int fromValueForSingleShard, - final boolean isSingleShard - ) { + public void execute(final NormalizationExecuteDto normalizationExecuteDto) { + final List querySearchResults = normalizationExecuteDto.getQuerySearchResults(); + final Optional fetchSearchResultOptional = normalizationExecuteDto.getFetchSearchResultOptional(); + final ScoreNormalizationTechnique normalizationTechnique = normalizationExecuteDto.getNormalizationTechnique(); + final ScoreCombinationTechnique combinationTechnique = normalizationExecuteDto.getCombinationTechnique(); + final SearchPhaseContext searchPhaseContext = normalizationExecuteDto.getSearchPhaseContext(); + // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); @@ -75,8 +76,8 @@ public void execute( .scoreCombinationTechnique(combinationTechnique) .querySearchResults(querySearchResults) .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) - .fromValueForSingleShard(fromValueForSingleShard) - .isSingleShard(isSingleShard) + .fromValueForSingleShard(searchPhaseContext.getRequest().source().from()) + .isFetchResultsPresent(fetchSearchResultOptional.isPresent()) .build(); // combine @@ -86,7 +87,12 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds, fromValueForSingleShard); + updateOriginalFetchResults( + querySearchResults, + fetchSearchResultOptional, + unprocessedDocIds, + combineScoresDTO.getFromValueForSingleShard() + ); } /** @@ -117,7 +123,6 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) final List querySearchResults = combineScoresDTO.getQuerySearchResults(); final List queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults); final Sort sort = combineScoresDTO.getSort(); - final int from = querySearchResults.get(0).from(); int totalScoreDocsCount = 0; for (int index = 0; index < querySearchResults.size(); index++) { QuerySearchResult querySearchResult = querySearchResults.get(index); @@ -127,14 +132,16 @@ private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) buildTopDocs(updatedTopDocs, sort), maxScoreForShard(updatedTopDocs, sort != null) ); - if (combineScoresDTO.isSingleShard()) { + // Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard. + // This will ensure the trimming of the results. + if (combineScoresDTO.isFetchResultsPresent()) { querySearchResult.from(combineScoresDTO.getFromValueForSingleShard()); } querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats()); } - if ((from > 0 || combineScoresDTO.getFromValueForSingleShard() > 0) - && (from > totalScoreDocsCount || combineScoresDTO.getFromValueForSingleShard() > totalScoreDocsCount)) { + final int from = querySearchResults.get(0).from(); + if (from > 0 && from > totalScoreDocsCount) { throw new IllegalArgumentException( String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results") ); @@ -231,6 +238,9 @@ private void updateOriginalFetchResults( QuerySearchResult querySearchResult = querySearchResults.get(0); TopDocs topDocs = querySearchResult.topDocs().topDocs; + // When normalization process will execute before the fetch phase, then from =0 is applicable. + // When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the + // search request. // iterate over the normalized/combined scores, that solves (1) and (3) SearchHit[] updatedSearchHitArray = new SearchHit[topDocs.scoreDocs.length - fromValueForSingleShard]; for (int i = fromValueForSingleShard; i < topDocs.scoreDocs.length; i++) { @@ -242,14 +252,6 @@ private void updateOriginalFetchResults( updatedSearchHitArray[i - fromValueForSingleShard] = searchHit; } - // iterate over the normalized/combined scores, that solves (1) and (3) - // SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { - // // get fetched hit content by doc_id - // SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); - // // update score to normalized/combined value (3) - // searchHit.score(scoreDoc.score); - // return searchHit; - // }).toArray(SearchHit[]::new); SearchHits updatedSearchHits = new SearchHits( updatedSearchHitArray, querySearchResult.getTotalHits(), diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 9ae58e8f0..c70ab0b78 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -26,6 +26,7 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; /** * Abstracts combination of scores in query search results. @@ -69,7 +70,6 @@ public void combineScores(final CombineScoresDto combineScoresDTO) { Sort sort = combineScoresDTO.getSort(); combineScoresDTO.getQueryTopDocs() .forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort)); - } private void combineShardScores( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java b/src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java similarity index 82% rename from src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java rename to src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java index 42ebf6ea2..77444f383 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/CombineScoresDto.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/dto/CombineScoresDto.java @@ -2,9 +2,10 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.processor.combination; +package org.opensearch.neuralsearch.processor.dto; import java.util.List; + import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; @@ -12,6 +13,7 @@ import org.apache.lucene.search.Sort; import org.opensearch.common.Nullable; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.search.query.QuerySearchResult; /** @@ -30,5 +32,5 @@ public class CombineScoresDto { @Nullable private Sort sort; private int fromValueForSingleShard; - private boolean isSingleShard; + private boolean isFetchResultsPresent; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java b/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java new file mode 100644 index 000000000..1ddda83d2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/dto/NormalizationExecuteDto.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import java.util.List; +import java.util.Optional; + +/** + * DTO object to hold data in NormalizationProcessorWorkflow class + * in NormalizationProcessorWorkflow. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizationExecuteDto { + @NonNull + private List querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; + @NonNull + private SearchPhaseContext searchPhaseContext; +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 405df5f1a..14514df60 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -15,6 +15,7 @@ import java.util.Objects; import java.util.concurrent.Callable; +import lombok.Getter; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -31,21 +32,25 @@ * Implementation of Query interface for type "hybrid". It allows execution of multiple sub-queries and collect individual * scores for each sub-query. */ +@Getter public final class HybridQuery extends Query implements Iterable { private final List subQueries; - private int paginationDepth; + private Integer paginationDepth; /** * Create new instance of hybrid query object based on collection of sub queries and filter query * @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores * @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is */ - public HybridQuery(final Collection subQueries, final List filterQueries, int paginationDepth) { + public HybridQuery(final Collection subQueries, final List filterQueries, Integer paginationDepth) { Objects.requireNonNull(subQueries, "collection of queries must not be null"); if (subQueries.isEmpty()) { throw new IllegalArgumentException("collection of queries must not be empty"); } + if (paginationDepth != null && paginationDepth == 0) { + throw new IllegalArgumentException("pagination depth must not be zero"); + } if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) { this.subQueries = new ArrayList<>(subQueries); } else { @@ -61,7 +66,7 @@ public HybridQuery(final Collection subQueries, final List filterQ this.paginationDepth = paginationDepth; } - public HybridQuery(final Collection subQueries, final int paginationDepth) { + public HybridQuery(final Collection subQueries, final Integer paginationDepth) { this(subQueries, List.of(), paginationDepth); } @@ -192,10 +197,6 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } - public int getPaginationDepth() { - return paginationDepth; - } - /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 1353a0b61..03e54c9fb 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -35,6 +35,8 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.common.MinClusterVersionUtil.isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery; + /** * Class abstract creation of a Query type "hybrid". Hybrid query will allow execution of multiple sub-queries and * collects score for each of those sub-query. @@ -53,14 +55,17 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); private String fieldName; - private int paginationDepth; - + private Integer paginationDepth; static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 1; + private static final int UPPER_BOUND_OF_PAGINATION_DEPTH = 10000; public HybridQueryBuilder(StreamInput in) throws IOException { super(in); queries.addAll(readQueries(in)); - paginationDepth = in.readInt(); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + paginationDepth = in.readOptionalInt(); + } } /** @@ -71,7 +76,9 @@ public HybridQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { writeQueries(out, queries); - out.writeInt(paginationDepth); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + out.writeOptionalInt(paginationDepth); + } } /** @@ -154,7 +161,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { float boost = AbstractQueryBuilder.DEFAULT_BOOST; - int paginationDepth = 0; + Integer paginationDepth = null; final List queries = new ArrayList<>(); String queryName = null; @@ -224,7 +231,10 @@ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOEx HybridQueryBuilder compoundQueryBuilder = new HybridQueryBuilder(); compoundQueryBuilder.queryName(queryName); compoundQueryBuilder.boost(boost); - compoundQueryBuilder.paginationDepth(paginationDepth); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + validatePaginationDepth(paginationDepth); + compoundQueryBuilder.paginationDepth(paginationDepth); + } for (QueryBuilder query : queries) { compoundQueryBuilder.add(query); } @@ -244,7 +254,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I if (changed) { newBuilder.queryName(queryName); newBuilder.boost(boost); - newBuilder.paginationDepth(paginationDepth); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + newBuilder.paginationDepth(paginationDepth); + } return newBuilder; } else { return this; @@ -267,7 +279,9 @@ protected boolean doEquals(HybridQueryBuilder obj) { EqualsBuilder equalsBuilder = new EqualsBuilder(); equalsBuilder.append(fieldName, obj.fieldName); equalsBuilder.append(queries, obj.queries); - equalsBuilder.append(paginationDepth, obj.paginationDepth); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + equalsBuilder.append(paginationDepth, obj.paginationDepth); + } return equalsBuilder.isEquals(); } @@ -308,6 +322,15 @@ private Collection toQueries(Collection queryBuilders, Quer return queries; } + private static void validatePaginationDepth(Integer paginationDepth) { + if (paginationDepth != null + && (paginationDepth < LOWER_BOUND_OF_PAGINATION_DEPTH || paginationDepth > UPPER_BOUND_OF_PAGINATION_DEPTH)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", paginationDepth) + ); + } + } + /** * visit method to parse the HybridQueryBuilder by a visitor */ diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 7b6d8229b..e5427a53a 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -74,7 +74,6 @@ public abstract class HybridCollectorManager implements CollectorManager 0) { searchContext.from(0); } @@ -477,34 +483,30 @@ private ReduceableSearchResult reduceSearchResults(final List 0 && (Objects.isNull(paginationDepth))) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth is missing in the search request")); } - - if (paginationDepth != 0) { - validatePaginationDepth(paginationDepth); + if (paginationDepth != null) { return paginationDepth; - } else if (searchContext.from() > 0 && paginationDepth == 0) { - return DEFAULT_PAGINATION_DEPTH; } else { + // Switch to from+size retrieval size when pagination_depth is null. return searchContext.from() + searchContext.size(); } } - private static void validatePaginationDepth(int depth) { - if (depth < 0 || depth > 10000) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: %s", depth) - ); + private static HybridQuery getHybridQueryFromAbstractQuery(Query query) { + HybridQuery hybridQuery; + // In case of nested fields and alias filter, hybrid query is wrapped under bool query and lies in the first clause. + if (query instanceof BooleanQuery) { + BooleanQuery booleanQuery = (BooleanQuery) query; + hybridQuery = (HybridQuery) booleanQuery.clauses().get(0).getQuery(); + } else { + hybridQuery = (HybridQuery) query; } + return hybridQuery; } /** diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index fc700da92..218515a97 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -6,8 +6,6 @@ import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -136,6 +134,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -181,6 +180,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio } SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -205,6 +205,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -249,6 +250,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(1); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -274,7 +276,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -330,7 +332,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { @@ -346,6 +348,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -410,6 +413,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -419,7 +423,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any(), anyInt(), anyBoolean()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { @@ -435,6 +439,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() ); SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); searchRequest.setBatchedReduceSize(4); AtomicReference onPartialMergeFailure = new AtomicReference<>(); QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( @@ -497,6 +502,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); IllegalStateException exception = expectThrows( IllegalStateException.class, () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 0b5b9f978..e09aea187 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -20,8 +20,11 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.dto.NormalizationExecuteDto; import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; @@ -36,6 +39,7 @@ import org.opensearch.test.OpenSearchTestCase; public class NormalizationProcessorWorkflowTests extends OpenSearchTestCase { + private static final String INDEX_NAME = "normalization-index"; public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationCombination() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( @@ -73,14 +77,19 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDTO = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -117,14 +126,19 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResults.add(querySearchResult); } - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -178,14 +192,18 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -240,14 +258,18 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -294,17 +316,19 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ) - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -348,14 +372,20 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(0); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDto); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -399,16 +429,22 @@ public void testNormalization_whenFromIsGreaterThanResultsSize_thenFail() { querySearchResults.add(querySearchResult); } + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source().from(17); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); + + NormalizationExecuteDto normalizationExecuteDto = NormalizationExecuteDto.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - 0, - false - ) + () -> normalizationProcessorWorkflow.execute(normalizationExecuteDto) ); assertEquals( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 918f3f45b..cbfacbc88 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.processor; import java.util.Collections; -import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; +import org.opensearch.neuralsearch.processor.dto.CombineScoresDto; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 5aab7581c..e0bdfee41 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -476,6 +476,93 @@ public void testFromXContent_whenQueriesCountIsGreaterThanFive_thenFail() { assertThat(exception.getMessage(), containsString("Number of sub-queries exceeds maximum supported by [hybrid] query")); } + @SneakyThrows + public void testFromXContent_whenPaginationDepthIsInvalid_thenFail() { + setUpClusterService(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("pagination_depth", -1) + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry( + List.of( + new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(TermQueryBuilder.NAME), TermQueryBuilder::fromXContent), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(NeuralQueryBuilder.NAME), + NeuralQueryBuilder::fromXContent + ), + new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField(HybridQueryBuilder.NAME), + HybridQueryBuilder::fromXContent + ) + ) + ); + XContentParser contentParser = createParser( + namedXContentRegistry, + xContentBuilder.contentType().xContent(), + BytesReference.bytes(xContentBuilder) + ); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> HybridQueryBuilder.fromXContent(contentParser) + ); + assertThat(exception.getMessage(), containsString("Pagination depth should lie in the range of 1-1000. Received: -1")); + + XContentBuilder xContentBuilder1 = XContentFactory.jsonBuilder() + .startObject() + .field("pagination_depth", 10001) + .startArray("queries") + .startObject() + .startObject(NeuralQueryBuilder.NAME) + .startObject(VECTOR_FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .endObject() + .endObject() + .endObject() + .startObject() + .startObject(TermQueryBuilder.NAME) + .field(TEXT_FIELD_NAME, TERM_QUERY_TEXT) + .endObject() + .endObject() + .endArray() + .endObject(); + + XContentParser contentParser1 = createParser( + namedXContentRegistry, + xContentBuilder1.contentType().xContent(), + BytesReference.bytes(xContentBuilder1) + ); + contentParser1.nextToken(); + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> HybridQueryBuilder.fromXContent(contentParser1) + ); + assertThat(exception1.getMessage(), containsString("Pagination depth should lie in the range of 1-1000. Received: 10001")); + } + @SneakyThrows public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); @@ -599,6 +686,7 @@ public void testHashAndEquals_whenSameOrIdenticalObject_thenReturnEqual() { } public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { + setUpClusterService(); String modelId = "testModelId"; String fieldName = "fieldTwo"; String queryText = "query text"; @@ -687,6 +775,7 @@ public void testHashAndEquals_whenSubQueriesDifferent_thenReturnNotEqual() { @SneakyThrows public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery() { + setUpClusterService(); HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index a0d6dcd1a..0eb2238aa 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -845,7 +845,7 @@ public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessf initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { @@ -860,7 +860,7 @@ public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccess initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { @@ -875,7 +875,7 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSucces initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { @@ -890,7 +890,7 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSucce initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { @@ -927,30 +927,31 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc } @SneakyThrows - public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(String indexName) { + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(String indexName) { HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); - Map searchResponseAsMap = search( - indexName, - hybridQueryBuilderOnlyMatchAll, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - null, - null, - null, - false, - null, - 2 + ResponseException responseException = assertThrows( + ResponseException.class, + () -> search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 + ) ); - assertEquals(2, getHitCount(searchResponseAsMap)); - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(4, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); + org.hamcrest.MatcherAssert.assertThat( + responseException.getMessage(), + allOf(containsString("pagination_depth is missing in the search request")) + ); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index 5cbcc7b2a..43b609545 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -97,11 +97,11 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { HybridQuery query1 = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 0 + null ); HybridQuery query2 = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 0 + null ); HybridQuery query3 = new HybridQuery( List.of( @@ -123,7 +123,7 @@ public void testQueryBasics_whenMultipleDifferentQueries_thenSuccessful() { countOfQueries++; } assertEquals(2, countOfQueries); - assertEquals(5, query3.getPaginationDepth()); + assertEquals(5, (int) query3.getPaginationDepth()); } @SneakyThrows @@ -147,7 +147,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 0 + null ); Query rewritten = hybridQueryWithTerm.rewrite(reader); // term query is the same after we rewrite it @@ -166,11 +166,11 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K); Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext); - HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery), 0); + HybridQuery hybridQueryWithKnn = new HybridQuery(List.of(knnQuery), null); rewritten = hybridQueryWithKnn.rewrite(reader); assertSame(hybridQueryWithKnn, rewritten); - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 0)); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), null)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); w.close(); @@ -204,7 +204,7 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithMatch_thenRetu HybridQuery query = new HybridQuery( List.of(new TermQuery(new Term(TEXT_FIELD_NAME, field1Value)), new TermQuery(new Term(TEXT_FIELD_NAME, field2Value))), - 0 + null ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -250,7 +250,7 @@ public void testWithRandomDocuments_whenOneTermSubQueryWithoutMatch_thenReturnSu DirectoryReader reader = DirectoryReader.open(w); IndexSearcher searcher = newSearcher(reader); - HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), 0); + HybridQuery query = new HybridQuery(List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), null); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -287,7 +287,7 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR HybridQuery query = new HybridQuery( List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), - 0 + null ); // executing search query, getting up to 3 docs in result TopDocs hybridQueryResult = searcher.search(query, 3); @@ -301,10 +301,22 @@ public void testWithRandomDocuments_whenMultipleTermSubQueriesWithoutMatch_thenR @SneakyThrows public void testWithRandomDocuments_whenNoSubQueries_thenFail() { - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), 0)); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new HybridQuery(List.of(), null)); assertThat(exception.getMessage(), containsString("collection of queries must not be empty")); } + @SneakyThrows + public void testWithRandomDocuments_whenPaginationDepthIsZero_thenFail() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new HybridQuery( + List.of(new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT)), new TermQuery(new Term(TEXT_FIELD_NAME, QUERY_TEXT))), + 0 + ) + ); + assertThat(exception.getMessage(), containsString("pagination depth must not be zero")); + } + @SneakyThrows public void testToString_whenCallQueryToString_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); @@ -319,7 +331,7 @@ public void testToString_whenCallQueryToString_thenSuccessful() { .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT)) .toQuery(mockQueryShardContext) ), - 0 + null ); String queryString = query.toString(TEXT_FIELD_NAME); @@ -340,7 +352,7 @@ public void testFilter_whenSubQueriesWithFilterPassed_thenSuccessful() { QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_ANOTHER_QUERY_TEXT).toQuery(mockQueryShardContext) ), List.of(filter), - 0 + null ); QueryUtils.check(hybridQuery); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 0c75f22ca..dcb910c55 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -61,7 +61,7 @@ public void testScorerIterator_whenExecuteQuery_thenScorerIteratorSuccessful() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 0 + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -118,7 +118,7 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) ), - 0 + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = hybridQueryWithTerm.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f); @@ -166,7 +166,7 @@ public void testExplain_whenCallExplain_thenFail() { IndexReader reader = DirectoryReader.open(w); HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext)), - 0 + null ); IndexSearcher searcher = newSearcher(reader); Weight weight = searcher.createWeight(hybridQueryWithTerm, ScoreMode.COMPLETE, 1.0f); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java index 27718df41..4e2dfbd89 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -69,7 +69,7 @@ public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSucce TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -129,7 +129,7 @@ public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessf TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 78831a7ca..f23f99714 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -93,7 +93,7 @@ public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -124,7 +124,7 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -155,7 +155,7 @@ public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); @@ -199,7 +199,7 @@ public void testPostFilter_whenConcurrentSearch_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); @@ -245,7 +245,7 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { HybridQuery hybridQueryWithTerm = new HybridQuery( List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)), - 0 + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -346,7 +346,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingIsApplied_thenSucc TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -383,7 +383,7 @@ public void testNewCollector_whenNotConcurrentSearchAndSortingAndSearchAfterAreA TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -414,7 +414,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQueryWithMatchAll = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQueryWithMatchAll); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -512,7 +512,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 0 + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -634,7 +634,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); IndexReader indexReader = mock(IndexReader.class); @@ -764,7 +764,7 @@ public void testReduceAndRescore_whenMatchedDocsAndRescoreContextPresent_thenSuc QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 0 + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -883,7 +883,7 @@ public void testRescoreWithConcurrentSegmentSearch_whenMatchedDocsAndRescore_the QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY3).toQuery(mockQueryShardContext) ), - 0 + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1027,7 +1027,7 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) ), - 0 + null ); when(searchContext.query()).thenReturn(hybridQueryWithTerm); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1087,71 +1087,14 @@ public void testReduceAndRescore_whenRescorerThrowsException_thenFail() { } @SneakyThrows - public void testNumDocsCount_whenPaginationDepthIsLessThanZero_thenFail() { + public void testCreateCollectorManager_whenFromAreEqualToZeroAndPaginationDepthInRange_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), -1); - - when(searchContext.query()).thenReturn(hybridQuery); - ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); - IndexReader indexReader = mock(IndexReader.class); - when(indexSearcher.getIndexReader()).thenReturn(indexReader); - when(searchContext.searcher()).thenReturn(indexSearcher); - - Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); - when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); - when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); - - IllegalArgumentException illegalArgumentException = assertThrows( - IllegalArgumentException.class, - () -> HybridCollectorManager.createHybridCollectorManager(searchContext) - ); - assertEquals( - String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: -1"), - illegalArgumentException.getMessage() - ); - } - - @SneakyThrows - public void testNumDocsCount_whenPaginationDepthIsGreaterThan10000_thenFail() { - SearchContext searchContext = mock(SearchContext.class); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10001); - - when(searchContext.query()).thenReturn(hybridQuery); - ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); - IndexReader indexReader = mock(IndexReader.class); - when(indexSearcher.getIndexReader()).thenReturn(indexReader); - when(searchContext.searcher()).thenReturn(indexSearcher); - - Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); - when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); - when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); - - IllegalArgumentException illegalArgumentException = assertThrows( - IllegalArgumentException.class, - () -> HybridCollectorManager.createHybridCollectorManager(searchContext) - ); - assertEquals( - String.format(Locale.ROOT, "Pagination depth should lie in the range of 1-1000. Received: 10001"), - illegalArgumentException.getMessage() - ); - } - - @SneakyThrows - public void testCreateCollectorManager_whenPaginationDepthAndFromAreEqualToZero_thenSuccessful() { - SearchContext searchContext = mock(SearchContext.class); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + // pagination_depth=10 + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 10); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1176,16 +1119,16 @@ public void testCreateCollectorManager_whenPaginationDepthAndFromAreEqualToZero_ } @SneakyThrows - public void testCreateCollectorManager_whenPaginationDepthIsEqualToZeroAndFromIsGreaterThanZero_thenSuccessful() { + public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() { SearchContext searchContext = mock(SearchContext.class); // From >0 - searchContext.from(5); + when(searchContext.from()).thenReturn(5); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); // if pagination_depth ==0 then internally by default it will pick 10 as the depth - HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), 0); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext)), null); when(searchContext.query()).thenReturn(hybridQuery); ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); @@ -1197,16 +1140,14 @@ public void testCreateCollectorManager_whenPaginationDepthIsEqualToZeroAndFromIs when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); - CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); - assertNotNull(hybridCollectorManager); - assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); - - Collector collector = hybridCollectorManager.newCollector(); - assertNotNull(collector); - assertTrue(collector instanceof HybridTopScoreDocCollector); - - Collector secondCollector = hybridCollectorManager.newCollector(); - assertSame(collector, secondCollector); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> HybridCollectorManager.createHybridCollectorManager(searchContext) + ); + assertEquals( + String.format(Locale.ROOT, "pagination_depth is missing in the search request"), + illegalArgumentException.getMessage() + ); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java index 3be4ad090..0fc0980c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -46,7 +46,7 @@ public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() .toQuery(mockQueryShardContext), QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) ), - 0 + null ); SearchContext searchContext = mock(SearchContext.class); From c5d3c465118c8ee50d456afd88c25139b88ff48b Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Sun, 3 Nov 2024 18:59:26 -0800 Subject: [PATCH 2/2] Fixing Integ Tests Signed-off-by: Varun Jain --- .../query/HybridQueryBuilder.java | 8 +++- .../search/query/HybridCollectorManager.java | 14 +++--- .../neuralsearch/query/HybridQueryIT.java | 48 +++++++++---------- 3 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index 03e54c9fb..0b52b90e6 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -55,8 +55,9 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries = new ArrayList<>(); private String fieldName; - private Integer paginationDepth; + private Integer paginationDepth = null; static final int MAX_NUMBER_OF_SUB_QUERIES = 5; + private final static int DEFAULT_PAGINATION_DEPTH = 10; private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 1; private static final int UPPER_BOUND_OF_PAGINATION_DEPTH = 10000; @@ -108,7 +109,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep queryBuilder.toXContent(builder, params); } builder.endArray(); - builder.field(DEPTH_FIELD.getPreferredName(), paginationDepth); + if (isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery()) { + builder.field(DEPTH_FIELD.getPreferredName(), paginationDepth == null ? DEFAULT_PAGINATION_DEPTH : paginationDepth); + } printBoostAndQueryName(builder); builder.endObject(); } @@ -159,6 +162,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio * @throws IOException */ public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException { + log.info("fromXContent called"); float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer paginationDepth = null; diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index e5427a53a..8a86117a6 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -485,16 +485,16 @@ private ReduceableSearchResult reduceSearchResults(final List 0 && (Objects.isNull(paginationDepth))) { + if (Objects.isNull(paginationDepth)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth is missing in the search request")); } - if (paginationDepth != null) { - return paginationDepth; - } else { - // Switch to from+size retrieval size when pagination_depth is null. - return searchContext.from() + searchContext.size(); - } + return paginationDepth; } private static HybridQuery getHybridQueryFromAbstractQuery(Query query) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 0eb2238aa..16728cfd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -639,6 +639,7 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilterAndIndexWithNested HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + // hybridQueryBuilder.paginationDepth(10); Map searchResponseAsMap = search( alias, @@ -845,7 +846,7 @@ public void testPaginationOnSingleShard_whenConcurrentSearchEnabled_thenSuccessf initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { @@ -860,7 +861,7 @@ public void testPaginationOnSingleShard_whenConcurrentSearchDisabled_thenSuccess initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } finally { @@ -875,7 +876,7 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchEnabled_thenSucces initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { @@ -890,7 +891,7 @@ public void testPaginationOnMultipleShard_whenConcurrentSearchDisabled_thenSucce initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); - testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(TEST_MULTI_DOC_INDEX_NAME); + testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail(TEST_MULTI_DOC_INDEX_NAME); testHybridQuery_whenPaginationDepthIsOutOfRange_thenFail(TEST_MULTI_DOC_INDEX_NAME); } finally { @@ -927,31 +928,30 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc } @SneakyThrows - public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenFail(String indexName) { + public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful(String indexName) { HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder(); hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder()); - ResponseException responseException = assertThrows( - ResponseException.class, - () -> search( - indexName, - hybridQueryBuilderOnlyMatchAll, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - null, - null, - null, - false, - null, - 2 - ) + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilderOnlyMatchAll, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null, + null, + false, + null, + 2 ); - org.hamcrest.MatcherAssert.assertThat( - responseException.getMessage(), - allOf(containsString("pagination_depth is missing in the search request")) - ); + assertEquals(2, getHitCount(searchResponseAsMap)); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } @SneakyThrows