Skip to content

Commit

Permalink
Collector manager approach, initial POC version
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jan 19, 2024
1 parent e0d756e commit efc4a2c
Show file tree
Hide file tree
Showing 9 changed files with 544 additions and 119 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ testClusters.integTest {
// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx4g")

systemProperty('opensearch.experimental.feature.concurrent_segment_search.enabled', 'true')
}

// Remote Integration Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -27,13 +26,12 @@

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

/**
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
*/
@Log4j2
public class HybridTopScoreDocCollector implements Collector {
public class HybridTopScoreDocCollector<T extends ScoreDoc> implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
Expand All @@ -58,10 +56,10 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
// compoundQueryScorer = (HybridQueryScorer) scorer;
if (scorer instanceof HybridQueryScorer) {
compoundQueryScorer = (HybridQueryScorer) scorer;
}
else {
} else {
compoundQueryScorer = getHybridQueryScorer(scorer);
}
}
Expand All @@ -82,14 +80,12 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE
return null;
}



@Override
@Override
public void collect(int doc) throws IOException {
if (compoundQueryScorer == null) {
/*if (compoundQueryScorer == null) {
scorer.score();
return;
}
}*/
float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.search.aggregations.AggregationCollectorManager;
import org.opensearch.search.aggregations.AggregationInitializationException;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.aggregations.BucketCollectorProcessor;
import org.opensearch.search.aggregations.GlobalAggCollectorManager;
import org.opensearch.search.aggregations.NonGlobalAggCollectorManager;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.query.InternalProfileCollectorManager;
import org.opensearch.search.profile.query.InternalProfileComponent;
import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.ReduceableSearchResult;

import java.io.IOException;
import java.util.Collections;

public class HybridAggregationProcessor implements AggregationProcessor {

private final BucketCollectorProcessor bucketCollectorProcessor = new BucketCollectorProcessor();

@Override
public void preProcess(SearchContext context) {
try {
if (context.aggregations() != null) {
// update the bucket collector process as there is aggregation in the request
context.setBucketCollectorProcessor(bucketCollectorProcessor);
if (context.aggregations().factories().hasNonGlobalAggregator()) {
context.queryCollectorManagers().put(NonGlobalAggCollectorManager.class, new NonGlobalAggCollectorManager(context));
}
// initialize global aggregators as well, such that any failure to initialize can be caught before executing the request
if (context.aggregations().factories().hasGlobalAggregator()) {
context.queryCollectorManagers().put(GlobalAggCollectorManager.class, new GlobalAggCollectorManager(context));
}
}
} catch (IOException ex) {
throw new AggregationInitializationException("Could not initialize aggregators", ex);
}
}

@Override
public void postProcess(SearchContext context) {
if (context.aggregations() == null) {
context.queryResult().aggregations(null);
return;
}

// for concurrent case we will perform only global aggregation in post process as QueryResult is already populated with results of
// processing the non-global aggregation
CollectorManager<? extends Collector, ReduceableSearchResult> globalCollectorManager = context.queryCollectorManagers()
.get(GlobalAggCollectorManager.class);
try {
if (globalCollectorManager != null) {
Query query = context.buildFilteredQuery(Queries.newMatchAllQuery());
if (context.getProfilers() != null) {
globalCollectorManager = new InternalProfileCollectorManager(
globalCollectorManager,
((AggregationCollectorManager) globalCollectorManager).getCollectorReason(),
Collections.emptyList()
);
context.getProfilers().addQueryProfiler().setCollector((InternalProfileComponent) globalCollectorManager);
}
final ReduceableSearchResult result = context.searcher().search(query, globalCollectorManager);
result.reduce(context.queryResult());
}
} catch (Exception e) {
throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e);
}

// disable aggregations so that they don't run on next pages in case of scrolling
context.aggregations(null);
context.queryCollectorManagers().remove(NonGlobalAggCollectorManager.class);
context.queryCollectorManagers().remove(GlobalAggCollectorManager.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AllArgsConstructor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.sort.SortAndFormats;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

@AllArgsConstructor
public class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private int numHits;
private FieldDoc searchAfter;
private int hitCountThreshold;
private HitsThresholdChecker hitsThresholdChecker;
private boolean isSingleShard;
private int trackTotalHitsUpTo;
private SortAndFormats sortAndFormats;

public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
FieldDoc searchAfter = searchContext.searchAfter();
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
return new HybridCollectorManager(
numDocs,
searchAfter,
Integer.MAX_VALUE,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
);
}

@Override
public org.apache.lucene.search.Collector newCollector() throws IOException {
HybridTopScoreDocCollector<?> maxScoreCollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
return maxScoreCollector;
}

@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final List<HybridTopScoreDocCollector<?>> hybridTopScoreDocCollectors = new ArrayList<>();

for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) collector);
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector<?> hybridTopScoreDocCollector = hybridTopScoreDocCollectors.get(0);
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats));
}
return null;
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
ScoreDoc[] scoreDocs = new ScoreDoc[0];
if (Objects.nonNull(topDocs)) {
// for a single shard case we need to do score processing at coordinator level.
// this is workaround for current core behaviour, for single shard fetch phase is executed
// right after query phase and processors are called after actual fetch is done
// find any valid doc Id, or set it to -1 if there is not a single match
int delimiterDocId = topDocs.stream()
.filter(Objects::nonNull)
.filter(topDoc -> Objects.nonNull(topDoc.scoreDocs))
.map(topDoc -> topDoc.scoreDocs)
.filter(scoreDoc -> scoreDoc.length > 0)
.map(scoreDoc -> scoreDoc[0].doc)
.findFirst()
.orElse(-1);
if (delimiterDocId == -1) {
return new TopDocs(totalHits, scoreDocs);
}
// format scores using following template:
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
List<ScoreDoc> result = new ArrayList<>();
result.add(createStartStopElementForHybridSearchResults(delimiterDocId));
for (TopDocs topDoc : topDocs) {
if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) {
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId));
continue;
}
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId));
result.addAll(Arrays.asList(topDoc.scoreDocs));
}
result.add(createStartStopElementForHybridSearchResults(delimiterDocId));
scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final boolean isSingleShard) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}
long maxTotalHits = topDocs.get(0).totalHits.value;
for (TopDocs topDoc : topDocs) {
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
}

return new TotalHits(maxTotalHits, relation);
}

/*private int totalSize(final List<TopDocs> topDocs) {
int totalSize = 0;
for (TopDocs topDoc : topDocs) {
totalSize += topDoc.totalHits.value + 1;
}
// add 1 qty per each sub-query and + 2 for start and stop delimiters
totalSize += 2;
return totalSize;
}*/

private float getMaxScore(final List<TopDocs> topDocs) {
if (topDocs.isEmpty()) {
return 0.0f;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}
}
Loading

0 comments on commit efc4a2c

Please sign in to comment.