From e3e60714694b1cad3d7a9a52000406b30d93be82 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 21 Feb 2024 12:20:08 -0800 Subject: [PATCH] Separate collector managers for concurrent and non-concurrent searches Signed-off-by: Martin Gaievski --- .../query/HybridAggregationProcessor.java | 43 +++----- .../search/query/HybridCollectorManager.java | 101 ++++++++++++------ .../query/HybridQueryPhaseSearcher.java | 12 +-- 3 files changed, 85 insertions(+), 71 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index f68579422..815cdc621 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -9,8 +9,6 @@ import org.apache.lucene.search.CollectorManager; import org.opensearch.neuralsearch.util.HybridQueryUtil; import org.opensearch.search.aggregations.AggregationProcessor; -import org.opensearch.search.aggregations.ConcurrentAggregationProcessor; -import org.opensearch.search.aggregations.DefaultAggregationProcessor; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QuerySearchResult; @@ -24,19 +22,15 @@ @AllArgsConstructor public class HybridAggregationProcessor implements AggregationProcessor { - private final ConcurrentAggregationProcessor concurrentAggregationProcessor; - private final DefaultAggregationProcessor defaultAggregationProcessor; + private final AggregationProcessor delegateAggsProcessor; @Override public void preProcess(SearchContext context) { - if (context.shouldUseConcurrentSearch()) { - concurrentAggregationProcessor.preProcess(context); - } else { - defaultAggregationProcessor.preProcess(context); - } + delegateAggsProcessor.preProcess(context); if (HybridQueryUtil.isHybridQuery(context.query(), context)) { - HybridCollectorManager collectorManager; + // adding collector manager for hybrid query + CollectorManager collectorManager; try { collectorManager = HybridCollectorManager.createHybridCollectorManager(context); } catch (IOException e) { @@ -52,31 +46,22 @@ public void preProcess(SearchContext context) { public void postProcess(SearchContext context) { if (HybridQueryUtil.isHybridQuery(context.query(), context)) { if (!context.shouldUseConcurrentSearch()) { - CollectorManager collectorManager = context.queryCollectorManagers() - .get(HybridCollectorManager.class); - try { - final Collection collectors = List.of(collectorManager.newCollector()); - collectorManager.reduce(collectors).reduce(context.queryResult()); - } catch (IOException e) { - throw new QueryPhaseExecutionException( - context.shardTarget(), - "failed to execute hybrid query aggregation processor", - e - ); - } + reduceCollectorResults(context); } updateQueryResult(context.queryResult(), context); } - if (context.shouldUseConcurrentSearch()) { - concurrentAggregationProcessor.postProcess(context); - } else { - defaultAggregationProcessor.postProcess(context); - } + delegateAggsProcessor.postProcess(context); } - private boolean shouldUseMultiCollectorManager(SearchContext context) { - return HybridQueryUtil.isHybridQuery(context.query(), context) || context.shouldUseConcurrentSearch(); + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); + } } private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 4505158f9..dfa2d9b12 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -33,6 +33,7 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -40,29 +41,40 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; @RequiredArgsConstructor -public class HybridCollectorManager implements CollectorManager { +public abstract class HybridCollectorManager implements CollectorManager { private final int numHits; private final HitsThresholdChecker hitsThresholdChecker; private final boolean isSingleShard; private final int trackTotalHitsUpTo; private final SortAndFormats sortAndFormats; - Collector maxScoreCollector; - final Weight filteringWeight; + private final Optional filteringWeightOptional; - public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { final IndexReader reader = searchContext.searcher().getIndexReader(); final int totalNumDocs = Math.max(0, reader.numDocs()); boolean isSingleShard = searchContext.numberOfShards() == 1; int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + Weight filterWeight = null; // check for post filter if (Objects.nonNull(searchContext.parsedPostFilter())) { Query filterQuery = searchContext.parsedPostFilter().query(); ContextIndexSearcher searcher = searchContext.searcher(); - final Weight filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f); - return new HybridCollectorManager( + filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f); + } + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort(), + filterWeight + ) + : new HybridCollectorNonConcurrentManager( numDocs, new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), isSingleShard, @@ -70,38 +82,19 @@ public static HybridCollectorManager createHybridCollectorManager(final SearchCo searchContext.sort(), filterWeight ); - } - - return new HybridCollectorManager( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), - isSingleShard, - trackTotalHitsUpTo, - searchContext.sort(), - null - ); } @Override - public org.apache.lucene.search.Collector newCollector() { - if (Objects.isNull(maxScoreCollector)) { - maxScoreCollector = getCollector(); - return maxScoreCollector; - } else { - Collector toReturnCollector = maxScoreCollector; - maxScoreCollector = null; - return toReturnCollector; - } - } + abstract public org.apache.lucene.search.Collector newCollector(); - private Collector getCollector() { + Collector getCollector() { Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker); - if (Objects.isNull(filteringWeight)) { + if (filteringWeightOptional.isEmpty()) { // this is plain hybrid query scores collector return hybridcollector; } // this is hybrid query scores collector with post filter applied - return new FilteredCollector(hybridcollector, filteringWeight); + return new FilteredCollector(hybridcollector, filteringWeightOptional.get()); } @Override @@ -196,10 +189,6 @@ private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDo uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); } long maxTotalHits = uniqueDocIds.size(); - /*long maxTotalHits = topDocs.get(0).totalHits.value; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - }*/ return new TotalHits(maxTotalHits, relation); } @@ -219,4 +208,50 @@ private float getMaxScore(final List topDocs) { private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { return sortAndFormats == null ? null : sortAndFormats.formats; } + + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + Collector maxScoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + if (Objects.isNull(maxScoreCollector)) { + maxScoreCollector = getCollector(); + return maxScoreCollector; + } else { + Collector toReturnCollector = maxScoreCollector; + maxScoreCollector = null; + return toReturnCollector; + } + } + } + + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + return getCollector(); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 83f71e7ce..465fe21a3 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -22,8 +22,6 @@ import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.search.aggregations.AggregationProcessor; -import org.opensearch.search.aggregations.ConcurrentAggregationProcessor; -import org.opensearch.search.aggregations.DefaultAggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; @@ -45,11 +43,6 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { - private final AggregationProcessor aggregationProcessor = new HybridAggregationProcessor( - new ConcurrentAggregationProcessor(), - new DefaultAggregationProcessor() - ); - public HybridQueryPhaseSearcher() { super(); } @@ -149,7 +142,7 @@ protected boolean searchWithCollector( ) throws IOException { log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); - HybridCollectorManager collectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + CollectorManager collectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); Map, CollectorManager> collectorManagersByManagerClass = searchContext .queryCollectorManagers(); collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); @@ -205,6 +198,7 @@ private int getMaxDepthLimit(final SearchContext searchContext) { @Override public AggregationProcessor aggregationProcessor(SearchContext searchContext) { - return aggregationProcessor; + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); } }