From 65553e7864352abdf4c92e339c2a22e5d5b91b44 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 8 Mar 2024 13:59:04 -0800 Subject: [PATCH] Refactor non-concurrent collector manager Signed-off-by: Martin Gaievski --- .../query/HybridAggregationProcessor.java | 9 +++---- .../search/query/HybridCollectorManager.java | 27 +++++++------------ .../query/HybridCollectorManagerTests.java | 4 +-- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index 42c27821f..4e9070748 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -6,6 +6,7 @@ import lombok.AllArgsConstructor; import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationInitializationException; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryPhaseExecutionException; @@ -13,7 +14,6 @@ import org.opensearch.search.query.ReduceableSearchResult; import java.io.IOException; -import java.util.Collection; import java.util.List; import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; @@ -36,8 +36,8 @@ public void preProcess(SearchContext context) { CollectorManager collectorManager; try { collectorManager = HybridCollectorManager.createHybridCollectorManager(context); - } catch (IOException e) { - throw new RuntimeException(e); + } catch (IOException exception) { + throw new AggregationInitializationException("could not initialize hybrid aggregation processor", exception); } context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); } @@ -67,8 +67,7 @@ public void postProcess(SearchContext context) { private void reduceCollectorResults(SearchContext context) { CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); try { - final Collection collectors = List.of(collectorManager.newCollector()); - collectorManager.reduce(collectors).reduce(context.queryResult()); + collectorManager.reduce(List.of()).reduce(context.queryResult()); } catch (IOException e) { throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index 40b10c5f3..a5de898ab 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -78,9 +78,7 @@ public static CollectorManager createHybridCollectorManager(final SearchContext } @Override - abstract public Collector newCollector(); - - Collector getCollector() { + public Collector newCollector() { Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); return hybridcollector; } @@ -211,7 +209,7 @@ private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats * use saved state of collector */ static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { - Collector maxScoreCollector; + private final Collector scoreCollector; public HybridCollectorNonConcurrentManager( int numHits, @@ -221,18 +219,18 @@ public HybridCollectorNonConcurrentManager( SortAndFormats sortAndFormats ) { super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @Override public Collector newCollector() { - if (Objects.isNull(maxScoreCollector)) { - maxScoreCollector = getCollector(); - return maxScoreCollector; - } else { - Collector toReturnCollector = maxScoreCollector; - maxScoreCollector = null; - return toReturnCollector; - } + return scoreCollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + assert collectors.isEmpty() : "reduce on HybridCollectorNonConcurrentManager called with non-empty collectors"; + return super.reduce(List.of(scoreCollector)); } } @@ -251,10 +249,5 @@ public HybridCollectorConcurrentSearchManager( ) { super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); } - - @Override - public Collector newCollector() { - return getCollector(); - } } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index f9d616716..65d6f3d8a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -39,7 +39,6 @@ import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ReduceableSearchResult; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -175,8 +174,7 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); leafCollector.finish(); - final Collection collectors = List.of(collector); - Object results = hybridCollectorManager.reduce(collectors); + Object results = hybridCollectorManager.reduce(List.of()); assertNotNull(results); ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results);