Skip to content

Commit

Permalink
Testing apporach with aggregationProcessor pre and post phases
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 20, 2024
1 parent 633d330 commit 02e0b09
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ private void updateOriginalFetchResults(
// 3. update original scores to normalized and combined values
// 4. order scores based on normalized and combined values
FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
// checking case when results are cached
// boolean requestCache = Objects.isNull(querySearchResults.get(0).getShardSearchRequest())
// || querySearchResults.get(0).getShardSearchRequest().requestCache();
SearchHits searchHits = fetchSearchResult.hits();
// if (requestCache && (searchHits.getHits().length != docIds.size())) {
// return;
// }

SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult);

// create map of docId to index of search hits. This solves (2), duplicates are from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
package org.opensearch.neuralsearch.search.query;

import lombok.AllArgsConstructor;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.opensearch.neuralsearch.util.HybridQueryUtil;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.aggregations.ConcurrentAggregationProcessor;
import org.opensearch.search.aggregations.DefaultAggregationProcessor;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;

@AllArgsConstructor
public class HybridAggregationProcessor implements AggregationProcessor {
Expand All @@ -20,7 +30,22 @@ public class HybridAggregationProcessor implements AggregationProcessor {
@Override
public void preProcess(SearchContext context) {
if (shouldUseMultiCollectorManager(context)) {
concurrentAggregationProcessor.preProcess(context);
if (context.shouldUseConcurrentSearch()) {
concurrentAggregationProcessor.preProcess(context);
} else {
defaultAggregationProcessor.preProcess(context);
}
if (HybridQueryUtil.isHybridQuery(context.query(), context)) {
HybridCollectorManager collectorManager;
try {
collectorManager = HybridCollectorManager.createHybridCollectorManager(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> collectorManagersByManagerClass = context
.queryCollectorManagers();
collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager);
}
} else {
defaultAggregationProcessor.preProcess(context);
}
Expand All @@ -29,7 +54,28 @@ public void preProcess(SearchContext context) {
@Override
public void postProcess(SearchContext context) {
if (shouldUseMultiCollectorManager(context)) {
concurrentAggregationProcessor.postProcess(context);
if (HybridQueryUtil.isHybridQuery(context.query(), context)) {
CollectorManager<?, ReduceableSearchResult> collectorManager = context.queryCollectorManagers()
.get(HybridCollectorManager.class);
if (!context.shouldUseConcurrentSearch()) {
try {
final Collection collectors = List.of(collectorManager.newCollector());
collectorManager.reduce(collectors).reduce(context.queryResult());
} catch (IOException e) {
throw new QueryPhaseExecutionException(
context.shardTarget(),
"failed to execute hybrid query aggregation processor",
e
);
}
}
}
updateQueryResult(context.queryResult(), context);
if (context.shouldUseConcurrentSearch()) {
concurrentAggregationProcessor.postProcess(context);
} else {
defaultAggregationProcessor.postProcess(context);
}
} else {
defaultAggregationProcessor.postProcess(context);
}
Expand All @@ -38,4 +84,11 @@ public void postProcess(SearchContext context) {
private boolean shouldUseMultiCollectorManager(SearchContext context) {
return HybridQueryUtil.isHybridQuery(context.query(), context) || context.shouldUseConcurrentSearch();
}

private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) {
boolean isSingleShard = searchContext.numberOfShards() == 1;
if (isSingleShard) {
searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,29 @@
*/
package org.opensearch.neuralsearch.search.query;

import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.sort.SortAndFormats;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -33,34 +39,69 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

@AllArgsConstructor
@RequiredArgsConstructor
public class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private int numHits;
private HitsThresholdChecker hitsThresholdChecker;
private boolean isSingleShard;
private int trackTotalHitsUpTo;
private SortAndFormats sortAndFormats;
private final int numHits;
private final HitsThresholdChecker hitsThresholdChecker;
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
Collector maxScoreCollector;
final Weight filteringWeight;

public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) {
public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
final int totalNumDocs = Math.max(0, reader.numDocs());
boolean isSingleShard = searchContext.numberOfShards() == 1;
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

// check for post filter
if (Objects.nonNull(searchContext.parsedPostFilter())) {
Query filterQuery = searchContext.parsedPostFilter().query();
ContextIndexSearcher searcher = searchContext.searcher();
final Weight filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
return new HybridCollectorManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
);
}

return new HybridCollectorManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort()
searchContext.sort(),
null
);
}

@Override
public org.apache.lucene.search.Collector newCollector() {
HybridTopScoreDocCollector<?> maxScoreCollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
return maxScoreCollector;
if (Objects.isNull(maxScoreCollector)) {
maxScoreCollector = getCollector();
return maxScoreCollector;
} else {
Collector toReturnCollector = maxScoreCollector;
maxScoreCollector = null;
return toReturnCollector;
}
}

private Collector getCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
if (Objects.isNull(filteringWeight)) {
// this is plain hybrid query scores collector
return hybridcollector;
}
// this is hybrid query scores collector with post filter applied
return new FilteredCollector(hybridcollector, filteringWeight);
}

@Override
Expand All @@ -76,7 +117,10 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
}
} else if (collector instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) collector);
}
} else if (collector instanceof FilteredCollector
&& ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) {
hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector<?>) ((FilteredCollector) collector).getCollector());
}
}

if (!hybridTopScoreDocCollectors.isEmpty()) {
Expand All @@ -87,9 +131,9 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs);
float maxScore = getMaxScore(topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
return (QuerySearchResult result) -> result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats));
return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
}
return null;
throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.util.HybridQueryUtil;
import org.opensearch.search.aggregations.AggregationProcessor;
import org.opensearch.search.aggregations.ConcurrentAggregationProcessor;
import org.opensearch.search.aggregations.DefaultAggregationProcessor;
Expand Down Expand Up @@ -63,11 +62,12 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
if (HybridQueryUtil.isHybridQuery(query, searchContext)) {
/*if (HybridQueryUtil.isHybridQuery(query, searchContext)) {
Query hybridQuery = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
}
validateQuery(searchContext, query);
*/
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT {
private final float[] testVector2 = createRandomVector(TEST_DIMENSION);
private final float[] testVector3 = createRandomVector(TEST_DIMENSION);
private final float[] testVector4 = createRandomVector(TEST_DIMENSION);
private final float[] testVector5 = createRandomVector(TEST_DIMENSION);
private final float[] testVector6 = createRandomVector(TEST_DIMENSION);

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -318,7 +320,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME,
"5",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector4).toArray()),
Collections.singletonList(Floats.asList(testVector5).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT4)
);
Expand Down Expand Up @@ -365,15 +367,15 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
"5",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector4).toArray()),
Collections.singletonList(Floats.asList(testVector5).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT4)
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
"6",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector4).toArray()),
Collections.singletonList(Floats.asList(testVector6).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT5)
);
Expand Down
Loading

0 comments on commit 02e0b09

Please sign in to comment.