Skip to content

Commit

Permalink
Separate collector managers for concurrent and non-concurrent searches
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 21, 2024
1 parent ec3f7dc commit e3e6071
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import org.apache.lucene.search.CollectorManager;
import org.opensearch.neuralsearch.util.HybridQueryUtil;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.aggregations.ConcurrentAggregationProcessor;
import org.opensearch.search.aggregations.DefaultAggregationProcessor;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.QuerySearchResult;
Expand All @@ -24,19 +22,15 @@
@AllArgsConstructor
public class HybridAggregationProcessor implements AggregationProcessor {

private final ConcurrentAggregationProcessor concurrentAggregationProcessor;
private final DefaultAggregationProcessor defaultAggregationProcessor;
private final AggregationProcessor delegateAggsProcessor;

@Override
public void preProcess(SearchContext context) {
if (context.shouldUseConcurrentSearch()) {
concurrentAggregationProcessor.preProcess(context);
} else {
defaultAggregationProcessor.preProcess(context);
}
delegateAggsProcessor.preProcess(context);

if (HybridQueryUtil.isHybridQuery(context.query(), context)) {
HybridCollectorManager collectorManager;
// adding collector manager for hybrid query
CollectorManager collectorManager;
try {
collectorManager = HybridCollectorManager.createHybridCollectorManager(context);
} catch (IOException e) {
Expand All @@ -52,31 +46,22 @@ public void preProcess(SearchContext context) {
public void postProcess(SearchContext context) {
if (HybridQueryUtil.isHybridQuery(context.query(), context)) {
if (!context.shouldUseConcurrentSearch()) {
CollectorManager<?, ReduceableSearchResult> collectorManager = context.queryCollectorManagers()
.get(HybridCollectorManager.class);
try {
final Collection collectors = List.of(collectorManager.newCollector());
collectorManager.reduce(collectors).reduce(context.queryResult());
} catch (IOException e) {
throw new QueryPhaseExecutionException(
context.shardTarget(),
"failed to execute hybrid query aggregation processor",
e
);
}
reduceCollectorResults(context);
}
updateQueryResult(context.queryResult(), context);
}

if (context.shouldUseConcurrentSearch()) {
concurrentAggregationProcessor.postProcess(context);
} else {
defaultAggregationProcessor.postProcess(context);
}
delegateAggsProcessor.postProcess(context);
}

private boolean shouldUseMultiCollectorManager(SearchContext context) {
return HybridQueryUtil.isHybridQuery(context.query(), context) || context.shouldUseConcurrentSearch();
private void reduceCollectorResults(SearchContext context) {
CollectorManager<?, ReduceableSearchResult> collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class);
try {
final Collection collectors = List.of(collectorManager.newCollector());
collectorManager.reduce(collectors).reduce(context.queryResult());
} catch (IOException e) {
throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e);
}
}

private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,75 +33,68 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

@RequiredArgsConstructor
public class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
Collector maxScoreCollector;
final Weight filteringWeight;
private final Optional<Weight> filteringWeightOptional;

public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Weight filterWeight = null;
// check for post filter
if (Objects.nonNull(searchContext.parsedPostFilter())) {
Query filterQuery = searchContext.parsedPostFilter().query();
ContextIndexSearcher searcher = searchContext.searcher();
final Weight filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
return new HybridCollectorManager(
filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
}

return searchContext.shouldUseConcurrentSearch()
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
);
}

return new HybridCollectorManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
null
);
}

@Override
public org.apache.lucene.search.Collector newCollector() {
if (Objects.isNull(maxScoreCollector)) {
maxScoreCollector = getCollector();
return maxScoreCollector;
} else {
Collector toReturnCollector = maxScoreCollector;
maxScoreCollector = null;
return toReturnCollector;
}
}
abstract public org.apache.lucene.search.Collector newCollector();

private Collector getCollector() {
Collector getCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
if (Objects.isNull(filteringWeight)) {
if (filteringWeightOptional.isEmpty()) {
// this is plain hybrid query scores collector
return hybridcollector;
}
// this is hybrid query scores collector with post filter applied
return new FilteredCollector(hybridcollector, filteringWeight);
return new FilteredCollector(hybridcollector, filteringWeightOptional.get());
}

@Override
Expand Down Expand Up @@ -196,10 +189,6 @@ private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDo
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();
/*long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}*/

return new TotalHits(maxTotalHits, relation);
}
Expand All @@ -219,4 +208,50 @@ private float getMaxScore(final List<TopDocs> topDocs) {
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}

static class HybridCollectorNonConcurrentManager extends HybridCollectorManager {
Collector maxScoreCollector;

public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
}

@Override
public Collector newCollector() {
if (Objects.isNull(maxScoreCollector)) {
maxScoreCollector = getCollector();
return maxScoreCollector;
} else {
Collector toReturnCollector = maxScoreCollector;
maxScoreCollector = null;
return toReturnCollector;
}
}
}

static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager {

public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
}

@Override
public Collector newCollector() {
return getCollector();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.aggregations.ConcurrentAggregationProcessor;
import org.opensearch.search.aggregations.DefaultAggregationProcessor;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
Expand All @@ -45,11 +43,6 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

private final AggregationProcessor aggregationProcessor = new HybridAggregationProcessor(
new ConcurrentAggregationProcessor(),
new DefaultAggregationProcessor()
);

public HybridQueryPhaseSearcher() {
super();
}
Expand Down Expand Up @@ -149,7 +142,7 @@ protected boolean searchWithCollector(
) throws IOException {
log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId());

HybridCollectorManager collectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext);
CollectorManager collectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext);
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> collectorManagersByManagerClass = searchContext
.queryCollectorManagers();
collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager);
Expand Down Expand Up @@ -205,6 +198,7 @@ private int getMaxDepthLimit(final SearchContext searchContext) {

@Override
public AggregationProcessor aggregationProcessor(SearchContext searchContext) {
return aggregationProcessor;
AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext);
return new HybridAggregationProcessor(coreAggProcessor);
}
}

0 comments on commit e3e6071

Please sign in to comment.