From fe4dfae8d0fbb144cf81733a46bdaaad490739f7 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 11 Mar 2024 17:06:07 -0700 Subject: [PATCH] Adding aggregations in hybrid query Signed-off-by: Martin Gaievski --- CHANGELOG.md | 4 +- .../processor/combination/ScoreCombiner.java | 18 +- .../search/HitsThresholdChecker.java | 2 +- .../query/HybridAggregationProcessor.java | 7 +- .../query/HybridQueryPhaseSearcher.java | 49 +- .../neuralsearch/util/HybridQueryUtil.java | 68 +++ .../processor/NormalizationProcessorIT.java | 8 +- .../ScoreCombinationTechniqueTests.java | 2 +- .../neuralsearch/query/HybridQueryIT.java | 544 +++++++++++++++++- .../query/HybridQueryPhaseSearcherTests.java | 80 ++- .../util/AggregationsTestUtils.java | 43 ++ .../neuralsearch/BaseNeuralSearchIT.java | 138 ++++- 12 files changed, 888 insertions(+), 75 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dcdc721b..120640aa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,19 +8,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes - Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) -- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) -- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) - Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578)) - Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615)) ### Infrastructure ### Documentation ### Maintenance -- Added support for jdk-21 ([#500](https://github.com/opensearch-project/neural-search/pull/500))) ### Refactoring ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.12...2.x) ### Features ### Enhancements +- Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630)) ### Bug Fixes - Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index c9e0551e2..278d2fdfc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; @@ -131,13 +132,18 @@ private void updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits)); } + /** + * Get max hits as number of unique doc ids from results of all sub-queries + * @param topDocsPerSubQuery list of topDocs objects for one shard + * @return number of unique doc ids + */ protected int getMaxHits(final List topDocsPerSubQuery) { - int maxHits = 0; - for (TopDocs topDocs : topDocsPerSubQuery) { - int hits = topDocs.scoreDocs.length; - maxHits = Math.max(maxHits, hits); - } - return maxHits; + Set docIds = topDocsPerSubQuery.stream() + .filter(topDocs -> Objects.nonNull(topDocs.scoreDocs)) + .flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs)) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toSet()); + return docIds.size(); } private TotalHits getTotalHits(final List topDocsPerSubQuery, int maxHits) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java index 1299537bb..2e8b365e2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -34,7 +34,7 @@ protected boolean isThresholdReached() { return hitCount >= getTotalHitsThreshold(); } - protected ScoreMode scoreMode() { + public ScoreMode scoreMode() { return ScoreMode.TOP_SCORES; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index 4e9070748..13353614e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -6,6 +6,7 @@ import lombok.AllArgsConstructor; import org.apache.lucene.search.CollectorManager; +import org.opensearch.neuralsearch.util.HybridQueryUtil; import org.opensearch.search.aggregations.AggregationInitializationException; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.SearchContext; @@ -16,8 +17,6 @@ import java.io.IOException; import java.util.List; -import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; - /** * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom * collector manager for hybris query (pre phase) and reducing results (post phase) @@ -31,7 +30,7 @@ public class HybridAggregationProcessor implements AggregationProcessor { public void preProcess(SearchContext context) { delegateAggsProcessor.preProcess(context); - if (isHybridQuery(context.query(), context)) { + if (HybridQueryUtil.isHybridQuery(context.query(), context)) { // adding collector manager for hybrid query CollectorManager collectorManager; try { @@ -45,7 +44,7 @@ public void preProcess(SearchContext context) { @Override public void postProcess(SearchContext context) { - if (isHybridQuery(context.query(), context)) { + if (HybridQueryUtil.isHybridQuery(context.query(), context)) { // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector // managers is not called // (https://github.com/opensearch-project/OpenSearch/blob/2.12/server/src/main/java/org/opensearch/search/query/QueryPhase.java#L333-L373) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 6461c698e..b22059f14 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -11,13 +11,12 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.util.HybridQueryUtil; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; @@ -34,10 +33,6 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { - public HybridQueryPhaseSearcher() { - super(); - } - public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, @@ -46,7 +41,7 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (!isHybridQuery(query, searchContext)) { + if (!HybridQueryUtil.isHybridQuery(query, searchContext)) { validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } else { @@ -55,46 +50,6 @@ public boolean searchWith( } } - @VisibleForTesting - static boolean isHybridQuery(final Query query, final SearchContext searchContext) { - if (query instanceof HybridQuery) { - return true; - } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { - /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code - https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. - main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks - hybrid query for indexes with nested field types. - in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for - this search request. - below is sample structure of such query: - - Boolean { - should: { - hybrid: { - sub_query1 {} - sub_query2 {} - } - } - filter: { - exists: { - field: "_primary_term" - } - } - } - TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ - // we have already checked if query in instance of Boolean in higher level else if condition - return ((BooleanQuery) query).clauses() - .stream() - .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) - .allMatch(clause -> { - return clause.getOccur() == BooleanClause.Occur.FILTER - && clause.getQuery() instanceof FieldExistsQuery - && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); - }); - } - return false; - } - private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java new file mode 100644 index 000000000..930778bf7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.SeqNoFieldMapper; +import org.opensearch.index.search.NestedHelper; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.search.internal.SearchContext; + +/** + * Utility class for anything related to hybrid query + */ +public class HybridQueryUtil { + + public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { + /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code + https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. + main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks + hybrid query for indexes with nested field types. + in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for + this search request. + below is sample structure of such query: + + Boolean { + should: { + hybrid: { + sub_query1 {} + sub_query2 {} + } + } + filter: { + exists: { + field: "_primary_term" + } + } + } + TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ + // we have already checked if query in instance of Boolean in higher level else if condition + return ((BooleanQuery) query).clauses() + .stream() + .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .allMatch(clause -> { + return clause.getOccur() == BooleanClause.Occur.FILTER + && clause.getQuery() instanceof FieldExistsQuery + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); + }); + } + return false; + } + + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); + } + + private static boolean isWrappedHybridQuery(final Query query) { + return query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index b1f0de9d3..e4f2c77ae 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -52,6 +52,8 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT { private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final float[] testVector5 = createRandomVector(TEST_DIMENSION); + private final float[] testVector6 = createRandomVector(TEST_DIMENSION); @Before public void setUp() throws Exception { @@ -318,7 +320,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, "5", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector5).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT4) ); @@ -365,7 +367,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, "5", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector5).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT4) ); @@ -373,7 +375,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, "6", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector6).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT5) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index d2c1ddb4f..4f76c666e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -63,7 +63,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); - assertEquals(3, queryTopDocs.get(0).getScoreDocs().size()); + assertEquals(5, queryTopDocs.get(0).getScoreDocs().size()); assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(0).score, DELTA_FOR_SCORE_ASSERTION); assertEquals(1, queryTopDocs.get(0).getScoreDocs().get(0).doc); assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(1).score, DELTA_FOR_SCORE_ASSERTION); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 487a378df..70772456a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -12,6 +12,10 @@ import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; import java.io.IOException; import java.util.ArrayList; @@ -28,6 +32,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; @@ -35,6 +40,15 @@ import com.google.common.primitives.Floats; import lombok.SneakyThrows; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MinBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder; public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-hybrid-basic-index"; @@ -43,6 +57,9 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-hybrid-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-neural-aggs-pipeline-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -63,6 +80,42 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "People keep telling me orange but I still prefer pink"; + private static final String TEST_DOC_TEXT6 = "She traveled because it cost the same as therapy and was a lot more enjoyable"; + private static final String INTEGER_FIELD_1 = "doc_index"; + private static final int INTEGER_FIELD_1_VALUE = 1234; + private static final int INTEGER_FIELD_2_VALUE = 2345; + private static final int INTEGER_FIELD_3_VALUE = 3456; + private static final int INTEGER_FIELD_4_VALUE = 4567; + private static final String KEYWORD_FIELD_1 = "doc_keyword"; + private static final String KEYWORD_FIELD_1_VALUE = "workable"; + private static final String KEYWORD_FIELD_2_VALUE = "angry"; + private static final String KEYWORD_FIELD_3_VALUE = "likeable"; + private static final String KEYWORD_FIELD_4_VALUE = "entire"; + private static final String DATE_FIELD_1 = "doc_date"; + private static final String DATE_FIELD_1_VALUE = "01/03/1995"; + private static final String DATE_FIELD_2_VALUE = "05/02/2015"; + private static final String DATE_FIELD_3_VALUE = "07/23/2007"; + private static final String DATE_FIELD_4_VALUE = "08/21/2012"; + private static final String INTEGER_FIELD_PRICE = "doc_price"; + private static final int INTEGER_FIELD_PRICE_1_VALUE = 130; + private static final int INTEGER_FIELD_PRICE_2_VALUE = 100; + private static final int INTEGER_FIELD_PRICE_3_VALUE = 200; + private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; + private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; + private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; + private static final String BUCKET_AGG_DOC_COUNT_FIELD = "doc_count"; + private static final String KEY = "key"; + private static final String BUCKET_AGG_KEY_AS_STRING = "key_as_string"; + private static final String SUM_AGGREGATION_NAME = "sum_aggs"; + private static final String MAX_AGGREGATION_NAME = "max_aggs"; + private static final String DATE_AGGREGATION_NAME = "date_aggregation"; + private static final String GENERIC_AGGREGATION_NAME = "my_aggregation"; + private static final String BUCKETS_AGGREGATION_NAME_1 = "date_buckets_1"; + private static final String BUCKETS_AGGREGATION_NAME_2 = "date_buckets_2"; + private static final String BUCKETS_AGGREGATION_NAME_3 = "date_buckets_3"; + private static final String BUCKETS_AGGREGATION_NAME_4 = "date_buckets_4"; @Before public void setUp() throws Exception { @@ -362,10 +415,8 @@ public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { @SneakyThrows public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess() { - String modelId = null; try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); - modelId = prepareModel(); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); NestedQueryBuilder nestedQueryBuilder = QueryBuilders.nestedQuery( @@ -395,7 +446,200 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, null, modelId, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testAvgSumMinMaxAggs(); + } + + @SneakyThrows + public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testAvgSumMinMaxAggs(); + } + + @SneakyThrows + public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testMaxAggsOnSingleShardCluster(); + } + + @SneakyThrows + public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testMaxAggsOnSingleShardCluster(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testDateRange(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testDateRange(); + } + + @SneakyThrows + public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.sampler(GENERIC_AGGREGATION_NAME) + .shardSize(2) + .subAggregation(AggregationBuilders.terms(BUCKETS_AGGREGATION_NAME_1).field(KEYWORD_FIELD_1)); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggsBuilder), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(2, aggValue.size()); + assertEquals(3, aggValue.get(BUCKET_AGG_DOC_COUNT_FIELD)); + Map nestedAggs = getAggregationValues(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertNotNull(nestedAggs); + assertEquals(0, nestedAggs.get("doc_count_error_upper_bound")); + List> buckets = getAggregationBuckets(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertEquals(2, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("likeable", firstBucket.get(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("workable", secondBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void testAvgSumMinMaxAggs() { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.dateHistogram(GENERIC_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD_1) + .subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_1)); + + BucketMetricsPipelineAggregationBuilder aggAvgBucket = PipelineAggregatorBuilders + .avgBucket(BUCKETS_AGGREGATION_NAME_1, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggSumBucket = PipelineAggregatorBuilders + .sumBucket(BUCKETS_AGGREGATION_NAME_2, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMinBucket = PipelineAggregatorBuilders + .minBucket(BUCKETS_AGGREGATION_NAME_3, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMaxBucket = PipelineAggregatorBuilders + .maxBucket(BUCKETS_AGGREGATION_NAME_4, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + Map searchResponseAsMapAnngsBoolQuery = executeQueryAndGetAggsResults( + List.of(aggsBuilder, aggAvgBucket, aggSumBucket, aggMinBucket, aggMaxBucket), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + assertResultsOfPipelineSumtoDateHistogramAggs(searchResponseAsMapAnngsBoolQuery); + + // test only aggregation without query (handled as match_all query) + Map searchResponseAsMapAggsNoQuery = executeQueryAndGetAggsResults( + List.of(aggsBuilder, aggAvgBucket), + null, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 6 + ); + + assertResultsOfPipelineSumtoDateHistogramAggsForMatchAllQuery(searchResponseAsMapAggsNoQuery); + + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testMaxAggsOnSingleShardCluster() throws Exception { + try { + prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + AggregationBuilder aggsBuilder = AggregationBuilders.max(MAX_AGGREGATION_NAME).field(INTEGER_FIELD_1); + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + List.of(aggsBuilder) + ); + + assertHitResultsFromQuery(2, searchResponseAsMap); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(MAX_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, MAX_AGGREGATION_NAME); + assertTrue(maxAggsValue >= 0); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + private void testDateRange() throws IOException { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + // try { + // prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.dateRange(DATE_AGGREGATION_NAME) + .field(DATE_FIELD_1) + .format("MM-yyyy") + .addRange("01-2014", "02-2024"); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggsBuilder), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(1, buckets.size()); + + Map bucket = buckets.get(0); + + assertEquals(6, bucket.size()); + assertEquals("01-2014", bucket.get("from_as_string")); + assertEquals(2, bucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("02-2024", bucket.get("to_as_string")); + assertTrue(bucket.containsKey("from")); + assertTrue(bucket.containsKey("to")); + assertTrue(bucket.containsKey(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); } } @@ -490,6 +734,157 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) ); } + + if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(DATE_FIELD_1), 3), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_1_VALUE, INTEGER_FIELD_PRICE_1_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_1_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_1_VALUE) + ); + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_2_VALUE, INTEGER_FIELD_PRICE_2_VALUE), + List.of(), + List.of(), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_2_VALUE) + ); + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_3_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_2_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_3_VALUE) + ); + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_4_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_3_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_2_VALUE) + ); + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_5_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_4_VALUE) + ); + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT6), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_4_VALUE, INTEGER_FIELD_PRICE_6_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_4_VALUE) + ); + } + } + + @SneakyThrows + private void initializeIndexWithOneShardIfNotExists(String indexName) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(), 1), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_1_VALUE), + List.of(), + List.of(), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_2_VALUE), + List.of(), + List.of(), + List.of(), + List.of() + ); + } } private void addDocsToIndex(final String testMultiDocIndexName) { @@ -532,4 +927,147 @@ private Optional getMaxScore(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } + + @SneakyThrows + void prepareResources(String indexName, String pipelineName) { + initializeIndexIfNotExist(indexName); + createSearchPipelineWithResultsPostProcessor(pipelineName); + } + + @SneakyThrows + void prepareResourcesForSingleShardIndex(String indexName, String pipelineName) { + initializeIndexWithOneShardIfNotExists(indexName); + createSearchPipelineWithResultsPostProcessor(pipelineName); + } + + private void assertResultsOfPipelineSumtoDateHistogramAggs(Map searchResponseAsMap) { + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + double aggValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_1); + assertEquals(3517.5, aggValue, DELTA_FOR_SCORE_ASSERTION); + + double sumValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_2); + assertEquals(7035.0, sumValue, DELTA_FOR_SCORE_ASSERTION); + + double minValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_3); + assertEquals(1234.0, minValue, DELTA_FOR_SCORE_ASSERTION); + + double maxValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_4); + assertEquals(5801.0, maxValue, DELTA_FOR_SCORE_ASSERTION); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } + + private void assertResultsOfPipelineSumtoDateHistogramAggsForMatchAllQuery(Map searchResponseAsMap) { + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + double aggValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_1); + assertEquals(3764.5, aggValue, DELTA_FOR_SCORE_ASSERTION); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } + + private Map executeQueryAndGetAggsResults(final List aggsBuilders, String indexName, int expectedHitsNumber) { + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + return executeQueryAndGetAggsResults(aggsBuilders, hybridQueryBuilderNeuralThenTerm, indexName, expectedHitsNumber); + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName, + int expectedHits + ) { + Map searchResponseAsMap = search( + indexName, + queryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilders + ); + + assertHitResultsFromQuery(expectedHits, searchResponseAsMap); + return searchResponseAsMap; + } + + private void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { + assertEquals(expected, getHitCount(searchResponseAsMap)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expected, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 2aebbb5d8..055301832 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -66,12 +66,12 @@ import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QuerySearchResult; import com.carrotsearch.randomizedtesting.RandomizedTest; import lombok.SneakyThrows; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -810,6 +810,82 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); } + @SneakyThrows + public void testAggregations_whenMetricAggregation_thenSuccessful() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); diff --git a/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java b/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java new file mode 100644 index 000000000..fbb53a918 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import java.util.List; +import java.util.Map; + +/** + * Util class for routines associated with aggregations testing + */ +public class AggregationsTestUtils { + + public static List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + public static Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + public static Map getAggregations(final Map searchResponseAsMap) { + Map aggsMap = (Map) searchResponseAsMap.get("aggregations"); + return aggsMap; + } + + public static T getAggregationValue(final Map aggsMap, final String aggName) { + Map aggValues = (Map) aggsMap.get(aggName); + return (T) aggValues.get("value"); + } + + public static T getAggregationBuckets(final Map aggsMap, final String aggName) { + Map aggValues = (Map) aggsMap.get(aggName); + return (T) aggValues.get("buckets"); + } + + public static T getAggregationValues(final Map aggsMap, final String aggName) { + return (T) aggsMap.get(aggName); + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index ffbbed2bc..622327fa7 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -413,14 +413,54 @@ protected Map search( final int resultSize, final Map requestParams ) { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); - queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + return search(index, queryBuilder, rescorer, resultSize, requestParams, null); + } + + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + List aggs + ) { + return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null); + } + + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + List aggs, + QueryBuilder postFilterBuilder + ) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + + if (queryBuilder != null) { + builder.field("query"); + queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } if (rescorer != null) { builder.startObject("rescore").startObject("query").field("query_weight", 0.0f).field("rescore_query"); rescorer.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject().endObject(); } + if (Objects.nonNull(aggs)) { + builder.startObject("aggs"); + for (Object agg : aggs) { + builder.value(agg); + } + builder.endObject(); + } + if (Objects.nonNull(postFilterBuilder)) { + builder.field("post_filter"); + postFilterBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } builder.endObject(); @@ -463,6 +503,35 @@ protected void addKnnDoc( addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts, + List nestedFieldNames, + List> nestedFields + ) { + addKnnDoc( + index, + docId, + vectorFieldNames, + vectors, + textFieldNames, + texts, + nestedFieldNames, + nestedFields, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList() + ); + } + /** * Add a set of knn vectors and text to an index * @@ -484,7 +553,13 @@ protected void addKnnDoc( final List textFieldNames, final List texts, final List nestedFieldNames, - final List> nestedFields + final List> nestedFields, + final List integerFieldNames, + final List integerFieldValues, + final List keywordFieldNames, + final List keywordFieldValues, + final List dateFieldNames, + final List dateFieldValues ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -505,6 +580,18 @@ protected void addKnnDoc( } builder.endObject(); } + + for (int i = 0; i < integerFieldNames.size(); i++) { + builder.field(integerFieldNames.get(i), integerFieldValues.get(i)); + } + + for (int i = 0; i < keywordFieldNames.size(); i++) { + builder.field(keywordFieldNames.get(i), keywordFieldValues.get(i)); + } + + for (int i = 0; i < dateFieldNames.size(); i++) { + builder.field(dateFieldNames.get(i), dateFieldValues.get(i)); + } builder.endObject(); request.setJsonEntity(builder.toString()); @@ -667,6 +754,25 @@ protected String buildIndexConfiguration( final List knnFieldConfigs, final List nestedFields, final int numberOfShards + ) { + return buildIndexConfiguration( + knnFieldConfigs, + nestedFields, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + numberOfShards + ); + } + + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final List intFields, + final List keywordFields, + final List dateFields, + final int numberOfShards ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() @@ -688,9 +794,31 @@ protected String buildIndexConfiguration( .endObject() .endObject(); } + // treat the list in a manner that first element is always the type name and all others are keywords + if (!nestedFields.isEmpty()) { + String nestedFieldName = nestedFields.get(0); + xContentBuilder.startObject(nestedFieldName).field("type", "nested"); + if (nestedFields.size() > 1) { + xContentBuilder.startObject("properties"); + for (int i = 1; i < nestedFields.size(); i++) { + String innerNestedTypeField = nestedFields.get(i); + xContentBuilder.startObject(innerNestedTypeField).field("type", "keyword").endObject(); + } + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + + for (String intField : intFields) { + xContentBuilder.startObject(intField).field("type", "integer").endObject(); + } + + for (String keywordField : keywordFields) { + xContentBuilder.startObject(keywordField).field("type", "keyword").endObject(); + } - for (String nestedField : nestedFields) { - xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); + for (String dateField : dateFields) { + xContentBuilder.startObject(dateField).field("type", "date").field("format", "MM/dd/yyyy").endObject(); } xContentBuilder.endObject().endObject().endObject();