forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Collector manager approach, initial POC version
Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
e0d756e
commit efc4a2c
Showing
9 changed files
with
544 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.search.query; | ||
|
||
import org.apache.lucene.search.Collector; | ||
import org.apache.lucene.search.CollectorManager; | ||
import org.apache.lucene.search.Query; | ||
import org.opensearch.common.lucene.search.Queries; | ||
import org.opensearch.search.aggregations.AggregationCollectorManager; | ||
import org.opensearch.search.aggregations.AggregationInitializationException; | ||
import org.opensearch.search.aggregations.AggregationProcessor; | ||
import org.opensearch.search.aggregations.BucketCollectorProcessor; | ||
import org.opensearch.search.aggregations.GlobalAggCollectorManager; | ||
import org.opensearch.search.aggregations.NonGlobalAggCollectorManager; | ||
import org.opensearch.search.internal.SearchContext; | ||
import org.opensearch.search.profile.query.InternalProfileCollectorManager; | ||
import org.opensearch.search.profile.query.InternalProfileComponent; | ||
import org.opensearch.search.query.QueryPhaseExecutionException; | ||
import org.opensearch.search.query.ReduceableSearchResult; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
|
||
public class HybridAggregationProcessor implements AggregationProcessor { | ||
|
||
private final BucketCollectorProcessor bucketCollectorProcessor = new BucketCollectorProcessor(); | ||
|
||
@Override | ||
public void preProcess(SearchContext context) { | ||
try { | ||
if (context.aggregations() != null) { | ||
// update the bucket collector process as there is aggregation in the request | ||
context.setBucketCollectorProcessor(bucketCollectorProcessor); | ||
if (context.aggregations().factories().hasNonGlobalAggregator()) { | ||
context.queryCollectorManagers().put(NonGlobalAggCollectorManager.class, new NonGlobalAggCollectorManager(context)); | ||
} | ||
// initialize global aggregators as well, such that any failure to initialize can be caught before executing the request | ||
if (context.aggregations().factories().hasGlobalAggregator()) { | ||
context.queryCollectorManagers().put(GlobalAggCollectorManager.class, new GlobalAggCollectorManager(context)); | ||
} | ||
} | ||
} catch (IOException ex) { | ||
throw new AggregationInitializationException("Could not initialize aggregators", ex); | ||
} | ||
} | ||
|
||
@Override | ||
public void postProcess(SearchContext context) { | ||
if (context.aggregations() == null) { | ||
context.queryResult().aggregations(null); | ||
return; | ||
} | ||
|
||
// for concurrent case we will perform only global aggregation in post process as QueryResult is already populated with results of | ||
// processing the non-global aggregation | ||
CollectorManager<? extends Collector, ReduceableSearchResult> globalCollectorManager = context.queryCollectorManagers() | ||
.get(GlobalAggCollectorManager.class); | ||
try { | ||
if (globalCollectorManager != null) { | ||
Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); | ||
if (context.getProfilers() != null) { | ||
globalCollectorManager = new InternalProfileCollectorManager( | ||
globalCollectorManager, | ||
((AggregationCollectorManager) globalCollectorManager).getCollectorReason(), | ||
Collections.emptyList() | ||
); | ||
context.getProfilers().addQueryProfiler().setCollector((InternalProfileComponent) globalCollectorManager); | ||
} | ||
final ReduceableSearchResult result = context.searcher().search(query, globalCollectorManager); | ||
result.reduce(context.queryResult()); | ||
} | ||
} catch (Exception e) { | ||
throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e); | ||
} | ||
|
||
// disable aggregations so that they don't run on next pages in case of scrolling | ||
context.aggregations(null); | ||
context.queryCollectorManagers().remove(NonGlobalAggCollectorManager.class); | ||
context.queryCollectorManagers().remove(GlobalAggCollectorManager.class); | ||
} | ||
} |
180 changes: 180 additions & 0 deletions
180
src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.search.query; | ||
|
||
import lombok.AllArgsConstructor; | ||
import org.apache.lucene.index.IndexReader; | ||
import org.apache.lucene.search.Collector; | ||
import org.apache.lucene.search.CollectorManager; | ||
import org.apache.lucene.search.FieldDoc; | ||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
import org.apache.lucene.search.TotalHits; | ||
import org.opensearch.common.lucene.search.TopDocsAndMaxScore; | ||
import org.opensearch.neuralsearch.search.HitsThresholdChecker; | ||
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; | ||
import org.opensearch.search.DocValueFormat; | ||
import org.opensearch.search.internal.SearchContext; | ||
import org.opensearch.search.query.MultiCollectorWrapper; | ||
import org.opensearch.search.query.QuerySearchResult; | ||
import org.opensearch.search.query.ReduceableSearchResult; | ||
import org.opensearch.search.sort.SortAndFormats; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Collection; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; | ||
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; | ||
|
||
@AllArgsConstructor | ||
public class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> { | ||
|
||
private int numHits; | ||
private FieldDoc searchAfter; | ||
private int hitCountThreshold; | ||
private HitsThresholdChecker hitsThresholdChecker; | ||
private boolean isSingleShard; | ||
private int trackTotalHitsUpTo; | ||
private SortAndFormats sortAndFormats; | ||
|
||
public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) { | ||
final IndexReader reader = searchContext.searcher().getIndexReader(); | ||
final int totalNumDocs = Math.max(0, reader.numDocs()); | ||
FieldDoc searchAfter = searchContext.searchAfter(); | ||
boolean isSingleShard = searchContext.numberOfShards() == 1; | ||
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); | ||
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); | ||
return new HybridCollectorManager( | ||
numDocs, | ||
searchAfter, | ||
Integer.MAX_VALUE, | ||
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), | ||
isSingleShard, | ||
trackTotalHitsUpTo, | ||
searchContext.sort() | ||
); | ||
} | ||
|
||
@Override | ||
public org.apache.lucene.search.Collector newCollector() throws IOException { | ||
HybridTopScoreDocCollector<?> maxScoreCollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker); | ||
return maxScoreCollector; | ||
} | ||
|
||
@Override | ||
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException { | ||
final List<HybridTopScoreDocCollector<?>> hybridTopScoreDocCollectors = new ArrayList<>(); | ||
|
||
for (final Collector collector : collectors) { | ||
if (collector instanceof MultiCollectorWrapper) { | ||
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { | ||
if (sub instanceof HybridTopScoreDocCollector) { | ||
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) sub); | ||
} | ||
} | ||
} else if (collector instanceof HybridTopScoreDocCollector) { | ||
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) collector); | ||
} | ||
} | ||
|
||
if (!hybridTopScoreDocCollectors.isEmpty()) { | ||
HybridTopScoreDocCollector<?> hybridTopScoreDocCollector = hybridTopScoreDocCollectors.get(0); | ||
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs(); | ||
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); | ||
float maxScore = getMaxScore(topDocs); | ||
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); | ||
return (QuerySearchResult result) -> result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); | ||
} | ||
return null; | ||
} | ||
|
||
private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) { | ||
ScoreDoc[] scoreDocs = new ScoreDoc[0]; | ||
if (Objects.nonNull(topDocs)) { | ||
// for a single shard case we need to do score processing at coordinator level. | ||
// this is workaround for current core behaviour, for single shard fetch phase is executed | ||
// right after query phase and processors are called after actual fetch is done | ||
// find any valid doc Id, or set it to -1 if there is not a single match | ||
int delimiterDocId = topDocs.stream() | ||
.filter(Objects::nonNull) | ||
.filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) | ||
.map(topDoc -> topDoc.scoreDocs) | ||
.filter(scoreDoc -> scoreDoc.length > 0) | ||
.map(scoreDoc -> scoreDoc[0].doc) | ||
.findFirst() | ||
.orElse(-1); | ||
if (delimiterDocId == -1) { | ||
return new TopDocs(totalHits, scoreDocs); | ||
} | ||
// format scores using following template: | ||
// doc_id | magic_number_1 | ||
// doc_id | magic_number_2 | ||
// ... | ||
// doc_id | magic_number_2 | ||
// ... | ||
// doc_id | magic_number_2 | ||
// ... | ||
// doc_id | magic_number_1 | ||
List<ScoreDoc> result = new ArrayList<>(); | ||
result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); | ||
for (TopDocs topDoc : topDocs) { | ||
if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { | ||
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); | ||
continue; | ||
} | ||
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); | ||
result.addAll(Arrays.asList(topDoc.scoreDocs)); | ||
} | ||
result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); | ||
scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); | ||
} | ||
return new TopDocs(totalHits, scoreDocs); | ||
} | ||
|
||
private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final boolean isSingleShard) { | ||
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED | ||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO | ||
: TotalHits.Relation.EQUAL_TO; | ||
if (topDocs == null || topDocs.isEmpty()) { | ||
return new TotalHits(0, relation); | ||
} | ||
long maxTotalHits = topDocs.get(0).totalHits.value; | ||
for (TopDocs topDoc : topDocs) { | ||
maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); | ||
} | ||
|
||
return new TotalHits(maxTotalHits, relation); | ||
} | ||
|
||
/*private int totalSize(final List<TopDocs> topDocs) { | ||
int totalSize = 0; | ||
for (TopDocs topDoc : topDocs) { | ||
totalSize += topDoc.totalHits.value + 1; | ||
} | ||
// add 1 qty per each sub-query and + 2 for start and stop delimiters | ||
totalSize += 2; | ||
return totalSize; | ||
}*/ | ||
|
||
private float getMaxScore(final List<TopDocs> topDocs) { | ||
if (topDocs.isEmpty()) { | ||
return 0.0f; | ||
} else { | ||
return topDocs.stream() | ||
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) | ||
.map(scoreDoc -> scoreDoc.score) | ||
.max(Float::compare) | ||
.get(); | ||
} | ||
} | ||
|
||
private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { | ||
return sortAndFormats == null ? null : sortAndFormats.formats; | ||
} | ||
} |
Oops, something went wrong.