From af84883c3171b0e1133d4b08255853b933ae2c88 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 25 Jun 2024 11:16:04 -0700 Subject: [PATCH 01/10] Cherry picking Concurrent Segment Search Bug Commit Signed-off-by: Varun Jain --- CHANGELOG.md | 1 + .../search/query/HybridCollectorManager.java | 156 ++++++----- .../query/HybridQueryScoreDocsMerger.java | 83 ++++++ .../search/query/TopDocsMerger.java | 74 +++++ .../util/HybridSearchResultFormatUtil.java | 24 ++ .../query/HybridQueryAggregationsIT.java | 22 +- .../neuralsearch/query/HybridQueryIT.java | 140 +++++++++- .../query/HybridQueryPostFilterIT.java | 8 +- .../BaseAggregationsWithHybridQueryIT.java | 1 - .../BucketAggregationsWithHybridQueryIT.java | 54 ++-- .../MetricAggregationsWithHybridQueryIT.java | 34 +-- ...PipelineAggregationsWithHybridQueryIT.java | 16 +- .../query/HybridCollectorManagerTests.java | 127 +++++++++ .../HybridQueryScoreDocsMergerTests.java | 154 +++++++++++ .../search/query/TopDocsMergerTests.java | 255 ++++++++++++++++++ .../HybridSearchResultFormatUtilTests.java | 25 +- .../neuralsearch/BaseNeuralSearchIT.java | 1 + 17 files changed, 1037 insertions(+), 138 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a72fcdaa..4b6df7e66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 68d0d559c..b1159c851 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -41,6 +41,8 @@ import java.util.List; import java.util.Objects; +import static org.apache.lucene.search.TotalHits.Relation; +import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; @@ -56,14 +58,14 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { - final List hybridTopScoreDocCollectors = new ArrayList<>(); - final List hybridSortedTopDocCollectors = new ArrayList<>(); - // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper - // in case multiple collector managers are registered. We use hybrid scores collector to format scores into - // format specific for hybrid search query: start, sub-query-delimiter, scores, stop - for (final Collector collector : collectors) { - if (collector instanceof MultiCollectorWrapper) { - for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { - if (sub instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); - } else if (sub instanceof HybridTopFieldDocSortCollector) { - hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) sub); - } - } - } else if (collector instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); - } else if (collector instanceof HybridTopFieldDocSortCollector) { - hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) collector); - } else if (collector instanceof FilteredCollector - && ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) { - hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector()); - } else if (collector instanceof FilteredCollector - && ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector) { - hybridSortedTopDocCollectors.add((HybridTopFieldDocSortCollector) ((FilteredCollector) collector).getCollector()); - } + final List hybridSearchCollectors = getHybridSearchCollectors(collectors); + if (hybridSearchCollectors.isEmpty()) { + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); } + return reduceSearchResults(getSearchResults(hybridSearchCollectors)); + } - if (!hybridTopScoreDocCollectors.isEmpty()) { - HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() - .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); - List topDocs = hybridTopScoreDocCollector.topDocs(); - TopDocs newTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()), - topDocs - ); - TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore()); - return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + private List getSearchResults(List hybridSearchCollectors) { + List results = new ArrayList<>(); + DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); + for (Collector collector : hybridSearchCollectors) { + TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats); + results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats)); } + return results; + } - // TODO: Cater the fix for the Bug https://github.com/opensearch-project/neural-search/issues/799 - if (!hybridSortedTopDocCollectors.isEmpty()) { - HybridTopFieldDocSortCollector hybridSortedTopScoreDocCollector = hybridSortedTopDocCollectors.stream() - .findFirst() - .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); - - List topFieldDocs = hybridSortedTopScoreDocCollector.topDocs(); - long maxTotalHits = hybridSortedTopScoreDocCollector.getTotalHits(); - float maxScore = hybridSortedTopScoreDocCollector.getMaxScore(); - - TopDocs newTopDocs = getNewTopFieldDocs( - getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, isSingleShard, maxTotalHits), + private TopDocsAndMaxScore getTopDocsAndAndMaxScore(Collector collector, DocValueFormat[] docValueFormats) { + float maxScore; + TopDocs newTopDocs; + if (docValueFormats != null) { + HybridTopFieldDocSortCollector hybridTopFieldDocSortCollector = (HybridTopFieldDocSortCollector) collector; + List topFieldDocs = hybridTopFieldDocSortCollector.topDocs(); + maxScore = hybridTopFieldDocSortCollector.getMaxScore(); + newTopDocs = getNewTopFieldDocs( + getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, hybridTopFieldDocSortCollector.getTotalHits()), topFieldDocs, sortAndFormats.sort.getSort() ); - TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } else { + HybridTopScoreDocCollector hybridTopScoreDocCollector = (HybridTopScoreDocCollector) collector; + List topDocs = hybridTopScoreDocCollector.topDocs(); + maxScore = hybridTopScoreDocCollector.getMaxScore(); + newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), topDocs); } + return new TopDocsAndMaxScore(newTopDocs, maxScore); + } - throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + private List getHybridSearchCollectors(Collection collectors) { + final List hybridSearchCollectors = new ArrayList<>(); + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector || sub instanceof HybridTopFieldDocSortCollector) { + hybridSearchCollectors.add(sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector || collector instanceof HybridTopFieldDocSortCollector) { + hybridSearchCollectors.add(collector); + } else if (collector instanceof FilteredCollector + && (((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector + || ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector)) { + hybridSearchCollectors.add(((FilteredCollector) collector).getCollector()); + } + } + return hybridSearchCollectors; } private static void validateSortCriteria(SearchContext searchContext, boolean trackScores) { @@ -302,10 +297,11 @@ private TopDocs getNewTopDocs(final TotalHits totalHits, final List top return new TopDocs(totalHits, scoreDocs); } - private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard, final long maxTotalHits) { - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final long maxTotalHits) { + final Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? Relation.GREATER_THAN_OR_EQUAL_TO + : Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { return new TotalHits(0, relation); } @@ -372,6 +368,40 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats return sortAndFormats == null ? null : sortAndFormats.formats; } + private void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore topDocsAndMaxScore, DocValueFormat[] docValueFormats) { + // this is case of first collector, query result object doesn't have any top docs set, so we can + // just set new top docs without merge + // this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because + // getter throws exception in case topDocs is null + if (result.hasConsumedTopDocs()) { + result.topDocs(topDocsAndMaxScore, docValueFormats); + return; + } + // in this case top docs are already present in result, and we need to merge next result object with what we have. + // if collector doesn't have any hits we can just skip it and save some cycles by not doing merge + if (topDocsAndMaxScore.topDocs.totalHits.value == 0) { + return; + } + // we need to do actual merge because query result and current collector both have some score hits + TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs(); + TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); + result.topDocs(mergeTopDocsAndMaxScores, docValueFormats); + } + + /** + * For collection of search results, return a single one that has results from all individual result objects. + * @param results collection of search results + * @return single search result that represents all results as one object + */ + private ReduceableSearchResult reduceSearchResults(List results) { + return (result) -> { + for (ReduceableSearchResult r : results) { + // call reduce for results of each single collector, this will update top docs in query result + r.reduce(result); + } + }; + } + /** * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to * use saved state of collector @@ -382,7 +412,6 @@ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager public HybridCollectorNonConcurrentManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight, @@ -391,10 +420,10 @@ public HybridCollectorNonConcurrentManager( super( numHits, hitsThresholdChecker, - isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight, + TOP_DOCS_MERGER_TOP_SCORES, (FieldDoc) searchAfter ); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); @@ -421,7 +450,6 @@ static class HybridCollectorConcurrentSearchManager extends HybridCollectorManag public HybridCollectorConcurrentSearchManager( int numHits, HitsThresholdChecker hitsThresholdChecker, - boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight, @@ -430,10 +458,10 @@ public HybridCollectorConcurrentSearchManager( super( numHits, hitsThresholdChecker, - isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight, + TOP_DOCS_MERGER_TOP_SCORES, (FieldDoc) searchAfter ); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java new file mode 100644 index 000000000..7eb6e2b55 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.lucene.search.ScoreDoc; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; + +/** + * Merges two ScoreDoc arrays into one + */ +@NoArgsConstructor(access = AccessLevel.PACKAGE) +class HybridQueryScoreDocsMerger { + + private static final int MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC = 3; + + /** + * Merge two score docs objects, result ScoreDocs[] object will have all hits per sub-query from both original objects. + * Input and output ScoreDocs are in format that is specific to Hybrid Query. This method should not be used for ScoreDocs from + * other query types. + * Logic is based on assumption that hits of every sub-query are sorted by score. + * Method returns new object and doesn't mutate original ScoreDocs arrays. + * @param sourceScoreDocs original score docs from query result + * @param newScoreDocs new score docs that we need to merge into existing scores + * @return merged array of ScoreDocs objects + */ + public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator) { + if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC + || Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) { + throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements"); + } + // we overshoot and preallocate more than we need - length of both top docs combined. + // we will take only portion of the array at the end + List mergedScoreDocs = new ArrayList<>(sourceScoreDocs.length + newScoreDocs.length); + int sourcePointer = 0; + // mark beginning of hybrid query results by start element + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + // new pointer is set to 1 as we don't care about it start-stop element + int newPointer = 1; + + while (sourcePointer < sourceScoreDocs.length - 1 && newPointer < newScoreDocs.length - 1) { + // every iteration is for results of one sub-query + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + newPointer++; + // simplest case when both arrays have results for sub-query + while (sourcePointer < sourceScoreDocs.length + && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer]) + && newPointer < newScoreDocs.length + && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } else { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + // at least one object got exhausted at this point, now merge all elements from object that's left + while (sourcePointer < sourceScoreDocs.length && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer])) { + mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); + sourcePointer++; + } + while (newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { + mergedScoreDocs.add(newScoreDocs[newPointer]); + newPointer++; + } + } + // mark end of hybrid query results by end element + mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]); + return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java new file mode 100644 index 000000000..0e6adfb1a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; + +import java.util.Comparator; +import java.util.Objects; + +/** + * Utility class for merging TopDocs and MaxScore across multiple search queries + */ +@RequiredArgsConstructor(access = AccessLevel.PACKAGE) +class TopDocsMerger { + + private final HybridQueryScoreDocsMerger scoreDocsMerger; + @VisibleForTesting + protected static final Comparator SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + /** + * Uses hybrid query score docs merger to merge internal score docs + */ + static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>()); + + /** + * Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. + * @param source TopDocsAndMaxScore for the original query + * @param newTopDocs TopDocsAndMaxScore for the new query + * @return merged TopDocsAndMaxScore object + */ + public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { + return source; + } + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + source.topDocs.scoreDocs, + newTopDocs.topDocs.scoreDocs, + SCORE_DOC_BY_SCORE_COMPARATOR + ); + TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); + TopDocsAndMaxScore result = new TopDocsAndMaxScore( + new TopDocs(mergedTotalHits, mergedScoreDocs), + Math.max(source.maxScore, newTopDocs.maxScore) + ); + return result; + } + + private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + // merged value is a lower bound - if both are equal_to than merged will also be equal_to, + // otherwise assign greater_than_or_equal + TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + || newTopDocs.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java index d7be3851f..fa196a533 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java @@ -70,6 +70,30 @@ public static FieldDoc createFieldDocDelimiterElementForHybridSearchResults(fina return new FieldDoc(docId, MAGIC_NUMBER_DELIMITER, fields); } + /** + * Checking if passed scoreDocs object is a special element (start/stop or delimiter) in the list of hybrid query result scores + * @param scoreDoc score doc object to check on + * @return true if it is a special element + */ + public static boolean isHybridQuerySpecialElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return isHybridQueryStartStopElement(scoreDoc) || isHybridQueryDelimiterElement(scoreDoc); + } + + /** + * Checking if passed scoreDocs object is a document score element + * @param scoreDoc score doc object to check on + * @return true if element has score + */ + public static boolean isHybridQueryScoreDocElement(final ScoreDoc scoreDoc) { + if (Objects.isNull(scoreDoc)) { + return false; + } + return !isHybridQuerySpecialElement(scoreDoc); + } + /** * This method is for creating dummy sort object for the field docs having magic number scores which acts as delimiters. * The sort object should be in the same type of the field on which sorting criteria is applied. diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 9e72dfcb1..4bc40add8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -100,43 +100,43 @@ protected boolean preserveClusterUponCompletion() { @SneakyThrows public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgSumMinMaxAggs(); } @SneakyThrows public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAvgSumMinMaxAggs(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testMaxAggsOnSingleShardCluster(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateRange(); } @SneakyThrows public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateRange(); } @SneakyThrows public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); try { prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); @@ -177,14 +177,14 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(false, true); testPostFilterWithComplexHybridQuery(false, true); } @@ -420,14 +420,14 @@ private void testAvgSumMinMaxAggs() { @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } @SneakyThrows public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPostFilterWithSimpleHybridQuery(true, true); testPostFilterWithComplexHybridQuery(true, true); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 43e302698..a650087b4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -22,6 +22,8 @@ import java.util.Set; import java.util.stream.IntStream; +import org.apache.commons.lang.RandomStringUtils; +import org.apache.commons.lang.math.RandomUtils; import org.apache.lucene.search.join.ScoreMode; import org.junit.Before; import org.opensearch.client.ResponseException; @@ -47,6 +49,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; private static final String TEST_INDEX_WITH_KEYWORDS_ONE_SHARD = "test-hybrid-keywords-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_ONE_SHARD = "test-hybrid-doc-qty-single-shard-index"; + private static final String TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS = "test-hybrid-doc-qty-multiple-shards-index"; private static final String TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS = "test-hybrid-keywords-three-shards-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; @@ -76,6 +80,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; + protected static final int SINGLE_SHARD = 1; + protected static final int MULTIPLE_SHARDS = 3; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -692,6 +698,101 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { } } + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenSingleShardIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 1_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_ONE_SHARD, SINGLE_SHARD, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_ONE_SHARD, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_ONE_SHARD, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + + @SneakyThrows + public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenSuccessful() { + try { + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); + int numberOfDocumentsInIndex = 2_000; + initializeIndexIfNotExist(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, MULTIPLE_SHARDS, numberOfDocumentsInIndex); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.matchAllQuery()); + hybridQueryBuilder.add(QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(0).lte(1000)); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + numberOfDocumentsInIndex, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + int queryHitCount = getHitCount(firstSearchResponseAsMap); + assertEquals(numberOfDocumentsInIndex, queryHitCount); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(numberOfDocumentsInIndex, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_DOC_QTY_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -784,7 +885,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), List.of(TEST_NESTED_TYPE_FIELD_NAME_1), - 1 + SINGLE_SHARD ), "" ); @@ -805,7 +906,14 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_INDEX_WITH_KEYWORDS_ONE_SHARD.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD)) { createIndexWithConfiguration( indexName, - buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 1), + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + SINGLE_SHARD + ), "" ); addDocWithKeywordsAndIntFields( @@ -901,6 +1009,34 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } + @SneakyThrows + private void initializeIndexIfNotExist(String indexName, int numberOfShards, int numberOfDocuments) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(KEYWORD_FIELD_1), + List.of(), + numberOfShards + ), + "" + ); + for (int i = 0; i < numberOfDocuments; i++) { + addDocWithKeywordsAndIntFields( + indexName, + String.valueOf(i), + INTEGER_FIELD_PRICE, + RandomUtils.nextInt(1000), + KEYWORD_FIELD_1, + RandomStringUtils.randomAlphabetic(10) + ); + } + } + } + private void addDocsToIndex(final String testMultiDocIndexName) { addKnnDoc( testMultiDocIndexName, diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java index 7d33d07fe..8f8ae8cc4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -68,7 +68,7 @@ public static void setUpCluster() { @SneakyThrows public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); @@ -81,7 +81,7 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_the @SneakyThrows public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); @@ -94,7 +94,7 @@ public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_th @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); @@ -107,7 +107,7 @@ public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_ @SneakyThrows public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled_thenSuccessful() { try { - updateClusterSettings("search.concurrent_segment_search.enabled", false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java index 48fb8f8d6..5cc5b9170 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java @@ -77,7 +77,6 @@ public class BaseAggregationsWithHybridQueryIT extends BaseNeuralSearchIT { protected static final String AVG_AGGREGATION_NAME = "avg_field"; protected static final String GENERIC_AGGREGATION_NAME = "my_aggregation"; protected static final String DATE_AGGREGATION_NAME = "date_aggregation"; - protected static final String CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH = "search.concurrent_segment_search.enabled"; @BeforeClass @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java index ce8854eed..7385f48e5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java @@ -68,165 +68,165 @@ public class BucketAggregationsWithHybridQueryIT extends BaseAggregationsWithHyb @SneakyThrows public void testBucketAndNestedAggs_whenAdjacencyMatrix_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAdjacencyMatrixAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenAdjacencyMatrix_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAdjacencyMatrixAggs(); } @SneakyThrows public void testBucketAndNestedAggs_whenDiversifiedSampler_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDiversifiedSampler(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDiversifiedSampler_thenFail() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDiversifiedSampler(); } @SneakyThrows public void testBucketAndNestedAggs_whenAvgNestedIntoFilter_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testAvgNestedIntoFilter(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenAvgNestedIntoFilter_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgNestedIntoFilter(); } @SneakyThrows public void testBucketAndNestedAggs_whenSumNestedIntoFilters_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSumNestedIntoFilters(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSumNestedIntoFilters_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSumNestedIntoFilters(); } @SneakyThrows public void testBucketAggs_whenGlobalAggUsedWithQuery_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testGlobalAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenGlobalAggUsedWithQuery_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testGlobalAggs(); } @SneakyThrows public void testBucketAggs_whenHistogramAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testHistogramAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenHistogramAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testHistogramAggs(); } @SneakyThrows public void testBucketAggs_whenNestedAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testNestedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenNestedAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testNestedAggs(); } @SneakyThrows public void testBucketAggs_whenSamplerAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSampler(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSamplerAgg_thenFail() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSampler(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketScriptAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testMetricAggs_whenTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testTermsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testTermsAggs(); } @SneakyThrows public void testMetricAggs_whenSignificantTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSignificantTermsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSignificantTermsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSignificantTermsAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java index 36c853984..aa69272fb 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java @@ -81,103 +81,103 @@ public class MetricAggregationsWithHybridQueryIT extends BaseAggregationsWithHyb */ @SneakyThrows public void testWithConcurrentSegmentSearch_whenAvgAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testAvgAggs(); } @SneakyThrows public void testMetricAggs_whenCardinalityAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testCardinalityAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenCardinalityAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testCardinalityAggs(); } @SneakyThrows public void testMetricAggs_whenExtendedStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testExtendedStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenExtendedStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testExtendedStatsAggs(); } @SneakyThrows public void testMetricAggs_whenTopHitsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testTopHitsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenTopHitsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testTopHitsAggs(); } @SneakyThrows public void testMetricAggs_whenPercentileRank_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPercentileRankAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenPercentileRank_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPercentileRankAggs(); } @SneakyThrows public void testMetricAggs_whenPercentile_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testPercentileAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenPercentile_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testPercentileAggs(); } @SneakyThrows public void testMetricAggs_whenScriptedMetrics_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testScriptedMetricsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenScriptedMetrics_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testScriptedMetricsAggs(); } @SneakyThrows public void testMetricAggs_whenSumAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testSumAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenSumAgg_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testSumAggs(); } @SneakyThrows public void testMetricAggs_whenValueCount_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testValueCountAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenValueCount_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testValueCountAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java index 168dce1e0..fd118629b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java @@ -53,49 +53,49 @@ public class PipelineAggregationsWithHybridQueryIT extends BaseAggregationsWithH @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketStatsAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketScriptedAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToBucketSortAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToBucketSortAggs(); } @SneakyThrows public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true); testDateBucketedSumsPipelinedToCumulativeSumAggs(); } @SneakyThrows public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { - updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); testDateBucketedSumsPipelinedToCumulativeSumAggs(); } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index a6e8337af..8ea464e9a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -67,6 +67,7 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; + private static final String QUERY2 = "hi"; private static final float DELTA_FOR_ASSERTION = 0.001f; @SneakyThrows @@ -481,4 +482,130 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() { reader.close(); directory.close(); } + + @SneakyThrows + public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY2).toQuery(mockQueryShardContext) + ) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector1 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + HybridTopScoreDocCollector collector2 = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(2, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(6, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[2].score > 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[3].score, DELTA_FOR_ASSERTION); + assertTrue(scoreDocs[4].score > 0); + + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[5].score, DELTA_FOR_ASSERTION); + // we have to assert that one of hits is max score because scores are generated for each run and order is not guaranteed + assertTrue(Float.compare(scoreDocs[2].score, maxScore) == 0 || Float.compare(scoreDocs[4].score, maxScore) == 0); + + w.close(); + reader.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java new file mode 100644 index 000000000..2147578c9 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.ScoreDoc; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import static org.opensearch.neuralsearch.search.query.TopDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; + +public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { + + private static final float DELTA_FOR_ASSERTION = 0.001f; + + public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scores = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + createStartStopElementForHybridSearchResults(2) }; + + NullPointerException exception = assertThrows( + NullPointerException.class, + () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + exception = assertThrows(NullPointerException.class, () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR)); + assertEquals("score docs cannot be null", exception.getMessage()); + + ScoreDoc[] lessElementsScoreDocs = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), new ScoreDoc(1, 0.7f) }; + + IllegalArgumentException notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + + notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + } + + public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(10, mergedScoreDocs.length); + + // check format, all elements one by one + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertScoreDoc(mergedScoreDocs[2], 1, 0.7f); + assertScoreDoc(mergedScoreDocs[3], 0, 0.5f); + assertScoreDoc(mergedScoreDocs[4], 2, 0.3f); + assertScoreDoc(mergedScoreDocs[5], 4, 0.3f); + assertScoreDoc(mergedScoreDocs[6], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[7].score, 0); + assertScoreDoc(mergedScoreDocs[8], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[9].score, 0); + } + + public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(8, mergedScoreDocs.length); + + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertScoreDoc(mergedScoreDocs[2], 1, 0.7f); + assertScoreDoc(mergedScoreDocs[3], 4, 0.3f); + assertScoreDoc(mergedScoreDocs[4], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[5].score, 0); + assertScoreDoc(mergedScoreDocs[6], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[7].score, 0); + } + + public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + + ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) }; + ScoreDoc[] scoreDocsNew = new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createStartStopElementForHybridSearchResults(2) }; + + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + + assertNotNull(mergedScoreDocs); + assertEquals(4, mergedScoreDocs.length); + // check format, all elements one by one + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[1].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, mergedScoreDocs[2].score, 0); + assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[3].score, 0); + } + + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { + assertEquals(expectedDocId, scoreDoc.doc); + assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java new file mode 100644 index 000000000..5a99f3f3a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; + +public class TopDocsMergerTests extends OpenSearchQueryTestCase { + + private static final float DELTA_FOR_ASSERTION = 0.001f; + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(6, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 5 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 5 + 1 + 2 + 2 = 10 + assertEquals(10, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 2, 0.3f); + assertScoreDoc(scoreDocs[5], 4, 0.3f); + assertScoreDoc(scoreDocs[6], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[7].score, 0); + assertScoreDoc(scoreDocs[8], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[9].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 3 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 3 + 1 + 2 + 2 = 8 + assertEquals(8, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 4, 0.3f); + assertScoreDoc(scoreDocs[4], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[5].score, 0); + assertScoreDoc(scoreDocs[6], 4, 0.6f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[7].score, 0); + } + + @SneakyThrows + public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0); + TopDocs topDocsNew = new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(0, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[2].score, 0); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, 0); + } + + @SneakyThrows + public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + + TopDocs topDocsOriginal = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + createDelimiterElementForHybridSearchResults(0), + createStartStopElementForHybridSearchResults(0) } + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(2), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(1, 0.7f), + new ScoreDoc(4, 0.3f), + new ScoreDoc(5, 0.05f), + createDelimiterElementForHybridSearchResults(2), + new ScoreDoc(4, 0.6f), + createStartStopElementForHybridSearchResults(2) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); + + assertNotNull(firstMergedTopDocsAndMaxScore); + + // merge results from collector 3 + TopDocs topDocsThirdCollector = new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(3), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(3, 0.4f), + createDelimiterElementForHybridSearchResults(3), + new ScoreDoc(7, 0.85f), + new ScoreDoc(9, 0.2f), + createStartStopElementForHybridSearchResults(3) } + ); + TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.merge( + firstMergedTopDocsAndMaxScore, + topDocsAndMaxScoreThirdCollector + ); + + assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(9, finalMergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, finalMergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 6 from sub-query1 and 3 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 6 + 3 + 2 + 2 = 13 + assertEquals(13, finalMergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + ScoreDoc[] scoreDocs = finalMergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, 0); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, 0); + assertScoreDoc(scoreDocs[2], 1, 0.7f); + assertScoreDoc(scoreDocs[3], 0, 0.5f); + assertScoreDoc(scoreDocs[4], 3, 0.4f); + assertScoreDoc(scoreDocs[5], 2, 0.3f); + assertScoreDoc(scoreDocs[6], 4, 0.3f); + assertScoreDoc(scoreDocs[7], 5, 0.05f); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[8].score, 0); + assertScoreDoc(scoreDocs[9], 7, 0.85f); + assertScoreDoc(scoreDocs[10], 4, 0.6f); + assertScoreDoc(scoreDocs[11], 9, 0.2f); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[12].score, 0); + } + + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { + assertEquals(expectedDocId, scoreDoc.doc); + assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java index 16f2f10ce..d84e196cd 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java @@ -4,16 +4,14 @@ */ package org.opensearch.neuralsearch.search.util; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryScoreDocElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQuerySpecialElement; import org.apache.lucene.search.ScoreDoc; import org.opensearch.common.Randomness; @@ -57,4 +55,23 @@ public void testCreateElements_whenCreateStartStopAndDelimiterElements_thenSucce assertEquals(docId, delimiterElement.doc); assertEquals(MAGIC_NUMBER_DELIMITER, delimiterElement.score, 0.0f); } + + public void testSpecialElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + ScoreDoc startStopElement = new ScoreDoc(docId, MAGIC_NUMBER_START_STOP); + assertTrue(isHybridQuerySpecialElement(startStopElement)); + assertFalse(isHybridQueryScoreDocElement(startStopElement)); + + ScoreDoc delimiterElement = new ScoreDoc(docId, MAGIC_NUMBER_DELIMITER); + assertTrue(isHybridQuerySpecialElement(delimiterElement)); + assertFalse(isHybridQueryScoreDocElement(delimiterElement)); + } + + public void testScoreElementCheck_whenElementIsSpecialAndIsNotSpecial_thenSuccessful() { + int docId = 1; + float score = Randomness.get().nextFloat(); + ScoreDoc startStopElement = new ScoreDoc(docId, score); + assertFalse(isHybridQuerySpecialElement(startStopElement)); + assertTrue(isHybridQueryScoreDocElement(startStopElement)); + } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 2682ee7c7..689e4bf98 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -83,6 +83,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); + protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); From 6e31805352dc49ca1b6a86b4720c4549fecbd044 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Fri, 28 Jun 2024 10:47:14 -0700 Subject: [PATCH 02/10] Fix Concurrent Segment Search Bug in Sorting Signed-off-by: Varun Jain --- .../search/query/HybridCollectorManager.java | 33 +- .../query/HybridQueryFieldDocComparator.java | 57 ++++ .../query/HybridQueryScoreDocsMerger.java | 22 +- .../search/query/TopDocsMerger.java | 70 ++++- .../query/HybridCollectorManagerTests.java | 126 ++++++++ .../HybridQueryScoreDocsMergerTests.java | 221 ++++++++++++- .../search/query/TopDocsMergerTests.java | 290 +++++++++++++++++- 7 files changed, 784 insertions(+), 35 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index b1159c851..5bd2fa79d 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -42,7 +42,6 @@ import java.util.Objects; import static org.apache.lucene.search.TotalHits.Relation; -import static org.opensearch.neuralsearch.search.query.TopDocsMerger.TOP_DOCS_MERGER_TOP_SCORES; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; @@ -156,7 +155,7 @@ public ReduceableSearchResult reduce(Collection collectors) { return reduceSearchResults(getSearchResults(hybridSearchCollectors)); } - private List getSearchResults(List hybridSearchCollectors) { + private List getSearchResults(final List hybridSearchCollectors) { List results = new ArrayList<>(); DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); for (Collector collector : hybridSearchCollectors) { @@ -166,7 +165,7 @@ private List getSearchResults(List hybridSear return results; } - private TopDocsAndMaxScore getTopDocsAndAndMaxScore(Collector collector, DocValueFormat[] docValueFormats) { + private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final Collector collector, final DocValueFormat[] docValueFormats) { float maxScore; TopDocs newTopDocs; if (docValueFormats != null) { @@ -187,7 +186,7 @@ private TopDocsAndMaxScore getTopDocsAndAndMaxScore(Collector collector, DocValu return new TopDocsAndMaxScore(newTopDocs, maxScore); } - private List getHybridSearchCollectors(Collection collectors) { + private List getHybridSearchCollectors(final Collection collectors) { final List hybridSearchCollectors = new ArrayList<>(); for (final Collector collector : collectors) { if (collector instanceof MultiCollectorWrapper) { @@ -368,7 +367,11 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats return sortAndFormats == null ? null : sortAndFormats.formats; } - private void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore topDocsAndMaxScore, DocValueFormat[] docValueFormats) { + private void reduceCollectorResults( + final QuerySearchResult result, + final TopDocsAndMaxScore topDocsAndMaxScore, + final DocValueFormat[] docValueFormats + ) { // this is case of first collector, query result object doesn't have any top docs set, so we can // just set new top docs without merge // this call is effectively checking if QuerySearchResult.topDoc is null. using it in such way because @@ -384,8 +387,18 @@ private void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore } // we need to do actual merge because query result and current collector both have some score hits TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs(); - TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); - result.topDocs(mergeTopDocsAndMaxScores, docValueFormats); + result.topDocs(getMergeTopDocsAndMaxScores(originalTotalDocsAndHits, topDocsAndMaxScore), docValueFormats); + } + + private TopDocsAndMaxScore getMergeTopDocsAndMaxScores( + final TopDocsAndMaxScore originalTotalDocsAndHits, + final TopDocsAndMaxScore topDocsAndMaxScore + ) { + if (sortAndFormats != null) { + return topDocsMerger.mergeFieldDocs(originalTotalDocsAndHits, topDocsAndMaxScore, sortAndFormats); + } else { + return topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); + } } /** @@ -393,7 +406,7 @@ private void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore * @param results collection of search results * @return single search result that represents all results as one object */ - private ReduceableSearchResult reduceSearchResults(List results) { + private ReduceableSearchResult reduceSearchResults(final List results) { return (result) -> { for (ReduceableSearchResult r : results) { // call reduce for results of each single collector, this will update top docs in query result @@ -423,7 +436,7 @@ public HybridCollectorNonConcurrentManager( trackTotalHitsUpTo, sortAndFormats, filteringWeight, - TOP_DOCS_MERGER_TOP_SCORES, + new TopDocsMerger(sortAndFormats), (FieldDoc) searchAfter ); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); @@ -461,7 +474,7 @@ public HybridCollectorConcurrentSearchManager( trackTotalHitsUpTo, sortAndFormats, filteringWeight, - TOP_DOCS_MERGER_TOP_SCORES, + new TopDocsMerger(sortAndFormats), (FieldDoc) searchAfter ); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java new file mode 100644 index 000000000..d09750dfb --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryFieldDocComparator.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import java.util.Comparator; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.Pruning; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.SortField; + +/** + * Comparator class that compares two field docs as per the sorting criteria + */ +@RequiredArgsConstructor(access = AccessLevel.PACKAGE) +class HybridQueryFieldDocComparator implements Comparator { + final SortField[] sortFields; + final FieldComparator[] comparators; + final int[] reverseMul; + final Comparator tieBreaker; + + public HybridQueryFieldDocComparator(SortField[] sortFields, Comparator tieBreaker) { + this.sortFields = sortFields; + this.tieBreaker = tieBreaker; + comparators = new FieldComparator[sortFields.length]; + reverseMul = new int[sortFields.length]; + for (int compIDX = 0; compIDX < sortFields.length; compIDX++) { + final SortField sortField = sortFields[compIDX]; + comparators[compIDX] = sortField.getComparator(1, Pruning.NONE); + reverseMul[compIDX] = sortField.getReverse() ? -1 : 1; + } + } + + @Override + public int compare(final FieldDoc firstFD, final FieldDoc secondFD) { + for (int compIDX = 0; compIDX < comparators.length; compIDX++) { + final FieldComparator comp = comparators[compIDX]; + + final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]); + + if (cmp != 0) { + return cmp; + } + } + return tieBreakCompare(firstFD, secondFD, tieBreaker); + } + + private int tieBreakCompare(ScoreDoc firstDoc, ScoreDoc secondDoc, Comparator tieBreaker) { + assert tieBreaker != null; + int value = tieBreaker.compare(firstDoc, secondDoc); + return value; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java index 7eb6e2b55..662cb2242 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java @@ -6,6 +6,7 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import java.util.ArrayList; @@ -33,7 +34,7 @@ class HybridQueryScoreDocsMerger { * @param newScoreDocs new score docs that we need to merge into existing scores * @return merged array of ScoreDocs objects */ - public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator) { + public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator, final boolean isSortEnabled) { if (Objects.requireNonNull(sourceScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC || Objects.requireNonNull(newScoreDocs, "score docs cannot be null").length < MIN_NUMBER_OF_ELEMENTS_IN_SCORE_DOC) { throw new IllegalArgumentException("cannot merge top docs because it does not have enough elements"); @@ -58,7 +59,7 @@ public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Compar && isHybridQueryScoreDocElement(sourceScoreDocs[sourcePointer]) && newPointer < newScoreDocs.length && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { - if (comparator.compare(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer]) >= 0) { + if (compareCondition(sourceScoreDocs[sourcePointer], newScoreDocs[newPointer], comparator, isSortEnabled)) { mergedScoreDocs.add(sourceScoreDocs[sourcePointer]); sourcePointer++; } else { @@ -78,6 +79,23 @@ && isHybridQueryScoreDocElement(newScoreDocs[newPointer])) { } // mark end of hybrid query results by end element mergedScoreDocs.add(sourceScoreDocs[sourceScoreDocs.length - 1]); + if (isSortEnabled) { + return mergedScoreDocs.toArray((T[]) new FieldDoc[0]); + } return mergedScoreDocs.toArray((T[]) new ScoreDoc[0]); } + + private boolean compareCondition( + final ScoreDoc oldScoreDoc, + final ScoreDoc secondScoreDoc, + final Comparator comparator, + final boolean isSortEnabled + ) { + // If sorting is enabled then compare condition will be different then normal HybridQuery + if (isSortEnabled) { + return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) < 0; + } else { + return comparator.compare((T) oldScoreDoc, (T) secondScoreDoc) >= 0; + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 0e6adfb1a..a041149fe 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -7,13 +7,16 @@ import com.google.common.annotations.VisibleForTesting; import lombok.AccessLevel; import lombok.RequiredArgsConstructor; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import java.util.Comparator; import java.util.Objects; +import org.opensearch.search.sort.SortAndFormats; /** * Utility class for merging TopDocs and MaxScore across multiple search queries @@ -21,13 +24,29 @@ @RequiredArgsConstructor(access = AccessLevel.PACKAGE) class TopDocsMerger { - private final HybridQueryScoreDocsMerger scoreDocsMerger; + private HybridQueryScoreDocsMerger scoreDocsMerger; + private HybridQueryScoreDocsMerger fieldDocsMerger; @VisibleForTesting - protected static final Comparator SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + protected static Comparator SCORE_DOC_BY_SCORE_COMPARATOR; + @VisibleForTesting + protected static HybridQueryFieldDocComparator FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR; + private final Comparator MERGING_TIE_BREAKER = (o1, o2) -> { + int docIdComparison = Integer.compare(o1.doc, o2.doc); + return docIdComparison; + }; + /** * Uses hybrid query score docs merger to merge internal score docs */ - static final TopDocsMerger TOP_DOCS_MERGER_TOP_SCORES = new TopDocsMerger(new HybridQueryScoreDocsMerger<>()); + TopDocsMerger(final SortAndFormats sortAndFormats) { + if (sortAndFormats != null) { + fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER); + } else { + scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); + } + } /** * Merge TopDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. @@ -35,7 +54,7 @@ class TopDocsMerger { * @param newTopDocs TopDocsAndMaxScore for the new query * @return merged TopDocsAndMaxScore object */ - public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { return source; } @@ -52,7 +71,8 @@ public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore ne ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( source.topDocs.scoreDocs, newTopDocs.topDocs.scoreDocs, - SCORE_DOC_BY_SCORE_COMPARATOR + SCORE_DOC_BY_SCORE_COMPARATOR, + false ); TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); TopDocsAndMaxScore result = new TopDocsAndMaxScore( @@ -62,7 +82,45 @@ public TopDocsAndMaxScore merge(TopDocsAndMaxScore source, TopDocsAndMaxScore ne return result; } - private TotalHits getMergedTotalHits(TopDocsAndMaxScore source, TopDocsAndMaxScore newTopDocs) { + /** + * Merge TopFieldDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. + * @param source TopDocsAndMaxScore for the original query + * @param newTopDocs TopDocsAndMaxScore for the new query + * @return merged TopDocsAndMaxScore object + */ + public TopDocsAndMaxScore mergeFieldDocs( + final TopDocsAndMaxScore source, + final TopDocsAndMaxScore newTopDocs, + final SortAndFormats sortAndFormats + ) { + if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { + return source; + } + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 | [1] + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_1 | [1] + FieldDoc[] mergedScoreDocs = fieldDocsMerger.merge( + (FieldDoc[]) source.topDocs.scoreDocs, + (FieldDoc[]) newTopDocs.topDocs.scoreDocs, + FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); + TopDocsAndMaxScore result = new TopDocsAndMaxScore( + new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()), + Math.max(source.maxScore, newTopDocs.maxScore) + ); + return result; + } + + private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { // merged value is a lower bound - if both are equal_to than merged will also be equal_to, // otherwise assign greater_than_or_equal TotalHits.Relation mergedHitsRelation = source.topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 8ea464e9a..de9c6006b 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.search.query; import com.carrotsearch.randomizedtesting.RandomizedTest; +import java.util.Arrays; import lombok.SneakyThrows; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -608,4 +609,129 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD reader2.close(); directory2.close(); } + + @SneakyThrows + public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedDocsWithSort_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery(List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext))); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(2); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("id", SortField.Type.DOC); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + when(searchContext.sort()).thenReturn(sortAndFormats); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int[] docIds = new int[] { docId1, docId2, docId3 }; + Arrays.sort(docIds); + + w.addDocument(getDocument(TEXT_FIELD_NAME, docIds[0], TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docIds[1], TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + SearchContext searchContext2 = mock(SearchContext.class); + + ContextIndexSearcher indexSearcher2 = mock(ContextIndexSearcher.class); + IndexReader indexReader2 = mock(IndexReader.class); + when(indexReader2.numDocs()).thenReturn(1); + when(indexSearcher2.getIndexReader()).thenReturn(indexReader); + when(searchContext2.searcher()).thenReturn(indexSearcher2); + when(searchContext2.size()).thenReturn(1); + + when(searchContext2.queryCollectorManagers()).thenReturn(new HashMap<>()); + when(searchContext2.shouldUseConcurrentSearch()).thenReturn(true); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory2 = newDirectory(); + final IndexWriter w2 = new IndexWriter(directory2, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft2 = new FieldType(TextField.TYPE_NOT_STORED); + ft2.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft2.setOmitNorms(random().nextBoolean()); + ft2.freeze(); + + w2.addDocument(getDocument(TEXT_FIELD_NAME, docIds[2], TEST_DOC_TEXT2, ft)); + w2.flush(); + w2.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + IndexReader reader2 = DirectoryReader.open(w2); + IndexSearcher searcher2 = newSearcher(reader2); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + SimpleFieldCollector collector1 = (SimpleFieldCollector) hybridCollectorManager.newCollector(); + SimpleFieldCollector collector2 = (SimpleFieldCollector) hybridCollectorManager.newCollector(); + + Weight weight1 = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + Weight weight2 = new HybridQueryWeight(hybridQueryWithTerm, searcher2, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector1.setWeight(weight1); + collector2.setWeight(weight2); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); + + LeafReaderContext leafReaderContext2 = searcher2.getIndexReader().leaves().get(0); + LeafCollector leafCollector2 = collector2.getLeafCollector(leafReaderContext2); + BulkScorer scorer = weight1.bulkScorer(leafReaderContext); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); + BulkScorer scorer2 = weight2.bulkScorer(leafReaderContext2); + scorer2.score(leafCollector2, leafReaderContext2.reader().getLiveDocs()); + leafCollector2.finish(); + + Object results = hybridCollectorManager.reduce(List.of(collector1, collector2)); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(3, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + FieldDoc[] fieldDocs = (FieldDoc[]) topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(5, fieldDocs.length); + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertEquals(fieldDocs[2].doc, fieldDocs[2].fields[0]); + assertEquals(fieldDocs[3].doc, fieldDocs[3].fields[0]); + assertEquals(1, fieldDocs[4].fields[0]); + + w.close(); + reader.close(); + directory.close(); + w2.close(); + reader2.close(); + directory2.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 2147578c9..3d7477233 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -4,14 +4,20 @@ */ package org.opensearch.neuralsearch.search.query; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import static org.opensearch.neuralsearch.search.query.TopDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.sort.SortAndFormats; public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { @@ -19,7 +25,7 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scores = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), createDelimiterElementForHybridSearchResults(2), @@ -28,24 +34,27 @@ public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail( NullPointerException exception = assertThrows( NullPointerException.class, - () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(scores, null, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("score docs cannot be null", exception.getMessage()); - exception = assertThrows(NullPointerException.class, () -> scoreDocsMerger.merge(scores, null, SCORE_DOC_BY_SCORE_COMPARATOR)); + exception = assertThrows( + NullPointerException.class, + () -> scoreDocsMerger.merge(scores, null, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) + ); assertEquals("score docs cannot be null", exception.getMessage()); ScoreDoc[] lessElementsScoreDocs = new ScoreDoc[] { createStartStopElementForHybridSearchResults(2), new ScoreDoc(1, 0.7f) }; IllegalArgumentException notEnoughException = assertThrows( IllegalArgumentException.class, - () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(lessElementsScoreDocs, scores, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); notEnoughException = assertThrows( IllegalArgumentException.class, - () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR) + () -> scoreDocsMerger.merge(scores, lessElementsScoreDocs, topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, false) ); assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); } @@ -71,7 +80,13 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { new ScoreDoc(4, 0.6f), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(10, mergedScoreDocs.length); @@ -91,6 +106,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { createStartStopElementForHybridSearchResults(0), @@ -107,7 +123,12 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf new ScoreDoc(4, 0.6f), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(8, mergedScoreDocs.length); @@ -124,6 +145,7 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); ScoreDoc[] scoreDocsOriginal = new ScoreDoc[] { createStartStopElementForHybridSearchResults(0), @@ -136,7 +158,12 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { createDelimiterElementForHybridSearchResults(2), createStartStopElementForHybridSearchResults(2) }; - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge(scoreDocsOriginal, scoreDocsNew, SCORE_DOC_BY_SCORE_COMPARATOR); + ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( + scoreDocsOriginal, + scoreDocsNew, + topDocsMerger.SCORE_DOC_BY_SCORE_COMPARATOR, + false + ); assertNotNull(mergedScoreDocs); assertEquals(4, mergedScoreDocs.length); @@ -147,8 +174,184 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { assertEquals(MAGIC_NUMBER_START_STOP, mergedScoreDocs[3].score, 0); } + public void testIncorrectInput_whenFieldDocsAreNullOrNotEnoughElements_thenFail() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] scores = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + NullPointerException exception = assertThrows( + NullPointerException.class, + () -> fieldDocsMerger.merge(scores, null, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + exception = assertThrows( + NullPointerException.class, + () -> fieldDocsMerger.merge(scores, null, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("score docs cannot be null", exception.getMessage()); + + FieldDoc[] lessElementsScoreDocs = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }) }; + + IllegalArgumentException notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> fieldDocsMerger.merge(lessElementsScoreDocs, scores, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + + notEnoughException = assertThrows( + IllegalArgumentException.class, + () -> fieldDocsMerger.merge(scores, lessElementsScoreDocs, topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true) + ); + assertEquals("cannot merge top docs because it does not have enough elements", notEnoughException.getMessage()); + } + + public void testMergeFieldDocs_whenBothTopDocsHasHits_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 80 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 10 }), + new FieldDoc(4, 0.3f, new Object[] { 5 }), + new FieldDoc(5, 0.05f, new Object[] { 2 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 5 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(10, mergedFieldDocs.length); + + // check format, all elements one by one + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertFieldDoc(mergedFieldDocs[2], 0, 100); + assertFieldDoc(mergedFieldDocs[3], 2, 80); + assertFieldDoc(mergedFieldDocs[4], 1, 10); + assertFieldDoc(mergedFieldDocs[5], 4, 5); + assertFieldDoc(mergedFieldDocs[6], 5, 2); + assertEquals(1, mergedFieldDocs[7].fields[0]); + assertFieldDoc(mergedFieldDocs[8], 4, 5); + assertEquals(1, mergedFieldDocs[9].fields[0]); + } + + public void testMergeFieldDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + new FieldDoc(4, 0.3f, new Object[] { 80 }), + new FieldDoc(5, 0.05f, new Object[] { 20 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 50 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(8, mergedFieldDocs.length); + + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertFieldDoc(mergedFieldDocs[2], 1, 100); + assertFieldDoc(mergedFieldDocs[3], 4, 80); + assertFieldDoc(mergedFieldDocs[4], 5, 20); + assertEquals(1, mergedFieldDocs[5].fields[0]); + assertFieldDoc(mergedFieldDocs[6], 4, 50); + assertEquals(1, mergedFieldDocs[7].fields[0]); + } + + public void testMergeFieldDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + HybridQueryScoreDocsMerger fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + FieldDoc[] fieldDocsOriginal = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }; + FieldDoc[] fieldDocsNew = new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }; + + FieldDoc[] mergedFieldDocs = fieldDocsMerger.merge( + fieldDocsOriginal, + fieldDocsNew, + topDocsMerger.FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + true + ); + + assertNotNull(mergedFieldDocs); + assertEquals(4, mergedFieldDocs.length); + // check format, all elements one by one + assertEquals(1, mergedFieldDocs[0].fields[0]); + assertEquals(1, mergedFieldDocs[1].fields[0]); + assertEquals(1, mergedFieldDocs[2].fields[0]); + assertEquals(1, mergedFieldDocs[3].fields[0]); + } + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { assertEquals(expectedDocId, scoreDoc.doc); assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); } + + private void assertFieldDoc(FieldDoc fieldDoc, int expectedDocId, int expectedSortValue) { + assertEquals(expectedDocId, fieldDoc.doc); + assertEquals(expectedSortValue, fieldDoc.fields[0]); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 5a99f3f3a..597d8e29d 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -5,8 +5,12 @@ package org.opensearch.neuralsearch.search.query; import lombok.SneakyThrows; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; @@ -15,6 +19,11 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults; + +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.sort.SortAndFormats; public class TopDocsMergerTests extends OpenSearchQueryTestCase { @@ -22,8 +31,7 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -78,8 +86,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { @SneakyThrows public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -130,8 +137,7 @@ public void testMergeScoreDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -172,8 +178,7 @@ public void testMergeScoreDocs_whenBothTopDocsHasNoHits_thenSuccessful() { @SneakyThrows public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { - HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); - TopDocsMerger topDocsMerger = new TopDocsMerger(scoreDocsMerger); + TopDocsMerger topDocsMerger = new TopDocsMerger(null); TopDocs topDocsOriginal = new TopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -248,8 +253,277 @@ public void testThreeSequentialMerges_whenAllTopDocsHasHits_thenSuccessful() { assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[12].score, 0); } + @SneakyThrows + public void testMergeFieldDocs_whenBothTopDocsHasHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 80 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(1, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(1, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 70 }), + new FieldDoc(4, 0.3f, new Object[] { 60 }), + new FieldDoc(5, 0.05f, new Object[] { 30 }), + createFieldDocDelimiterElementForHybridSearchResults(1, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 40 }), + createFieldDocStartStopElementForHybridSearchResults(1, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew, + sortAndFormats + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(6, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 5 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 5 + 1 + 2 + 2 = 10 + assertEquals(10, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 0, 100); + assertFieldDoc(fieldDocs[3], 2, 80); + assertFieldDoc(fieldDocs[4], 1, 70); + assertFieldDoc(fieldDocs[5], 4, 60); + assertFieldDoc(fieldDocs[6], 5, 30); + assertEquals(1, fieldDocs[7].fields[0]); + assertFieldDoc(fieldDocs[8], 4, 40); + assertEquals(1, fieldDocs[9].fields[0]); + } + + @SneakyThrows + public void testMergeFieldDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 100 }), + new FieldDoc(4, 0.3f, new Object[] { 60 }), + new FieldDoc(5, 0.05f, new Object[] { 30 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 80 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew, + sortAndFormats + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0.7f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 3 from sub-query1 and 1 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 3 + 1 + 2 + 2 = 8 + assertEquals(8, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 1, 100); + assertFieldDoc(fieldDocs[3], 4, 60); + assertFieldDoc(fieldDocs[4], 5, 30); + assertEquals(1, fieldDocs[5].fields[0]); + assertFieldDoc(fieldDocs[6], 4, 80); + assertEquals(1, fieldDocs[7].fields[0]); + } + + @SneakyThrows + public void testMergeFieldDocs_whenBothTopDocsHasNoHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew, + sortAndFormats + ); + + assertNotNull(mergedTopDocsAndMaxScore); + + assertEquals(0f, mergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(0, mergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, mergedTopDocsAndMaxScore.topDocs.totalHits.relation); + assertEquals(4, mergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) mergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertEquals(1, fieldDocs[2].fields[0]); + assertEquals(1, fieldDocs[3].fields[0]); + } + + @SneakyThrows + public void testThreeSequentialMergesWithFieldDocs_whenAllTopDocsHasHits_thenSuccessful() { + DocValueFormat docValueFormat[] = new DocValueFormat[] { DocValueFormat.RAW }; + SortField sortField = new SortField("stock", SortField.Type.INT, true); + Sort sort = new Sort(sortField); + SortAndFormats sortAndFormats = new SortAndFormats(sort, docValueFormat); + TopDocsMerger topDocsMerger = new TopDocsMerger(sortAndFormats); + + TopDocs topDocsOriginal = new TopFieldDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + new FieldDoc(0, 0.5f, new Object[] { 100 }), + new FieldDoc(2, 0.3f, new Object[] { 20 }), + createFieldDocDelimiterElementForHybridSearchResults(0, new Object[] { 1 }), + createFieldDocStartStopElementForHybridSearchResults(0, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreOriginal = new TopDocsAndMaxScore(topDocsOriginal, 0.5f); + TopDocs topDocsNew = new TopFieldDocs( + new TotalHits(4, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(1, 0.7f, new Object[] { 80 }), + new FieldDoc(4, 0.3f, new Object[] { 30 }), + new FieldDoc(5, 0.05f, new Object[] { 10 }), + createFieldDocDelimiterElementForHybridSearchResults(2, new Object[] { 1 }), + new FieldDoc(4, 0.6f, new Object[] { 30 }), + createFieldDocStartStopElementForHybridSearchResults(2, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + topDocsAndMaxScoreOriginal, + topDocsAndMaxScoreNew, + sortAndFormats + ); + + assertNotNull(firstMergedTopDocsAndMaxScore); + + // merge results from collector 3 + TopDocs topDocsThirdCollector = new TopFieldDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + + new FieldDoc[] { + createFieldDocStartStopElementForHybridSearchResults(3, new Object[] { 1 }), + createFieldDocDelimiterElementForHybridSearchResults(3, new Object[] { 1 }), + new FieldDoc(3, 0.4f, new Object[] { 90 }), + createFieldDocDelimiterElementForHybridSearchResults(3, new Object[] { 1 }), + new FieldDoc(7, 0.85f, new Object[] { 60 }), + new FieldDoc(9, 0.2f, new Object[] { 50 }), + createFieldDocStartStopElementForHybridSearchResults(3, new Object[] { 1 }) }, + sortAndFormats.sort.getSort() + ); + TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + firstMergedTopDocsAndMaxScore, + topDocsAndMaxScoreThirdCollector, + sortAndFormats + ); + + assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); + assertEquals(9, finalMergedTopDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, finalMergedTopDocsAndMaxScore.topDocs.totalHits.relation); + // expected number of rows is 6 from sub-query1 and 3 from sub-query2, plus 2 start-stop elements + 2 delimiters + // 6 + 3 + 2 + 2 = 13 + assertEquals(13, finalMergedTopDocsAndMaxScore.topDocs.scoreDocs.length); + // check format, all elements one by one + FieldDoc[] fieldDocs = (FieldDoc[]) finalMergedTopDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(1, fieldDocs[0].fields[0]); + assertEquals(1, fieldDocs[1].fields[0]); + assertFieldDoc(fieldDocs[2], 0, 100); + assertFieldDoc(fieldDocs[3], 3, 90); + assertFieldDoc(fieldDocs[4], 1, 80); + assertFieldDoc(fieldDocs[5], 4, 30); + assertFieldDoc(fieldDocs[6], 2, 20); + assertFieldDoc(fieldDocs[7], 5, 10); + assertEquals(1, fieldDocs[8].fields[0]); + assertFieldDoc(fieldDocs[9], 7, 60); + assertFieldDoc(fieldDocs[10], 9, 50); + assertFieldDoc(fieldDocs[11], 4, 30); + assertEquals(1, fieldDocs[12].fields[0]); + } + private void assertScoreDoc(ScoreDoc scoreDoc, int expectedDocId, float expectedScore) { assertEquals(expectedDocId, scoreDoc.doc); assertEquals(expectedScore, scoreDoc.score, DELTA_FOR_ASSERTION); } + + private void assertFieldDoc(FieldDoc fieldDoc, int expectedDocId, int expectedSortValue) { + assertEquals(expectedDocId, fieldDoc.doc); + assertEquals(expectedSortValue, fieldDoc.fields[0]); + } } From 02f93ab58393c34253e46c2799bd52740f99cda7 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Wed, 3 Jul 2024 21:25:51 -0700 Subject: [PATCH 03/10] Functional Interface Signed-off-by: Varun Jain --- .../collector/HybridSearchCollector.java | 17 ++++++++ .../HybridTopFieldDocSortCollector.java | 3 +- .../collector/HybridTopScoreDocCollector.java | 3 +- .../search/query/HybridCollectorManager.java | 41 ++++++++++--------- 4 files changed, 40 insertions(+), 24 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java new file mode 100644 index 000000000..775941b35 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.collector; + +import java.util.List; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.TopDocs; + +public interface HybridSearchCollector extends Collector { + List topDocs(); + + int getTotalHits(); + + float getMaxScore(); +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java index 2b1f3171d..2e268d37b 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopFieldDocSortCollector.java @@ -13,7 +13,6 @@ import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.FieldValueHitQueue; import org.apache.lucene.search.ScoreDoc; @@ -38,7 +37,7 @@ The individual query results are sorted as per the sort criteria sent in the search request. */ @Log4j2 -public abstract class HybridTopFieldDocSortCollector implements Collector { +public abstract class HybridTopFieldDocSortCollector implements HybridSearchCollector { private final int numHits; private final HitsThresholdChecker hitsThresholdChecker; private final Sort sort; diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java index 1c01f905a..01a4cdfff 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridTopScoreDocCollector.java @@ -12,7 +12,6 @@ import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Scorable; @@ -30,7 +29,7 @@ * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results */ @Log4j2 -public class HybridTopScoreDocCollector implements Collector { +public class HybridTopScoreDocCollector implements HybridSearchCollector { private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); private int docBase; private final HitsThresholdChecker hitsThresholdChecker; diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 5bd2fa79d..4ed02a277 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -22,6 +22,7 @@ import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.collector.HybridSearchCollector; import org.opensearch.neuralsearch.search.collector.HybridTopFieldDocSortCollector; import org.opensearch.neuralsearch.search.collector.HybridTopScoreDocCollector; import org.opensearch.neuralsearch.search.collector.SimpleFieldCollector; @@ -148,59 +149,59 @@ private Collector getHybridQueryCollector() { */ @Override public ReduceableSearchResult reduce(Collection collectors) { - final List hybridSearchCollectors = getHybridSearchCollectors(collectors); + final List hybridSearchCollectors = getHybridSearchCollectors(collectors); if (hybridSearchCollectors.isEmpty()) { throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); } return reduceSearchResults(getSearchResults(hybridSearchCollectors)); } - private List getSearchResults(final List hybridSearchCollectors) { + private List getSearchResults(final List hybridSearchCollectors) { List results = new ArrayList<>(); DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats); - for (Collector collector : hybridSearchCollectors) { + for (HybridSearchCollector collector : hybridSearchCollectors) { TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats); results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats)); } return results; } - private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final Collector collector, final DocValueFormat[] docValueFormats) { - float maxScore; + private TopDocsAndMaxScore getTopDocsAndAndMaxScore( + final HybridSearchCollector hybridSearchCollector, + final DocValueFormat[] docValueFormats + ) { TopDocs newTopDocs; + List topDocs = hybridSearchCollector.topDocs(); if (docValueFormats != null) { - HybridTopFieldDocSortCollector hybridTopFieldDocSortCollector = (HybridTopFieldDocSortCollector) collector; - List topFieldDocs = hybridTopFieldDocSortCollector.topDocs(); - maxScore = hybridTopFieldDocSortCollector.getMaxScore(); newTopDocs = getNewTopFieldDocs( - getTotalHits(this.trackTotalHitsUpTo, topFieldDocs, hybridTopFieldDocSortCollector.getTotalHits()), - topFieldDocs, + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), + (List) topDocs, sortAndFormats.sort.getSort() ); } else { - HybridTopScoreDocCollector hybridTopScoreDocCollector = (HybridTopScoreDocCollector) collector; - List topDocs = hybridTopScoreDocCollector.topDocs(); - maxScore = hybridTopScoreDocCollector.getMaxScore(); - newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridTopScoreDocCollector.getTotalHits()), topDocs); + newTopDocs = getNewTopDocs( + getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), + (List) topDocs + ); } - return new TopDocsAndMaxScore(newTopDocs, maxScore); + return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore()); } - private List getHybridSearchCollectors(final Collection collectors) { - final List hybridSearchCollectors = new ArrayList<>(); + private List getHybridSearchCollectors(final Collection collectors) { + final List hybridSearchCollectors = new ArrayList<>(); for (final Collector collector : collectors) { if (collector instanceof MultiCollectorWrapper) { for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { if (sub instanceof HybridTopScoreDocCollector || sub instanceof HybridTopFieldDocSortCollector) { - hybridSearchCollectors.add(sub); + hybridSearchCollectors.add((HybridSearchCollector) sub); } } } else if (collector instanceof HybridTopScoreDocCollector || collector instanceof HybridTopFieldDocSortCollector) { - hybridSearchCollectors.add(collector); + hybridSearchCollectors.add((HybridSearchCollector) collector); } else if (collector instanceof FilteredCollector && (((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector || ((FilteredCollector) collector).getCollector() instanceof HybridTopFieldDocSortCollector)) { - hybridSearchCollectors.add(((FilteredCollector) collector).getCollector()); + hybridSearchCollectors.add((HybridSearchCollector) ((FilteredCollector) collector).getCollector()); } } return hybridSearchCollectors; From d2cb1f2537d81a73e76af5a0f4891ff7973d3ec5 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Sun, 7 Jul 2024 11:19:52 -0700 Subject: [PATCH 04/10] Addressing Martin Comments Signed-off-by: Varun Jain --- CHANGELOG.md | 1 - .../query/NeuralQueryBuilder.java | 17 ++- .../collector/HybridSearchCollector.java | 12 ++ .../search/query/HybridCollectorManager.java | 14 +- .../query/HybridQueryScoreDocsMerger.java | 2 + .../search/query/TopDocsMerger.java | 130 ++++++++++-------- .../search/query/TopDocsMergerTests.java | 29 +--- 7 files changed, 105 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b6df7e66..1a72fcdaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes -- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 986d6d96c..e7e081f2b 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -292,15 +292,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { if (vectorSupplier().get() == null) { return this; } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName(), vectorSupplier.get()).filter(filter()); - if (maxDistance != null) { - knnQueryBuilder.maxDistance(maxDistance); - } else if (minScore != null) { - knnQueryBuilder.minScore(minScore); - } else { - knnQueryBuilder.k(k); - } - return knnQueryBuilder; + return KNNQueryBuilder.builder() + .fieldName(fieldName()) + .vector(vectorSupplier.get()) + .filter(filter()) + .maxDistance(maxDistance) + .minScore(minScore) + .k(k) + .build(); } SetOnce vectorSetOnce = new SetOnce<>(); diff --git a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java index 775941b35..c1702996d 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/collector/HybridSearchCollector.java @@ -8,10 +8,22 @@ import org.apache.lucene.search.Collector; import org.apache.lucene.search.TopDocs; +/** + * Common interface class for Hybrid search collectors + */ public interface HybridSearchCollector extends Collector { + /** + * @return List of topDocs which contains topDocs of individual subqueries. + */ List topDocs(); + /** + * @return count of total hits per shard + */ int getTotalHits(); + /** + * @return maxScore found on a shard + */ float getMaxScore(); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 4ed02a277..2bcca6689 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -388,18 +388,8 @@ private void reduceCollectorResults( } // we need to do actual merge because query result and current collector both have some score hits TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs(); - result.topDocs(getMergeTopDocsAndMaxScores(originalTotalDocsAndHits, topDocsAndMaxScore), docValueFormats); - } - - private TopDocsAndMaxScore getMergeTopDocsAndMaxScores( - final TopDocsAndMaxScore originalTotalDocsAndHits, - final TopDocsAndMaxScore topDocsAndMaxScore - ) { - if (sortAndFormats != null) { - return topDocsMerger.mergeFieldDocs(originalTotalDocsAndHits, topDocsAndMaxScore, sortAndFormats); - } else { - return topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); - } + TopDocsAndMaxScore mergeTopDocsAndMaxScores = topDocsMerger.merge(originalTotalDocsAndHits, topDocsAndMaxScore); + result.topDocs(mergeTopDocsAndMaxScores, docValueFormats); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java index 662cb2242..1895d1d79 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMerger.java @@ -32,6 +32,8 @@ class HybridQueryScoreDocsMerger { * Method returns new object and doesn't mutate original ScoreDocs arrays. * @param sourceScoreDocs original score docs from query result * @param newScoreDocs new score docs that we need to merge into existing scores + * @param comparator comparator to compare the score docs + * @param isSortEnabled flag that show if sort is enabled or disabled * @return merged array of ScoreDocs objects */ public T[] merge(final T[] sourceScoreDocs, final T[] newScoreDocs, final Comparator comparator, final boolean isSortEnabled) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index a041149fe..4c9ada011 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -26,6 +26,7 @@ class TopDocsMerger { private HybridQueryScoreDocsMerger scoreDocsMerger; private HybridQueryScoreDocsMerger fieldDocsMerger; + private SortAndFormats sortAndFormats; @VisibleForTesting protected static Comparator SCORE_DOC_BY_SCORE_COMPARATOR; @VisibleForTesting @@ -39,7 +40,8 @@ class TopDocsMerger { * Uses hybrid query score docs merger to merge internal score docs */ TopDocsMerger(final SortAndFormats sortAndFormats) { - if (sortAndFormats != null) { + this.sortAndFormats = sortAndFormats; + if (this.sortAndFormats != null) { fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER); } else { @@ -58,67 +60,51 @@ public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAn if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { return source; } - // we need to merge hits per individual sub-query - // format of results in both new and source TopDocs is following - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - ScoreDoc[] mergedScoreDocs = scoreDocsMerger.merge( - source.topDocs.scoreDocs, - newTopDocs.topDocs.scoreDocs, - SCORE_DOC_BY_SCORE_COMPARATOR, - false - ); TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); TopDocsAndMaxScore result = new TopDocsAndMaxScore( - new TopDocs(mergedTotalHits, mergedScoreDocs), + getTopDocs(getMergedScoreDocs(source.topDocs.scoreDocs, newTopDocs.topDocs.scoreDocs), mergedTotalHits), Math.max(source.maxScore, newTopDocs.maxScore) ); return result; } - /** - * Merge TopFieldDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. - * @param source TopDocsAndMaxScore for the original query - * @param newTopDocs TopDocsAndMaxScore for the new query - * @return merged TopDocsAndMaxScore object - */ - public TopDocsAndMaxScore mergeFieldDocs( - final TopDocsAndMaxScore source, - final TopDocsAndMaxScore newTopDocs, - final SortAndFormats sortAndFormats - ) { - if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { - return source; - } - // we need to merge hits per individual sub-query - // format of results in both new and source TopDocs is following - // doc_id | magic_number_1 | [1] - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_1 | [1] - FieldDoc[] mergedScoreDocs = fieldDocsMerger.merge( - (FieldDoc[]) source.topDocs.scoreDocs, - (FieldDoc[]) newTopDocs.topDocs.scoreDocs, - FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, - true - ); - TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); - TopDocsAndMaxScore result = new TopDocsAndMaxScore( - new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()), - Math.max(source.maxScore, newTopDocs.maxScore) - ); - return result; - } + // /** + // * Merge TopFieldDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. + // * @param source TopDocsAndMaxScore for the original query + // * @param newTopDocs TopDocsAndMaxScore for the new query + // * @return merged TopDocsAndMaxScore object + // */ + // public TopDocsAndMaxScore mergeFieldDocs( + // final TopDocsAndMaxScore source, + // final TopDocsAndMaxScore newTopDocs, + // final SortAndFormats sortAndFormats + // ) { + // if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { + // return source; + // } + // // we need to merge hits per individual sub-query + // // format of results in both new and source TopDocs is following + // // doc_id | magic_number_1 | [1] + // // doc_id | magic_number_2 | [1] + // // ... + // // doc_id | magic_number_2 | [1] + // // ... + // // doc_id | magic_number_2 | [1] + // // ... + // // doc_id | magic_number_1 | [1] + // FieldDoc[] mergedScoreDocs = fieldDocsMerger.merge( + // (FieldDoc[]) source.topDocs.scoreDocs, + // (FieldDoc[]) newTopDocs.topDocs.scoreDocs, + // FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, + // true + // ); + // TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); + // TopDocsAndMaxScore result = new TopDocsAndMaxScore( + // new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()), + // Math.max(source.maxScore, newTopDocs.maxScore) + // ); + // return result; + // } private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { // merged value is a lower bound - if both are equal_to than merged will also be equal_to, @@ -129,4 +115,38 @@ private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopD : TotalHits.Relation.EQUAL_TO; return new TotalHits(source.topDocs.totalHits.value + newTopDocs.topDocs.totalHits.value, mergedHitsRelation); } + + private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits) { + if (sortAndFormats != null) { + return new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()); + } + return new TopDocs(mergedTotalHits, mergedScoreDocs); + } + + private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs) { + if (sortAndFormats != null) { + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 | [1] + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_1 | [1] + return fieldDocsMerger.merge((FieldDoc[]) source, (FieldDoc[]) newScoreDocs, FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true); + } + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + return scoreDocsMerger.merge(source, newScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR, false); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 597d8e29d..d10ca0668 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -289,11 +289,7 @@ public void testMergeFieldDocs_whenBothTopDocsHasHits_thenSuccessful() { sortAndFormats.sort.getSort() ); TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); - TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( - topDocsAndMaxScoreOriginal, - topDocsAndMaxScoreNew, - sortAndFormats - ); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); assertNotNull(mergedTopDocsAndMaxScore); @@ -352,11 +348,7 @@ public void testMergeFieldDocs_whenOneTopDocsHasHitsAndOtherIsEmpty_thenSuccessf sortAndFormats.sort.getSort() ); TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); - TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( - topDocsAndMaxScoreOriginal, - topDocsAndMaxScoreNew, - sortAndFormats - ); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); assertNotNull(mergedTopDocsAndMaxScore); @@ -409,11 +401,7 @@ public void testMergeFieldDocs_whenBothTopDocsHasNoHits_thenSuccessful() { sortAndFormats.sort.getSort() ); TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0); - TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( - topDocsAndMaxScoreOriginal, - topDocsAndMaxScoreNew, - sortAndFormats - ); + TopDocsAndMaxScore mergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); assertNotNull(mergedTopDocsAndMaxScore); @@ -465,11 +453,7 @@ public void testThreeSequentialMergesWithFieldDocs_whenAllTopDocsHasHits_thenSuc sortAndFormats.sort.getSort() ); TopDocsAndMaxScore topDocsAndMaxScoreNew = new TopDocsAndMaxScore(topDocsNew, 0.7f); - TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( - topDocsAndMaxScoreOriginal, - topDocsAndMaxScoreNew, - sortAndFormats - ); + TopDocsAndMaxScore firstMergedTopDocsAndMaxScore = topDocsMerger.merge(topDocsAndMaxScoreOriginal, topDocsAndMaxScoreNew); assertNotNull(firstMergedTopDocsAndMaxScore); @@ -488,10 +472,9 @@ public void testThreeSequentialMergesWithFieldDocs_whenAllTopDocsHasHits_thenSuc sortAndFormats.sort.getSort() ); TopDocsAndMaxScore topDocsAndMaxScoreThirdCollector = new TopDocsAndMaxScore(topDocsThirdCollector, 0.85f); - TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.mergeFieldDocs( + TopDocsAndMaxScore finalMergedTopDocsAndMaxScore = topDocsMerger.merge( firstMergedTopDocsAndMaxScore, - topDocsAndMaxScoreThirdCollector, - sortAndFormats + topDocsAndMaxScoreThirdCollector ); assertEquals(0.85f, finalMergedTopDocsAndMaxScore.maxScore, DELTA_FOR_ASSERTION); From 9f4ac2eca1582989e10af9cd9b239cac6efe7695 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 8 Jul 2024 12:38:43 -0700 Subject: [PATCH 05/10] Removing comments Signed-off-by: Varun Jain --- .../search/query/TopDocsMerger.java | 38 ------------------- 1 file changed, 38 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 4c9ada011..38d532a8f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -68,44 +68,6 @@ public TopDocsAndMaxScore merge(final TopDocsAndMaxScore source, final TopDocsAn return result; } - // /** - // * Merge TopFieldDocs and MaxScore from multiple search queries into a single TopDocsAndMaxScore object. - // * @param source TopDocsAndMaxScore for the original query - // * @param newTopDocs TopDocsAndMaxScore for the new query - // * @return merged TopDocsAndMaxScore object - // */ - // public TopDocsAndMaxScore mergeFieldDocs( - // final TopDocsAndMaxScore source, - // final TopDocsAndMaxScore newTopDocs, - // final SortAndFormats sortAndFormats - // ) { - // if (Objects.isNull(newTopDocs) || Objects.isNull(newTopDocs.topDocs) || newTopDocs.topDocs.totalHits.value == 0) { - // return source; - // } - // // we need to merge hits per individual sub-query - // // format of results in both new and source TopDocs is following - // // doc_id | magic_number_1 | [1] - // // doc_id | magic_number_2 | [1] - // // ... - // // doc_id | magic_number_2 | [1] - // // ... - // // doc_id | magic_number_2 | [1] - // // ... - // // doc_id | magic_number_1 | [1] - // FieldDoc[] mergedScoreDocs = fieldDocsMerger.merge( - // (FieldDoc[]) source.topDocs.scoreDocs, - // (FieldDoc[]) newTopDocs.topDocs.scoreDocs, - // FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, - // true - // ); - // TotalHits mergedTotalHits = getMergedTotalHits(source, newTopDocs); - // TopDocsAndMaxScore result = new TopDocsAndMaxScore( - // new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()), - // Math.max(source.maxScore, newTopDocs.maxScore) - // ); - // return result; - // } - private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopDocsAndMaxScore newTopDocs) { // merged value is a lower bound - if both are equal_to than merged will also be equal_to, // otherwise assign greater_than_or_equal From 8c7ed1a83635dd6bac86e424b4d18f81c2c49577 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 8 Jul 2024 22:48:06 -0700 Subject: [PATCH 06/10] Addressing Martin Comments Signed-off-by: Varun Jain --- .../neuralsearch/search/query/HybridCollectorManager.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 2bcca6689..427d73953 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -62,7 +62,7 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { final List hybridSearchCollectors = getHybridSearchCollectors(collectors); if (hybridSearchCollectors.isEmpty()) { - throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper collectors"); } return reduceSearchResults(getSearchResults(hybridSearchCollectors)); } From ec029af81cbfd8c33d7d70206b1bad4d9e962389 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 9 Jul 2024 09:13:34 -0700 Subject: [PATCH 07/10] Addressing Martin Comments Signed-off-by: Varun Jain --- .../search/query/HybridCollectorManager.java | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 427d73953..75e9c070e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -171,18 +171,15 @@ private TopDocsAndMaxScore getTopDocsAndAndMaxScore( final DocValueFormat[] docValueFormats ) { TopDocs newTopDocs; - List topDocs = hybridSearchCollector.topDocs(); + List topDocs = hybridSearchCollector.topDocs(); if (docValueFormats != null) { newTopDocs = getNewTopFieldDocs( getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), - (List) topDocs, + topDocs, sortAndFormats.sort.getSort() ); } else { - newTopDocs = getNewTopDocs( - getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), - (List) topDocs - ); + newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs); } return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore()); } From a7afa97ca8e1d2d5612262c7a32401390dd2a158 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 9 Jul 2024 09:54:03 -0700 Subject: [PATCH 08/10] Addressing Martin commnents Signed-off-by: Varun Jain --- .../search/query/TopDocsMerger.java | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 38d532a8f..799dd2a4e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -24,8 +24,7 @@ @RequiredArgsConstructor(access = AccessLevel.PACKAGE) class TopDocsMerger { - private HybridQueryScoreDocsMerger scoreDocsMerger; - private HybridQueryScoreDocsMerger fieldDocsMerger; + private HybridQueryScoreDocsMerger scoreDocsMerger; private SortAndFormats sortAndFormats; @VisibleForTesting protected static Comparator SCORE_DOC_BY_SCORE_COMPARATOR; @@ -42,7 +41,7 @@ class TopDocsMerger { TopDocsMerger(final SortAndFormats sortAndFormats) { this.sortAndFormats = sortAndFormats; if (this.sortAndFormats != null) { - fieldDocsMerger = new HybridQueryScoreDocsMerger<>(); + scoreDocsMerger = new HybridQueryScoreDocsMerger(); FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER); } else { scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); @@ -86,19 +85,19 @@ private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits } private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs) { - if (sortAndFormats != null) { - // we need to merge hits per individual sub-query - // format of results in both new and source TopDocs is following - // doc_id | magic_number_1 | [1] - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_2 | [1] - // ... - // doc_id | magic_number_1 | [1] - return fieldDocsMerger.merge((FieldDoc[]) source, (FieldDoc[]) newScoreDocs, FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR, true); - } + // Case 1 when sorting is enabled then below will be the TopDocs format + // we need to merge hits per individual sub-query + // format of results in both new and source TopDocs is following + // doc_id | magic_number_1 | [1] + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_2 | [1] + // ... + // doc_id | magic_number_1 | [1] + + // Case 2 when sorting is disabled then below will be the TopDocs format // we need to merge hits per individual sub-query // format of results in both new and source TopDocs is following // doc_id | magic_number_1 @@ -109,6 +108,10 @@ private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs // doc_id | magic_number_2 // ... // doc_id | magic_number_1 - return scoreDocsMerger.merge(source, newScoreDocs, SCORE_DOC_BY_SCORE_COMPARATOR, false); + return scoreDocsMerger.merge(source, newScoreDocs, comparator(), sortAndFormats != null); + } + + private Comparator comparator() { + return sortAndFormats != null ? FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR : SCORE_DOC_BY_SCORE_COMPARATOR; } } From b9a62001876ae761550f3b3701c5158fbda02544 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 9 Jul 2024 11:19:41 -0700 Subject: [PATCH 09/10] Address Martin Comments Signed-off-by: Varun Jain --- .../neuralsearch/search/query/TopDocsMerger.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 799dd2a4e..154db9798 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -24,7 +24,7 @@ @RequiredArgsConstructor(access = AccessLevel.PACKAGE) class TopDocsMerger { - private HybridQueryScoreDocsMerger scoreDocsMerger; + private HybridQueryScoreDocsMerger docsMerger; private SortAndFormats sortAndFormats; @VisibleForTesting protected static Comparator SCORE_DOC_BY_SCORE_COMPARATOR; @@ -40,11 +40,11 @@ class TopDocsMerger { */ TopDocsMerger(final SortAndFormats sortAndFormats) { this.sortAndFormats = sortAndFormats; - if (this.sortAndFormats != null) { - scoreDocsMerger = new HybridQueryScoreDocsMerger(); + if (isSortingEnabled()) { + docsMerger = new HybridQueryScoreDocsMerger(); FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR = new HybridQueryFieldDocComparator(sortAndFormats.sort.getSort(), MERGING_TIE_BREAKER); } else { - scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); + docsMerger = new HybridQueryScoreDocsMerger<>(); SCORE_DOC_BY_SCORE_COMPARATOR = Comparator.comparing((scoreDoc) -> scoreDoc.score); } } @@ -108,10 +108,14 @@ private ScoreDoc[] getMergedScoreDocs(ScoreDoc[] source, ScoreDoc[] newScoreDocs // doc_id | magic_number_2 // ... // doc_id | magic_number_1 - return scoreDocsMerger.merge(source, newScoreDocs, comparator(), sortAndFormats != null); + return docsMerger.merge(source, newScoreDocs, comparator(), isSortingEnabled()); } private Comparator comparator() { return sortAndFormats != null ? FIELD_DOC_BY_SORT_CRITERIA_COMPARATOR : SCORE_DOC_BY_SCORE_COMPARATOR; } + + private boolean isSortingEnabled() { + return sortAndFormats != null; + } } From 8a933e3f082ebe34eba1c45f37850639ca5bd11e Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Tue, 9 Jul 2024 11:21:07 -0700 Subject: [PATCH 10/10] Address Martin Comments Signed-off-by: Varun Jain --- .../org/opensearch/neuralsearch/search/query/TopDocsMerger.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java index 154db9798..4efb1a2fa 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/TopDocsMerger.java @@ -78,7 +78,7 @@ private TotalHits getMergedTotalHits(final TopDocsAndMaxScore source, final TopD } private TopDocs getTopDocs(ScoreDoc[] mergedScoreDocs, TotalHits mergedTotalHits) { - if (sortAndFormats != null) { + if (isSortingEnabled()) { return new TopFieldDocs(mergedTotalHits, mergedScoreDocs, sortAndFormats.sort.getSort()); } return new TopDocs(mergedTotalHits, mergedScoreDocs);