From 1f1f2b7b5c5d85c38a1a97692bf08ec6939c625b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 17 Jan 2024 10:51:09 -0800 Subject: [PATCH] Collector manager approach, initial POC version Signed-off-by: Martin Gaievski --- build.gradle | 2 + .../search/HybridTopScoreDocCollector.java | 16 +- .../query/HybridAggregationProcessor.java | 83 +++++++ .../search/query/HybridCollectorManager.java | 221 ++++++++++++++++++ .../query/HybridQueryPhaseSearcher.java | 79 +++++-- .../neuralsearch/query/HybridQueryIT.java | 204 ++++++++++++---- .../HybridTopScoreDocCollectorTests.java | 10 +- .../query/HybridQueryPhaseSearcherTests.java | 23 +- .../neuralsearch/BaseNeuralSearchIT.java | 62 ++--- 9 files changed, 585 insertions(+), 115 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java diff --git a/build.gradle b/build.gradle index ac0ebaa9d..c2b146156 100644 --- a/build.gradle +++ b/build.gradle @@ -285,6 +285,8 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx4g") + + systemProperty('opensearch.experimental.feature.concurrent_segment_search.enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 6bc3c7be0..731320e10 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,7 +8,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -27,13 +26,12 @@ import lombok.Getter; import lombok.extern.log4j.Log4j2; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; /** * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results */ @Log4j2 -public class HybridTopScoreDocCollector implements Collector { +public class HybridTopScoreDocCollector implements Collector { private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); private int docBase; private final HitsThresholdChecker hitsThresholdChecker; @@ -58,10 +56,10 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept @Override public void setScorer(Scorable scorer) throws IOException { super.setScorer(scorer); + // compoundQueryScorer = (HybridQueryScorer) scorer; if (scorer instanceof HybridQueryScorer) { compoundQueryScorer = (HybridQueryScorer) scorer; - } - else { + } else { compoundQueryScorer = getHybridQueryScorer(scorer); } } @@ -82,14 +80,12 @@ private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOE return null; } - - - @Override + @Override public void collect(int doc) throws IOException { - if (compoundQueryScorer == null) { + /*if (compoundQueryScorer == null) { scorer.score(); return; - } + }*/ float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java new file mode 100644 index 000000000..1161784c3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.opensearch.common.lucene.search.Queries; +import org.opensearch.search.aggregations.AggregationCollectorManager; +import org.opensearch.search.aggregations.AggregationInitializationException; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.aggregations.BucketCollectorProcessor; +import org.opensearch.search.aggregations.GlobalAggCollectorManager; +import org.opensearch.search.aggregations.NonGlobalAggCollectorManager; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.InternalProfileCollectorManager; +import org.opensearch.search.profile.query.InternalProfileComponent; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collections; + +public class HybridAggregationProcessor implements AggregationProcessor { + + private final BucketCollectorProcessor bucketCollectorProcessor = new BucketCollectorProcessor(); + + @Override + public void preProcess(SearchContext context) { + try { + if (context.aggregations() != null) { + // update the bucket collector process as there is aggregation in the request + context.setBucketCollectorProcessor(bucketCollectorProcessor); + if (context.aggregations().factories().hasNonGlobalAggregator()) { + context.queryCollectorManagers().put(NonGlobalAggCollectorManager.class, new NonGlobalAggCollectorManager(context)); + } + // initialize global aggregators as well, such that any failure to initialize can be caught before executing the request + if (context.aggregations().factories().hasGlobalAggregator()) { + context.queryCollectorManagers().put(GlobalAggCollectorManager.class, new GlobalAggCollectorManager(context)); + } + } + } catch (IOException ex) { + throw new AggregationInitializationException("Could not initialize aggregators", ex); + } + } + + @Override + public void postProcess(SearchContext context) { + if (context.aggregations() == null) { + context.queryResult().aggregations(null); + return; + } + + // for concurrent case we will perform only global aggregation in post process as QueryResult is already populated with results of + // processing the non-global aggregation + CollectorManager globalCollectorManager = context.queryCollectorManagers() + .get(GlobalAggCollectorManager.class); + try { + if (globalCollectorManager != null) { + Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); + if (context.getProfilers() != null) { + globalCollectorManager = new InternalProfileCollectorManager( + globalCollectorManager, + ((AggregationCollectorManager) globalCollectorManager).getCollectorReason(), + Collections.emptyList() + ); + context.getProfilers().addQueryProfiler().setCollector((InternalProfileComponent) globalCollectorManager); + } + final ReduceableSearchResult result = context.searcher().search(query, globalCollectorManager); + result.reduce(context.queryResult()); + } + } catch (Exception e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e); + } + + // disable aggregations so that they don't run on next pages in case of scrolling + context.aggregations(null); + context.queryCollectorManagers().remove(NonGlobalAggCollectorManager.class); + context.queryCollectorManagers().remove(GlobalAggCollectorManager.class); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..4cecdb1b3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,221 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AllArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +@AllArgsConstructor +public class HybridCollectorManager implements CollectorManager { + + private int numHits; + private FieldDoc searchAfter; + private int hitCountThreshold; + private HitsThresholdChecker hitsThresholdChecker; + private boolean isSingleShard; + private int trackTotalHitsUpTo; + + public static HybridCollectorManager createHybridCollectorManager(final SearchContext searchContext) { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + FieldDoc searchAfter = searchContext.searchAfter(); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + return new HybridCollectorManager( + numDocs, + searchAfter, + Integer.MAX_VALUE, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo + ); + } + + @Override + public org.apache.lucene.search.Collector newCollector() throws IOException { + HybridTopScoreDocCollector maxScoreCollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + + // return MultiCollectorWrapper.wrap(manager.newCollector(), maxScoreCollector); + return maxScoreCollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + final List> hybridTopScoreDocCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.get(0); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, null); }; + } + return null; + } + + /* protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + return (QuerySearchResult result) -> { + final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter); + result.topDocs(topDocsAndMaxScore, null); + }; + } + */ + + protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + return (QuerySearchResult result) -> { + final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter); + result.topDocs(topDocsAndMaxScore, null); + }; + } + + TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) { + TotalHits totalHits = topDocs.totalHits; + + // Since we cannot support early forced termination, we have to simulate it by + // artificially reducing the number of total hits and doc scores. + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + if (terminatedAfter != null) { + if (totalHits.value > terminatedAfter) { + totalHits = new TotalHits(terminatedAfter, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); + } + + if (scoreDocs != null && scoreDocs.length > terminatedAfter) { + scoreDocs = Arrays.copyOf(scoreDocs, terminatedAfter); + } + } + + final TopDocs newTopDocs; + if (topDocs instanceof TopFieldDocs) { + TopFieldDocs fieldDocs = (TopFieldDocs) topDocs; + newTopDocs = new TopFieldDocs(totalHits, scoreDocs, fieldDocs.fields); + } else { + newTopDocs = new TopDocs(totalHits, scoreDocs); + } + + if (Float.isNaN(maxScore) && newTopDocs.scoreDocs.length > 0) { + return new TopDocsAndMaxScore(newTopDocs, newTopDocs.scoreDocs[0].score); + } else { + return new TopDocsAndMaxScore(newTopDocs, maxScore); + } + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + long maxTotalHits = topDocs.get(0).totalHits.value; + for (TopDocs topDoc : topDocs) { + maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); + } + + return new TotalHits(maxTotalHits, relation); + } + + private int totalSize(final List topDocs) { + int totalSize = 0; + for (TopDocs topDoc : topDocs) { + totalSize += topDoc.totalHits.value + 1; + } + // add 1 qty per each sub-query and + 2 for start and stop delimiters + totalSize += 2; + return totalSize; + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 616b1a652..780be4a9f 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -13,13 +13,15 @@ import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; -import com.google.common.base.Throwables; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; @@ -33,15 +35,18 @@ import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.aggregations.GlobalAggCollectorManager; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QueryCollectorManagerContext; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.query.TopDocsCollectorContext; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; @@ -57,6 +62,9 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + private final AggregationProcessor aggregationProcessor = new HybridAggregationProcessor(); + // private final AggregationProcessor aggregationProcessor = new DefaultAggregationProcessor(); + public HybridQueryPhaseSearcher() { super(); } @@ -194,6 +202,11 @@ protected boolean searchWithCollector( ) throws IOException { log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + HybridCollectorManager collectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + Map, CollectorManager> collectorManagers = searchContext + .queryCollectorManagers(); + collectorManagers.put(HybridCollectorManager.class, collectorManager); + final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); collectors.addFirst(topDocsFactory); if (searchContext.size() == 0) { @@ -213,43 +226,66 @@ protected boolean searchWithCollector( final QuerySearchResult queryResult = searchContext.queryResult(); - Collector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); + // Map, CollectorManager> queryCollectorManagers = + // searchContext.queryCollectorManagers(); + // CollectorManager collectorManager =; + // List> managers = + // List.of(QueryCollectorManagerContext.createQueryCollectorManager(collectorsOurs), + // QueryCollectorManagerContext.createQueryCollectorManager(collectors)); + // final CollectorManager multiCollectorManager = + // QueryCollectorManagerContext.createMultiCollectorManager(managers); // cannot use streams here as assigment of global variable inside the lambda will not be possible - for (int idx = 1; idx < collectors.size(); idx++) { + /*for (int idx = 1; idx < collectors.size(); idx++) { QueryCollectorContext collectorContext = collectors.get(idx); collector = collectorContext.create(collector); - } - - searcher.search(query, collector); + }*/ + /*Map, CollectorManager> queryCollectorManagers = searchContext.queryCollectorManagers(); + queryCollectorManagers.values().stream().forEach(e -> { + try { + e.newCollector(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + });*/ + + final List> managersExceptGlobalAgg = collectorManagers.entrySet() + .stream() + .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) + .map(Map.Entry::getValue) + .collect(Collectors.toList()); + final ReduceableSearchResult result = searcher.search( + query, + QueryCollectorManagerContext.createMultiCollectorManager(managersExceptGlobalAgg) + ); + // searcher.search(query, QueryCollectorManagerContext.createQueryCollectorManager(collectors)); if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { queryResult.terminatedEarly(false); } + result.reduce(queryResult); - setTopDocsInQueryResult(queryResult, collector, searchContext); + updateQueryResult(queryResult, searchContext); - collectors.stream().skip(1).forEach(ctx -> { - try { - ctx.postProcess(queryResult); - } catch (IOException e) { - Throwables.throwIfUnchecked(e); - } - }); + // setTopDocsInQueryResult(queryResult, collector, searchContext); return shouldRescore; } + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } + } + private void setTopDocsInQueryResult( final QuerySearchResult queryResult, final Collector collector, final SearchContext searchContext ) { if (collector instanceof HybridTopScoreDocCollector) { - List topDocs = ((HybridTopScoreDocCollector) collector).topDocs(); + List topDocs = ((HybridTopScoreDocCollector) collector).topDocs(); float maxScore = getMaxScore(topDocs); boolean isSingleShard = searchContext.numberOfShards() == 1; TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); @@ -352,4 +388,9 @@ private int getMaxDepthLimit(final SearchContext searchContext) { Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); } + + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + return aggregationProcessor; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 2f279a4cc..78aa6b9ab 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -46,8 +46,10 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-neural-multi-doc-nested-type--single-shard-index"; - private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT = - "test-neural-multi-doc-text-and-int-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = + "test-neural-multi-doc-text-and-int-index-single-shard"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-neural-multi-doc-text-and-int-index-multiple-shards"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -56,6 +58,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_DOC_TEXT1 = "Hello world"; private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; @@ -67,9 +70,12 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String INTEGER_FIELD_1 = "doc_index"; private static final int INTEGER_FIELD_1_VALUE = 1234; private static final int INTEGER_FIELD_2_VALUE = 2345; + private static final int INTEGER_FIELD_3_VALUE = 3456; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private static final String MAX_AGGREGATION_NAME = "max_aggs"; + private static final String AVG_AGGREGATION_NAME = "avg_field"; private static final String SEARCH_PIPELINE = "phase-results-pipeline"; @Before @@ -414,8 +420,8 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( * } */ @SneakyThrows - public void testAggregations_whenMetricAggregationsInQuery_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT); + public void testAggregationsSingleShard_whenMetricAggregationsInQuery_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); @@ -424,15 +430,14 @@ public void testAggregations_whenMetricAggregationsInQuery_thenSuccessful() { hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); - AggregationBuilder aggsBuilder = AggregationBuilders.max("max_aggs").field(INTEGER_FIELD_1); - //AggregationBuilder aggsBuilder = null; + AggregationBuilder aggsBuilder = AggregationBuilders.max(MAX_AGGREGATION_NAME).field(INTEGER_FIELD_1); Map searchResponseAsMap1 = search( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, - hybridQueryBuilderNeuralThenTerm, - null, - 10, - Map.of("search_pipeline", SEARCH_PIPELINE), - aggsBuilder + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilder ); assertEquals(1, getHitCount(searchResponseAsMap1)); @@ -455,6 +460,61 @@ public void testAggregations_whenMetricAggregationsInQuery_thenSuccessful() { assertEquals(1, total.get("value")); assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + Map aggregations = getAggregations(searchResponseAsMap1); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(MAX_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, MAX_AGGREGATION_NAME); + assertTrue(maxAggsValue >= 0); + } + + @SneakyThrows + public void testAggregationsMultipleShards_whenMetricAggregationsInQuery_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + AggregationBuilder aggsBuilder = AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1); + Map searchResponseAsMap1 = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilder + ); + + assertEquals(2, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + Map aggregations = getAggregations(searchResponseAsMap1); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(AVG_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, AVG_AGGREGATION_NAME); + assertEquals(maxAggsValue, 2345.0, DELTA_FOR_SCORE_ASSERTION); } @SneakyThrows @@ -549,43 +609,91 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { ); } - if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT.equals(indexName) - && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT)) { - createIndexWithConfiguration( - indexName, - buildIndexConfiguration( - List.of(), - List.of(), - List.of(INTEGER_FIELD_1), - 1 - ), - "" + if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD)) { + createIndexWithConfiguration(indexName, buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), 1), ""); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_1_VALUE) ); addKnnDoc( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, - "1", - List.of(), - List.of(), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT1), - List.of(), - List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_1_VALUE) + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_2_VALUE) ); + } + + if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS)) { + createIndexWithConfiguration(indexName, buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), 3), ""); addKnnDoc( - TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, - "2", - List.of(), - List.of(), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT3), - List.of(), - List.of(), - List.of(INTEGER_FIELD_1), - List.of(INTEGER_FIELD_2_VALUE) + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_1_VALUE) + ); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_2_VALUE) + ); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + "3", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2), + List.of(), + List.of(), + List.of(), + List.of() + ); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + "4", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_3_VALUE) ); } } @@ -630,4 +738,14 @@ private Optional getMaxScore(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } + + private Map getAggregations(final Map searchResponseAsMap) { + Map aggsMap = (Map) searchResponseAsMap.get("aggregations"); + return aggsMap; + } + + private T getAggregationValue(final Map aggsMap, final String aggName) { + Map aggValues = (Map) aggsMap.get(aggName); + return (T) aggValues.get("value"); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index b67a1ee05..7ee30805b 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -76,7 +76,7 @@ public void testBasics_whenCreateNewCollector_thenSuccessful() { LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( NUM_DOCS, new HitsThresholdChecker(TOTAL_HITS_UP_TO) ); @@ -111,7 +111,7 @@ public void testGetHybridScores_whenCreateNewAndGetScores_thenSuccessful() { LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector<>( NUM_DOCS, new HitsThresholdChecker(TOTAL_HITS_UP_TO) ); @@ -162,7 +162,7 @@ public void testTopDocs_whenCreateNewAndGetTopDocs_thenSuccessful() { LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector<>( NUM_DOCS, new HitsThresholdChecker(TOTAL_HITS_UP_TO) ); @@ -244,7 +244,7 @@ public void testTopDocs_whenMatchedDocsDifferentForEachSubQuery_thenSuccessful() LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector<>( NUM_DOCS, new HitsThresholdChecker(TOTAL_HITS_UP_TO) ); @@ -367,7 +367,7 @@ public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector<>( NUM_DOCS, new HitsThresholdChecker(Integer.MAX_VALUE) ); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 56bce8962..82d8ccb8f 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -63,7 +63,6 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; -import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import com.carrotsearch.randomizedtesting.RandomizedTest; @@ -860,21 +859,21 @@ public void testAggregations_whenMetricAggregation_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - null, - searchContext + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext ); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( - randomAlphaOfLength(10), - shardId, - randomAlphaOfLength(10), - OriginalIndices.NONE + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE ); when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index f8d274a15..2ab61c7d9 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -5,7 +5,6 @@ package org.opensearch.neuralsearch; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; -import static org.opensearch.search.aggregations.Aggregations.AGGREGATIONS_FIELD; import java.io.IOException; import java.nio.file.Files; @@ -382,12 +381,12 @@ protected Map search( @SneakyThrows protected Map search( - String index, - QueryBuilder queryBuilder, - QueryBuilder rescorer, - int resultSize, - Map requestParams, - AggregationBuilder aggsBuilder + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + AggregationBuilder aggsBuilder ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -399,13 +398,13 @@ protected Map search( } if (aggsBuilder != null) { builder.startObject("aggs").value(aggsBuilder).endObject(); - /*builder.startObject("aggs") - .startObject("max_index") - .startObject("max") - .field("field", "test-text-field-1") - .endObject() - .endObject() - .endObject();*/ + /*builder.startObject("aggs") + .startObject("max_index") + .startObject("max") + .field("field", "test-text-field-1") + .endObject() + .endObject() + .endObject();*/ } builder.endObject(); @@ -451,16 +450,27 @@ protected void addKnnDoc( @SneakyThrows protected void addKnnDoc( - String index, - String docId, - List vectorFieldNames, - List vectors, - List textFieldNames, - List texts, - List nestedFieldNames, - List> nestedFields + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts, + List nestedFieldNames, + List> nestedFields ) { - addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, nestedFieldNames, nestedFields, Collections.emptyList(), Collections.emptyList()); + addKnnDoc( + index, + docId, + vectorFieldNames, + vectors, + textFieldNames, + texts, + nestedFieldNames, + nestedFields, + Collections.emptyList(), + Collections.emptyList() + ); } /** @@ -599,9 +609,9 @@ protected String buildIndexConfiguration(final List knnFieldConf @SneakyThrows protected String buildIndexConfiguration( - final List knnFieldConfigs, - final List nestedFields, - final int numberOfShards + final List knnFieldConfigs, + final List nestedFields, + final int numberOfShards ) { return buildIndexConfiguration(knnFieldConfigs, nestedFields, Collections.emptyList(), numberOfShards); }