Skip to content

Commit

Permalink
Adding simple aggregation processor, custom scenario works, base is f…
Browse files Browse the repository at this point in the history
…ailing

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 27, 2024
1 parent 3d881ac commit dc95b85
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
*/
@Log4j2
public class HybridTopScoreDocCollector implements Collector {
public class HybridTopScoreDocCollector<T extends ScoreDoc> implements Collector {
private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private int docBase;
private final HitsThresholdChecker hitsThresholdChecker;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.search.query;

import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.sort.SortAndFormats;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

@RequiredArgsConstructor
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
private final Optional<Weight> filteringWeightOptional;

public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Weight filterWeight = null;
// check for post filter
if (Objects.nonNull(searchContext.parsedPostFilter())) {
Query filterQuery = searchContext.parsedPostFilter().query();
ContextIndexSearcher searcher = searchContext.searcher();
filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
}

return searchContext.shouldUseConcurrentSearch()
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
);
}

@Override
abstract public Collector newCollector();

Collector getCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
if (filteringWeightOptional.isEmpty()) {
// this is plain hybrid query scores collector
return hybridcollector;
}
// this is hybrid query scores collector with post filter applied
return new FilteredCollector(hybridcollector, filteringWeightOptional.get());
}

@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) {
final List<HybridTopScoreDocCollector<?>> hybridTopScoreDocCollectors = new ArrayList<>();

for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) sub);
}
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) collector);
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) ((FilteredCollector) collector).getCollector());
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
HybridTopScoreDocCollector<?> hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream()
.findFirst()
.orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
}
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
ScoreDoc[] scoreDocs = new ScoreDoc[0];
if (Objects.nonNull(topDocs)) {
// for a single shard case we need to do score processing at coordinator level.
// this is workaround for current core behaviour, for single shard fetch phase is executed
// right after query phase and processors are called after actual fetch is done
// find any valid doc Id, or set it to -1 if there is not a single match
int delimiterDocId = topDocs.stream()
.filter(Objects::nonNull)
.filter(topDoc -> Objects.nonNull(topDoc.scoreDocs))
.map(topDoc -> topDoc.scoreDocs)
.filter(scoreDoc -> scoreDoc.length > 0)
.map(scoreDoc -> scoreDoc[0].doc)
.findFirst()
.orElse(-1);
if (delimiterDocId == -1) {
return new TopDocs(totalHits, scoreDocs);
}
// format scores using following template:
// doc_id | magic_number_1
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_2
// ...
// doc_id | magic_number_1
List<ScoreDoc> result = new ArrayList<>();
result.add(createStartStopElementForHybridSearchResults(delimiterDocId));
for (TopDocs topDoc : topDocs) {
if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) {
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId));
continue;
}
result.add(createDelimiterElementForHybridSearchResults(delimiterDocId));
result.addAll(Arrays.asList(topDoc.scoreDocs));
}
result.add(createStartStopElementForHybridSearchResults(delimiterDocId));
scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
}
return new TopDocs(totalHits, scoreDocs);
}

private TotalHits getTotalHits(int trackTotalHitsUpTo, final List<TopDocs> topDocs, final boolean isSingleShard) {
final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
if (topDocs == null || topDocs.isEmpty()) {
return new TotalHits(0, relation);
}

List<ScoreDoc[]> scoreDocs = topDocs.stream()
.map(topdDoc -> topdDoc.scoreDocs)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Set<Integer> uniqueDocIds = new HashSet<>();
for (ScoreDoc[] scoreDocsArray : scoreDocs) {
uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
}
long maxTotalHits = uniqueDocIds.size();

return new TotalHits(maxTotalHits, relation);
}

private float getMaxScore(final List<TopDocs> topDocs) {
if (topDocs.isEmpty()) {
return 0.0f;
} else {
return topDocs.stream()
.map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
.map(scoreDoc -> scoreDoc.score)
.max(Float::compare)
.get();
}
}

private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
return sortAndFormats == null ? null : sortAndFormats.formats;
}

static class HybridCollectorNonConcurrentManager extends HybridCollectorManager {
Collector maxScoreCollector;

public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
}

@Override
public Collector newCollector() {
if (Objects.isNull(maxScoreCollector)) {
maxScoreCollector = getCollector();
return maxScoreCollector;
} else {
Collector toReturnCollector = maxScoreCollector;
maxScoreCollector = null;
return toReturnCollector;
}
}
}

static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager {

public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
}

@Override
public Collector newCollector() {
return getCollector();
}
}
}
Loading

0 comments on commit dc95b85

Please sign in to comment.