From c81668c917323b928afcf13ddd3d1db644057680 Mon Sep 17 00:00:00 2001 From: Sorabh Date: Tue, 6 Jun 2023 09:33:34 -0700 Subject: [PATCH] [Concurrent Segment Search]: Implement concurrent aggregations support without profile option (#7514) * Refactoring of AggregationReduceContext to use in SearchContext. This will be used for performing shard level reduce of aggregations during concurrent segment search usecase Signed-off-by: Sorabh Hamirwasia * Support for non global aggregations with concurrent segment search. This PR does not include the support for profile option with aggregations to work with concurrent model Signed-off-by: Sorabh Hamirwasia * Implement AggregationCollectorManager's reduce Signed-off-by: Andriy Redko * Use CollectorManager for both concurrent and non concurrent use case Add CollectorManager for Global Aggregations to support concurrent use case Signed-off-by: Sorabh Hamirwasia * Address review comments Signed-off-by: Sorabh Hamirwasia --------- Signed-off-by: Sorabh Hamirwasia Signed-off-by: Andriy Redko Co-authored-by: Andriy Redko --- CHANGELOG.md | 1 + .../metrics/ScriptedMetricIT.java | 20 +- .../action/search/SearchPhaseController.java | 6 +- .../action/search/TransportSearchAction.java | 2 +- .../search/DefaultSearchContext.java | 13 +- .../org/opensearch/search/SearchModule.java | 9 +- .../org/opensearch/search/SearchService.java | 18 +- .../AggregationCollectorManager.java | 115 +++++++++++ .../search/aggregations/AggregationPhase.java | 184 ------------------ .../aggregations/AggregationProcessor.java | 29 +++ .../AggregationReduceableSearchResult.java | 39 ++++ .../aggregations/AggregatorFactories.java | 48 ++++- .../ConcurrentAggregationProcessor.java | 84 ++++++++ .../DefaultAggregationProcessor.java | 90 +++++++++ .../GlobalAggCollectorManager.java | 41 ++++ ...ggCollectorManagerWithSingleCollector.java | 45 +++++ .../aggregations/InternalAggregations.java | 9 + .../aggregations/MultiBucketCollector.java | 4 + .../NonGlobalAggCollectorManager.java | 41 ++++ ...ggCollectorManagerWithSingleCollector.java | 45 +++++ .../SearchContextAggregations.java | 15 -- .../filter/FiltersAggregatorFactory.java | 36 ++-- .../internal/FilteredSearchContext.java | 6 + .../search/internal/SearchContext.java | 3 + .../query/ConcurrentQueryPhaseSearcher.java | 22 +-- .../search/query/QueryCollectorContext.java | 2 +- .../query/QueryCollectorManagerContext.java | 37 ++-- .../opensearch/search/query/QueryPhase.java | 40 +++- .../search/query/QueryPhaseSearcher.java | 10 + .../search/DefaultSearchContextTests.java | 21 +- .../opensearch/search/SearchModuleTests.java | 98 ++++++++++ .../opensearch/search/SearchServiceTests.java | 2 +- .../AggregationCollectorManagerTests.java | 125 ++++++++++++ .../AggregationCollectorTests.java | 67 ++++--- .../AggregationProcessorTests.java | 172 ++++++++++++++++ .../aggregations/AggregationSetupTests.java | 47 +++++ .../search/query/QueryPhaseTests.java | 6 +- .../opensearch/test/TestSearchContext.java | 6 + 38 files changed, 1241 insertions(+), 317 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java delete mode 100644 server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/AggregationProcessor.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/AggregationReduceableSearchResult.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/ConcurrentAggregationProcessor.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/DefaultAggregationProcessor.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManagerWithSingleCollector.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java create mode 100644 server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManagerWithSingleCollector.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorManagerTests.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/AggregationProcessorTests.java create mode 100644 server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 749aa99845da7..32cafe50f1b61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Allow mmap to use new JDK-19 preview APIs in Apache Lucene 9.4+ ([#5151](https://github.com/opensearch-project/OpenSearch/pull/5151)) - Add events correlation engine plugin ([#6854](https://github.com/opensearch-project/OpenSearch/issues/6854)) - Add support for ignoring missing Javadoc on generated code using annotation ([#7604](https://github.com/opensearch-project/OpenSearch/pull/7604)) +- Implement concurrent aggregations support without profile option ([#7514](https://github.com/opensearch-project/OpenSearch/pull/7514)) ### Dependencies - Bump `log4j-core` from 2.18.0 to 2.19.0 diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/ScriptedMetricIT.java index 27dbc56cf3b79..2065b122aac87 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -435,7 +435,9 @@ public void testMap() { assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); List aggregationList = (List) scriptedMetricAggregation.aggregation(); - assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries)); + // with script based aggregation, if it does not support reduce then aggregationList size + // will be numShards * slicesCount + assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries)); int numShardsRun = 0; for (Object object : aggregationList) { assertThat(object, notNullValue()); @@ -483,7 +485,9 @@ public void testMapWithParams() { assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); List aggregationList = (List) scriptedMetricAggregation.aggregation(); - assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries)); + // with script based aggregation, if it does not support reduce then aggregationList size + // will be numShards * slicesCount + assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries)); int numShardsRun = 0; for (Object object : aggregationList) { assertThat(object, notNullValue()); @@ -535,7 +539,9 @@ public void testInitMutatesParams() { assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); List aggregationList = (List) scriptedMetricAggregation.aggregation(); - assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries)); + // with script based aggregation, if it does not support reduce then aggregationList size + // will be numShards * slicesCount + assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries)); long totalCount = 0; for (Object object : aggregationList) { assertThat(object, notNullValue()); @@ -588,7 +594,9 @@ public void testMapCombineWithParams() { assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); List aggregationList = (List) scriptedMetricAggregation.aggregation(); - assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries)); + // with script based aggregation, if it does not support reduce then aggregationList size + // will be numShards * slicesCount + assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries)); long totalCount = 0; for (Object object : aggregationList) { assertThat(object, notNullValue()); @@ -651,7 +659,9 @@ public void testInitMapCombineWithParams() { assertThat(scriptedMetricAggregation.aggregation(), notNullValue()); assertThat(scriptedMetricAggregation.aggregation(), instanceOf(ArrayList.class)); List aggregationList = (List) scriptedMetricAggregation.aggregation(); - assertThat(aggregationList.size(), equalTo(getNumShards("idx").numPrimaries)); + // with script based aggregation, if it does not support reduce then aggregationList size + // will be numShards * slicesCount + assertThat(aggregationList.size(), greaterThanOrEqualTo(getNumShards("idx").numPrimaries)); long totalCount = 0; for (Object object : aggregationList) { assertThat(object, notNullValue()); diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index f142f0afbf92a..a4984db7c4095 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -95,11 +95,11 @@ public final class SearchPhaseController { private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0]; private final NamedWriteableRegistry namedWriteableRegistry; - private final Function requestToAggReduceContextBuilder; + private final Function requestToAggReduceContextBuilder; public SearchPhaseController( NamedWriteableRegistry namedWriteableRegistry, - Function requestToAggReduceContextBuilder + Function requestToAggReduceContextBuilder ) { this.namedWriteableRegistry = namedWriteableRegistry; this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder; @@ -737,7 +737,7 @@ public InternalSearchResponse buildResponse(SearchHits hits) { } InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) { - return requestToAggReduceContextBuilder.apply(request); + return requestToAggReduceContextBuilder.apply(request.source()); } /** diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index fd42089a4a9d5..69f529fe1d00c 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -443,7 +443,7 @@ private void executeRequest( localIndices, remoteClusterIndices, timeProvider, - searchService.aggReduceContextBuilder(searchRequest), + searchService.aggReduceContextBuilder(searchRequest.source()), remoteClusterService, threadPool, listener, diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 40081c087f09a..fb6cda4af00cd 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -64,7 +64,9 @@ import org.opensearch.index.search.NestedHelper; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.similarity.SimilarityService; +import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.SearchContextAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.fetch.FetchPhase; @@ -99,6 +101,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import java.util.function.Function; import java.util.function.LongSupplier; /** @@ -175,6 +178,7 @@ final class DefaultSearchContext extends SearchContext { private final Map, CollectorManager> queryCollectorManagers = new HashMap<>(); private final QueryShardContext queryShardContext; private final FetchPhase fetchPhase; + private final Function requestToAggReduceContextBuilder; DefaultSearchContext( ReaderContext readerContext, @@ -188,7 +192,8 @@ final class DefaultSearchContext extends SearchContext { boolean lowLevelCancellation, Version minNodeVersion, boolean validate, - Executor executor + Executor executor, + Function requestToAggReduceContextBuilder ) throws IOException { this.readerContext = readerContext; this.request = request; @@ -225,6 +230,7 @@ final class DefaultSearchContext extends SearchContext { ); queryBoost = request.indexBoost(); this.lowLevelCancellation = lowLevelCancellation; + this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder; } @Override @@ -886,4 +892,9 @@ public boolean isCancelled() { public ReaderContext readerContext() { return readerContext; } + + @Override + public InternalAggregation.ReduceContext partial() { + return requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction(); + } } diff --git a/server/src/main/java/org/opensearch/search/SearchModule.java b/server/src/main/java/org/opensearch/search/SearchModule.java index 3f49ef0c4bfcd..a4aa1cbf0d3c2 100644 --- a/server/src/main/java/org/opensearch/search/SearchModule.java +++ b/server/src/main/java/org/opensearch/search/SearchModule.java @@ -1290,7 +1290,14 @@ public FetchPhase getFetchPhase() { } public QueryPhase getQueryPhase() { - return (queryPhaseSearcher == null) ? new QueryPhase() : new QueryPhase(queryPhaseSearcher); + QueryPhase queryPhase; + if (queryPhaseSearcher == null) { + // use the defaults + queryPhase = new QueryPhase(); + } else { + queryPhase = new QueryPhase(queryPhaseSearcher); + } + return queryPhase; } public @Nullable ExecutorService getIndexSearcherExecutor(ThreadPool pool) { diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index efb5800879495..bc13acf5afe64 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -44,7 +44,6 @@ import org.opensearch.action.search.DeletePitResponse; import org.opensearch.action.search.ListPitInfo; import org.opensearch.action.search.PitSearchContextIdForNode; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.action.search.UpdatePitContextRequest; @@ -1038,7 +1037,8 @@ private DefaultSearchContext createSearchContext(ReaderContext reader, ShardSear lowLevelCancellation, clusterService.state().nodes().getMinNodeVersion(), validate, - indexSearcherExecutor + indexSearcherExecutor, + this::aggReduceContextBuilder ); // we clone the query shard context here just for rewriting otherwise we // might end up with incorrect state since we are using now() or script services @@ -1620,22 +1620,22 @@ public IndicesService getIndicesService() { /** * Returns a builder for {@link InternalAggregation.ReduceContext}. This - * builder retains a reference to the provided {@link SearchRequest}. + * builder retains a reference to the provided {@link SearchSourceBuilder}. */ - public InternalAggregation.ReduceContextBuilder aggReduceContextBuilder(SearchRequest request) { + public InternalAggregation.ReduceContextBuilder aggReduceContextBuilder(SearchSourceBuilder searchSourceBuilder) { return new InternalAggregation.ReduceContextBuilder() { @Override public InternalAggregation.ReduceContext forPartialReduction() { return InternalAggregation.ReduceContext.forPartialReduction( bigArrays, scriptService, - () -> requestToPipelineTree(request) + () -> requestToPipelineTree(searchSourceBuilder) ); } @Override public ReduceContext forFinalReduction() { - PipelineTree pipelineTree = requestToPipelineTree(request); + PipelineTree pipelineTree = requestToPipelineTree(searchSourceBuilder); return InternalAggregation.ReduceContext.forFinalReduction( bigArrays, scriptService, @@ -1646,11 +1646,11 @@ public ReduceContext forFinalReduction() { }; } - private static PipelineTree requestToPipelineTree(SearchRequest request) { - if (request.source() == null || request.source().aggregations() == null) { + private static PipelineTree requestToPipelineTree(SearchSourceBuilder searchSourceBuilder) { + if (searchSourceBuilder == null || searchSourceBuilder.aggregations() == null) { return PipelineTree.EMPTY; } - return request.source().aggregations().buildPipelineTree(); + return searchSourceBuilder.aggregations().buildPipelineTree(); } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java new file mode 100644 index 0000000000000..03519b335bbea --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.common.CheckedFunction; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.InternalProfileCollector; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; + +/** + * Common {@link CollectorManager} used by both concurrent and non-concurrent aggregation path and also for global and non-global + * aggregation operators + */ +class AggregationCollectorManager implements CollectorManager { + private final SearchContext context; + private final CheckedFunction, IOException> aggProvider; + private final String collectorReason; + + AggregationCollectorManager( + SearchContext context, + CheckedFunction, IOException> aggProvider, + String collectorReason + ) { + this.context = context; + this.aggProvider = aggProvider; + this.collectorReason = collectorReason; + } + + @Override + public Collector newCollector() throws IOException { + final Collector collector = createCollector(context, aggProvider.apply(context), collectorReason); + // For Aggregations we should not have a NO_OP_Collector + assert collector != BucketCollector.NO_OP_COLLECTOR; + return collector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + List aggregators = new ArrayList<>(); + + final Deque allCollectors = new LinkedList<>(collectors); + while (!allCollectors.isEmpty()) { + final Collector currentCollector = allCollectors.pop(); + if (currentCollector instanceof Aggregator) { + aggregators.add((Aggregator) currentCollector); + } else if (currentCollector instanceof InternalProfileCollector) { + if (((InternalProfileCollector) currentCollector).getCollector() instanceof Aggregator) { + aggregators.add((Aggregator) ((InternalProfileCollector) currentCollector).getCollector()); + } else if (((InternalProfileCollector) currentCollector).getCollector() instanceof MultiBucketCollector) { + allCollectors.addAll( + Arrays.asList(((MultiBucketCollector) ((InternalProfileCollector) currentCollector).getCollector()).getCollectors()) + ); + } + } else if (currentCollector instanceof MultiBucketCollector) { + allCollectors.addAll(Arrays.asList(((MultiBucketCollector) currentCollector).getCollectors())); + } + } + + final List internals = new ArrayList<>(aggregators.size()); + context.aggregations().resetBucketMultiConsumer(); + for (Aggregator aggregator : aggregators) { + try { + aggregator.postCollection(); + internals.add(aggregator.buildTopLevel()); + } catch (IOException e) { + throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e); + } + } + + final InternalAggregations internalAggregations = InternalAggregations.from(internals); + // Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce iff multiple slices + // were created to execute this request and it used concurrent segment search path + // TODO: Add the check for flag that the request was executed using concurrent search + if (collectors.size() > 1) { + // using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all + // documents are collected across shards for an aggregation + return new AggregationReduceableSearchResult( + InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partial()) + ); + } else { + return new AggregationReduceableSearchResult(internalAggregations); + } + } + + static Collector createCollector(SearchContext context, List collectors, String reason) throws IOException { + Collector collector = MultiBucketCollector.wrap(collectors); + ((BucketCollector) collector).preCollection(); + if (context.getProfilers() != null) { + collector = new InternalProfileCollector( + collector, + reason, + // TODO: report on child aggs as well + Collections.emptyList() + ); + } + return collector; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java deleted file mode 100644 index fe7f90703f776..0000000000000 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregationPhase.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.search.aggregations; - -import org.apache.lucene.search.Collector; -import org.apache.lucene.search.CollectorManager; -import org.apache.lucene.search.Query; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.lucene.search.Queries; -import org.opensearch.search.aggregations.bucket.global.GlobalAggregator; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.profile.aggregation.ProfilingAggregator; -import org.opensearch.search.profile.query.CollectorResult; -import org.opensearch.search.profile.query.InternalProfileCollector; -import org.opensearch.search.query.QueryPhaseExecutionException; -import org.opensearch.search.query.ReduceableSearchResult; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -/** - * Aggregation phase of a search request, used to collect aggregations - * - * @opensearch.internal - */ -public class AggregationPhase { - - @Inject - public AggregationPhase() {} - - public void preProcess(SearchContext context) { - if (context.aggregations() != null) { - List collectors = new ArrayList<>(); - Aggregator[] aggregators; - try { - AggregatorFactories factories = context.aggregations().factories(); - aggregators = factories.createTopLevelAggregators(context); - for (int i = 0; i < aggregators.length; i++) { - if (!isGlobalAggregator(context, aggregators[i])) { - collectors.add(aggregators[i]); - } - } - context.aggregations().aggregators(aggregators); - if (!collectors.isEmpty()) { - final Collector collector = createCollector(context, collectors); - context.queryCollectorManagers().put(AggregationPhase.class, new CollectorManager() { - @Override - public Collector newCollector() throws IOException { - return collector; - } - - @Override - public ReduceableSearchResult reduce(Collection collectors) throws IOException { - throw new UnsupportedOperationException("The concurrent aggregation over index segments is not supported"); - } - }); - } - } catch (IOException e) { - throw new AggregationInitializationException("Could not initialize aggregators", e); - } - } - } - - public void execute(SearchContext context) { - if (context.aggregations() == null) { - context.queryResult().aggregations(null); - return; - } - - if (context.queryResult().hasAggs()) { - // no need to compute the aggs twice, they should be computed on a per context basis - return; - } - - Aggregator[] aggregators = context.aggregations().aggregators(); - List globals = new ArrayList<>(); - for (int i = 0; i < aggregators.length; i++) { - if (isGlobalAggregator(context, aggregators[i])) { - globals.add(aggregators[i]); - } - } - - // optimize the global collector based execution - if (!globals.isEmpty()) { - BucketCollector globalsCollector = MultiBucketCollector.wrap(globals); - Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); - - try { - final Collector collector; - if (context.getProfilers() == null) { - collector = globalsCollector; - } else { - InternalProfileCollector profileCollector = new InternalProfileCollector( - globalsCollector, - CollectorResult.REASON_AGGREGATION_GLOBAL, - // TODO: report on sub collectors - Collections.emptyList() - ); - collector = profileCollector; - // start a new profile with this collector - context.getProfilers().addQueryProfiler().setCollector(profileCollector); - } - globalsCollector.preCollection(); - context.searcher().search(query, collector); - } catch (Exception e) { - throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e); - } - } - - List aggregations = new ArrayList<>(aggregators.length); - context.aggregations().resetBucketMultiConsumer(); - for (Aggregator aggregator : context.aggregations().aggregators()) { - try { - aggregator.postCollection(); - aggregations.add(aggregator.buildTopLevel()); - } catch (IOException e) { - throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e); - } - } - context.queryResult().aggregations(new InternalAggregations(aggregations)); - - // disable aggregations so that they don't run on next pages in case of scrolling - context.aggregations(null); - context.queryCollectorManagers().remove(AggregationPhase.class); - } - - private Collector createCollector(SearchContext context, List collectors) throws IOException { - Collector collector = MultiBucketCollector.wrap(collectors); - ((BucketCollector) collector).preCollection(); - if (context.getProfilers() != null) { - collector = new InternalProfileCollector( - collector, - CollectorResult.REASON_AGGREGATION, - // TODO: report on child aggs as well - Collections.emptyList() - ); - } - return collector; - } - - /** - * Checks if passed in aggregator is of type {@link GlobalAggregator}. This method takes care of Aggregator wrapped in - * {@link ProfilingAggregator} too - * @param context {@link SearchContext} - * @param aggregator input {@link Aggregator} instance to evaluate - * @return true input is {@link GlobalAggregator} instance or false otherwise - */ - private boolean isGlobalAggregator(SearchContext context, Aggregator aggregator) { - return (aggregator instanceof GlobalAggregator - || (context.getProfilers() != null && ProfilingAggregator.unwrap(aggregator) instanceof GlobalAggregator)); - } -} diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationProcessor.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationProcessor.java new file mode 100644 index 0000000000000..5b3e2f2542dc2 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationProcessor.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.opensearch.search.internal.SearchContext; + +/** + * Interface to define different stages of aggregation processing before and after document collection + */ +public interface AggregationProcessor { + + /** + * Callback invoked before collection of documents are done + * @param context {@link SearchContext} for the request + */ + void preProcess(SearchContext context); + + /** + * Callback invoked after collection of documents are done + * @param context {@link SearchContext} for the request + */ + void postProcess(SearchContext context); +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationReduceableSearchResult.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationReduceableSearchResult.java new file mode 100644 index 0000000000000..27c08f1221f6d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationReduceableSearchResult.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; + +/** + * {@link ReduceableSearchResult} returned by the {@link AggregationCollectorManager} which merges the aggregation with the one present in + * query results + */ +public class AggregationReduceableSearchResult implements ReduceableSearchResult { + private final InternalAggregations aggregations; + + public AggregationReduceableSearchResult(InternalAggregations aggregations) { + this.aggregations = aggregations; + } + + @Override + public void reduce(QuerySearchResult result) throws IOException { + if (!result.hasAggs()) { + result.aggregations(aggregations); + } else { + // the aggregations result from reduce of either global or non-global aggs is present so lets combine it with other aggs + // as well + final InternalAggregations existingAggregations = result.aggregations().expand(); + final InternalAggregations finalReducedAggregations = InternalAggregations.merge(existingAggregations, aggregations); + result.aggregations(finalReducedAggregations); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java index 1fae9d35823fd..f760070a9b650 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java @@ -48,6 +48,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.Rewriteable; import org.opensearch.search.aggregations.bucket.global.GlobalAggregationBuilder; +import org.opensearch.search.aggregations.bucket.global.GlobalAggregatorFactory; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; @@ -59,6 +60,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -70,6 +72,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -237,6 +240,13 @@ private static AggregatorFactories.Builder parseAggregators(XContentParser parse public static final AggregatorFactories EMPTY = new AggregatorFactories(new AggregatorFactory[0]); + private static final Predicate GLOBAL_AGGREGATOR_FACTORY_PREDICATE = new Predicate<>() { + @Override + public boolean test(AggregatorFactory o) { + return o instanceof GlobalAggregatorFactory; + } + }; + private AggregatorFactory[] factories; public static Builder builder() { @@ -268,24 +278,48 @@ public Aggregator[] createSubAggregators(SearchContext searchContext, Aggregator return aggregators; } - public Aggregator[] createTopLevelAggregators(SearchContext searchContext) throws IOException { + public List createTopLevelAggregators(SearchContext searchContext) throws IOException { + return createTopLevelAggregators(searchContext, (aggregatorFactory) -> true); + } + + public List createTopLevelGlobalAggregators(SearchContext searchContext) throws IOException { + return createTopLevelAggregators(searchContext, GLOBAL_AGGREGATOR_FACTORY_PREDICATE); + } + + public List createTopLevelNonGlobalAggregators(SearchContext searchContext) throws IOException { + return createTopLevelAggregators(searchContext, GLOBAL_AGGREGATOR_FACTORY_PREDICATE.negate()); + } + + private List createTopLevelAggregators(SearchContext searchContext, Predicate factoryFilter) + throws IOException { // These aggregators are going to be used with a single bucket ordinal, no need to wrap the PER_BUCKET ones - Aggregator[] aggregators = new Aggregator[factories.length]; + List aggregators = new ArrayList<>(); for (int i = 0; i < factories.length; i++) { /* * Top level aggs only collect from owningBucketOrd 0 which is * *exactly* what CardinalityUpperBound.ONE *means*. */ - Aggregator factory = factories[i].create(searchContext, null, CardinalityUpperBound.ONE); - Profilers profilers = factory.context().getProfilers(); - if (profilers != null) { - factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler()); + Aggregator factory; + if (factoryFilter.test(factories[i])) { + factory = factories[i].create(searchContext, null, CardinalityUpperBound.ONE); + Profilers profilers = factory.context().getProfilers(); + if (profilers != null) { + factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler()); + } + aggregators.add(factory); } - aggregators[i] = factory; } return aggregators; } + public boolean hasNonGlobalAggregator() { + return Arrays.stream(factories).anyMatch(GLOBAL_AGGREGATOR_FACTORY_PREDICATE.negate()); + } + + public boolean hasGlobalAggregator() { + return Arrays.stream(factories).anyMatch(GLOBAL_AGGREGATOR_FACTORY_PREDICATE); + } + /** * @return the number of sub-aggregator factories */ diff --git a/server/src/main/java/org/opensearch/search/aggregations/ConcurrentAggregationProcessor.java b/server/src/main/java/org/opensearch/search/aggregations/ConcurrentAggregationProcessor.java new file mode 100644 index 0000000000000..592fb8cc6e674 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/ConcurrentAggregationProcessor.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.opensearch.common.lucene.search.Queries; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.CollectorResult; +import org.opensearch.search.profile.query.InternalProfileCollectorManager; +import org.opensearch.search.profile.query.InternalProfileComponent; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collections; + +/** + * {@link AggregationProcessor} implementation to be used with {@link org.opensearch.search.query.ConcurrentQueryPhaseSearcher}. It takes + * care of performing shard level reduce on Aggregation results collected as part of concurrent execution among slices. This is done to + * avoid the increase in aggregation result sets returned by each shard to coordinator where final reduce happens for results received from + * all the shards + */ +public class ConcurrentAggregationProcessor extends DefaultAggregationProcessor { + + @Override + public void preProcess(SearchContext context) { + try { + if (context.aggregations() != null) { + if (context.aggregations().factories().hasNonGlobalAggregator()) { + context.queryCollectorManagers().put(NonGlobalAggCollectorManager.class, new NonGlobalAggCollectorManager(context)); + } + // initialize global aggregators as well, such that any failure to initialize can be caught before executing the request + if (context.aggregations().factories().hasGlobalAggregator()) { + context.queryCollectorManagers().put(GlobalAggCollectorManager.class, new GlobalAggCollectorManager(context)); + } + } + } catch (IOException ex) { + throw new AggregationInitializationException("Could not initialize aggregators", ex); + } + } + + @Override + public void postProcess(SearchContext context) { + if (context.aggregations() == null) { + context.queryResult().aggregations(null); + return; + } + + // for concurrent case we will perform only global aggregation in post process as QueryResult is already populated with results of + // processing the non-global aggregation + CollectorManager globalCollectorManager = context.queryCollectorManagers() + .get(GlobalAggCollectorManager.class); + try { + if (globalCollectorManager != null) { + Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); + globalCollectorManager = new InternalProfileCollectorManager( + globalCollectorManager, + CollectorResult.REASON_AGGREGATION_GLOBAL, + Collections.emptyList() + ); + if (context.getProfilers() != null) { + context.getProfilers().addQueryProfiler().setCollector((InternalProfileComponent) globalCollectorManager); + } + final ReduceableSearchResult result = context.searcher().search(query, globalCollectorManager); + result.reduce(context.queryResult()); + } + } catch (Exception e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e); + } + + // disable aggregations so that they don't run on next pages in case of scrolling + context.aggregations(null); + context.queryCollectorManagers().remove(NonGlobalAggCollectorManager.class); + context.queryCollectorManagers().remove(GlobalAggCollectorManager.class); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/DefaultAggregationProcessor.java b/server/src/main/java/org/opensearch/search/aggregations/DefaultAggregationProcessor.java new file mode 100644 index 0000000000000..05aa4a9acb270 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/DefaultAggregationProcessor.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Query; +import org.opensearch.common.lucene.search.Queries; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.InternalProfileComponent; +import org.opensearch.search.query.QueryPhaseExecutionException; + +import java.io.IOException; +import java.util.List; + +/** + * {@link AggregationProcessor} implementation which is used with {@link org.opensearch.search.query.QueryPhase.DefaultQueryPhaseSearcher}. + * This is the default implementation which works when collection for aggregations happen in sequential manner. It doesn't perform any + * reduce on the collected documents at shard level + */ +public class DefaultAggregationProcessor implements AggregationProcessor { + + @Override + public void preProcess(SearchContext context) { + try { + if (context.aggregations() != null) { + if (context.aggregations().factories().hasNonGlobalAggregator()) { + context.queryCollectorManagers() + .put(NonGlobalAggCollectorManager.class, new NonGlobalAggCollectorManagerWithSingleCollector(context)); + } + // initialize global aggregators as well, such that any failure to initialize can be caught before executing the request + if (context.aggregations().factories().hasGlobalAggregator()) { + context.queryCollectorManagers() + .put(GlobalAggCollectorManager.class, new GlobalAggCollectorManagerWithSingleCollector(context)); + } + } + } catch (IOException ex) { + throw new AggregationInitializationException("Could not initialize aggregators", ex); + } + } + + @Override + public void postProcess(SearchContext context) { + if (context.aggregations() == null) { + context.queryResult().aggregations(null); + return; + } + + if (context.queryResult().hasAggs()) { + // no need to compute the aggs twice, they should be computed on a per context basis + return; + } + + final AggregationCollectorManager nonGlobalCollectorManager = (AggregationCollectorManager) context.queryCollectorManagers() + .get(NonGlobalAggCollectorManager.class); + final AggregationCollectorManager globalCollectorManager = (AggregationCollectorManager) context.queryCollectorManagers() + .get(GlobalAggCollectorManager.class); + try { + if (nonGlobalCollectorManager != null) { + nonGlobalCollectorManager.reduce(List.of()).reduce(context.queryResult()); + } + + try { + if (globalCollectorManager != null) { + Query query = context.buildFilteredQuery(Queries.newMatchAllQuery()); + if (context.getProfilers() != null) { + context.getProfilers() + .addQueryProfiler() + .setCollector((InternalProfileComponent) globalCollectorManager.newCollector()); + } + context.searcher().search(query, globalCollectorManager.newCollector()); + globalCollectorManager.reduce(List.of()).reduce(context.queryResult()); + } + } catch (Exception e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "Failed to execute global aggregators", e); + } + } catch (IOException ex) { + throw new QueryPhaseExecutionException(context.shardTarget(), "Post processing failed for aggregators", ex); + } + + // disable aggregations so that they don't run on next pages in case of scrolling + context.aggregations(null); + context.queryCollectorManagers().remove(NonGlobalAggCollectorManager.class); + context.queryCollectorManagers().remove(GlobalAggCollectorManager.class); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java new file mode 100644 index 0000000000000..56f53a57a8573 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.CollectorResult; + +import java.io.IOException; +import java.util.Objects; + +/** + * {@link CollectorManager} to take care of global aggregation operators in case of concurrent segment search + */ +public class GlobalAggCollectorManager extends AggregationCollectorManager { + + private Collector collector; + + public GlobalAggCollectorManager(SearchContext context) throws IOException { + super(context, context.aggregations().factories()::createTopLevelGlobalAggregators, CollectorResult.REASON_AGGREGATION_GLOBAL); + collector = Objects.requireNonNull(super.newCollector(), "collector instance is null"); + } + + @Override + public Collector newCollector() throws IOException { + if (collector != null) { + final Collector toReturn = collector; + collector = null; + return toReturn; + } else { + return super.newCollector(); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManagerWithSingleCollector.java b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManagerWithSingleCollector.java new file mode 100644 index 0000000000000..f126f27c68855 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManagerWithSingleCollector.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.CollectorResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * {@link CollectorManager} to take care of global aggregation operators in case of non-concurrent segment search. This CollectorManager + * returns the same collector instance (i.e. created in constructor of super class) on each newCollector call + */ +public class GlobalAggCollectorManagerWithSingleCollector extends AggregationCollectorManager { + + private final Collector collector; + + public GlobalAggCollectorManagerWithSingleCollector(SearchContext context) throws IOException { + super(context, context.aggregations().factories()::createTopLevelGlobalAggregators, CollectorResult.REASON_AGGREGATION_GLOBAL); + collector = Objects.requireNonNull(super.newCollector(), "collector instance is null"); + } + + @Override + public Collector newCollector() throws IOException { + return collector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + assert collectors.isEmpty() : "Reduce on GlobalAggregationCollectorManagerWithCollector called with non-empty collectors"; + return super.reduce(List.of(collector)); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java index 16d7898118fc3..228360b872042 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java @@ -207,6 +207,15 @@ public long getSerializedSize() { } } + public static InternalAggregations merge(InternalAggregations first, InternalAggregations second) { + final List fromFirst = first.getInternalAggregations(); + final List fromSecond = second.getInternalAggregations(); + final List mergedAggregation = new ArrayList<>(fromFirst.size() + fromSecond.size()); + mergedAggregation.addAll(fromFirst); + mergedAggregation.addAll(fromSecond); + return new InternalAggregations(mergedAggregation); + } + /** * A counting stream output * diff --git a/server/src/main/java/org/opensearch/search/aggregations/MultiBucketCollector.java b/server/src/main/java/org/opensearch/search/aggregations/MultiBucketCollector.java index cac3a6151bd78..9915848ee2e1d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/MultiBucketCollector.java +++ b/server/src/main/java/org/opensearch/search/aggregations/MultiBucketCollector.java @@ -121,6 +121,10 @@ private MultiBucketCollector(BucketCollector... collectors) { this.cacheScores = numNeedsScores >= 2; } + public BucketCollector[] getCollectors() { + return collectors; + } + @Override public ScoreMode scoreMode() { ScoreMode scoreMode = null; diff --git a/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java new file mode 100644 index 0000000000000..3729734c48ed7 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.CollectorResult; + +import java.io.IOException; +import java.util.Objects; + +/** + * {@link CollectorManager} to take care of non-global aggregation operators in case of concurrent segment search + */ +public class NonGlobalAggCollectorManager extends AggregationCollectorManager { + + private Collector collector; + + public NonGlobalAggCollectorManager(SearchContext context) throws IOException { + super(context, context.aggregations().factories()::createTopLevelNonGlobalAggregators, CollectorResult.REASON_AGGREGATION); + collector = Objects.requireNonNull(super.newCollector(), "collector instance is null"); + } + + @Override + public Collector newCollector() throws IOException { + if (collector != null) { + final Collector toReturn = collector; + collector = null; + return toReturn; + } else { + return super.newCollector(); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManagerWithSingleCollector.java b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManagerWithSingleCollector.java new file mode 100644 index 0000000000000..433f6b6a05b22 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManagerWithSingleCollector.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.profile.query.CollectorResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Objects; + +/** + * {@link CollectorManager} to take care of non-global aggregation operators in case of non-concurrent segment search. This + * CollectorManager returns the same collector instance (i.e. created in constructor of super class) on each newCollector call + */ +public class NonGlobalAggCollectorManagerWithSingleCollector extends AggregationCollectorManager { + + private final Collector collector; + + public NonGlobalAggCollectorManagerWithSingleCollector(SearchContext context) throws IOException { + super(context, context.aggregations().factories()::createTopLevelNonGlobalAggregators, CollectorResult.REASON_AGGREGATION); + collector = Objects.requireNonNull(super.newCollector(), "collector instance is null"); + } + + @Override + public Collector newCollector() throws IOException { + return collector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) throws IOException { + assert collectors.isEmpty() : "Reduce on NonGlobalAggregationCollectorManagerWithCollector called with non-empty collectors"; + return super.reduce(List.of(collector)); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/SearchContextAggregations.java b/server/src/main/java/org/opensearch/search/aggregations/SearchContextAggregations.java index f51d5af23b049..16339713dc83e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/SearchContextAggregations.java +++ b/server/src/main/java/org/opensearch/search/aggregations/SearchContextAggregations.java @@ -39,10 +39,8 @@ * @opensearch.internal */ public class SearchContextAggregations { - private final AggregatorFactories factories; private final MultiBucketConsumer multiBucketConsumer; - private Aggregator[] aggregators; /** * Creates a new aggregation context with the parsed aggregator factories @@ -56,19 +54,6 @@ public AggregatorFactories factories() { return factories; } - public Aggregator[] aggregators() { - return aggregators; - } - - /** - * Registers all the created aggregators (top level aggregators) for the search execution context. - * - * @param aggregators The top level aggregators of the search execution. - */ - public void aggregators(Aggregator[] aggregators) { - this.aggregators = aggregators; - } - /** * Returns a consumer for multi bucket aggregation that checks the total number of buckets * created in the response diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregatorFactory.java index 8741213f98811..795f81a08d8d5 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregatorFactory.java @@ -58,7 +58,7 @@ public class FiltersAggregatorFactory extends AggregatorFactory { private final String[] keys; private final Query[] filters; - private Weight[] weights; + private volatile Weight[] weights; private final boolean keyed; private final boolean otherBucket; private final String otherBucketKey; @@ -93,19 +93,33 @@ public FiltersAggregatorFactory( * created if the aggregation collects documents reducing the overhead of * the aggregation in the case where no documents are collected. * - * Note that as aggregations are initialsed and executed in a serial manner, - * no concurrency considerations are necessary here. + * Note: With concurrent segment search use case, multiple aggregation collectors executing + * on different threads will try to fetch the weights. To handle the race condition there is + * a synchronization block */ public Weight[] getWeights(SearchContext searchContext) { - if (weights == null) { - try { - IndexSearcher contextSearcher = searchContext.searcher(); - weights = new Weight[filters.length]; - for (int i = 0; i < filters.length; ++i) { - this.weights[i] = contextSearcher.createWeight(contextSearcher.rewrite(filters[i]), ScoreMode.COMPLETE_NO_SCORES, 1); + if (weights != null) { + return weights; + } + + // This will happen only for the first segment access in the slices. After that for other segments + // weights will be non-null and returned from above + synchronized (this) { + if (weights == null) { + try { + final Weight[] filterWeights = new Weight[filters.length]; + IndexSearcher contextSearcher = searchContext.searcher(); + for (int i = 0; i < filters.length; ++i) { + filterWeights[i] = contextSearcher.createWeight( + contextSearcher.rewrite(filters[i]), + ScoreMode.COMPLETE_NO_SCORES, + 1 + ); + } + weights = filterWeights; + } catch (IOException e) { + throw new AggregationInitializationException("Failed to initialze filters for aggregation [" + name() + "]", e); } - } catch (IOException e) { - throw new AggregationInitializationException("Failed to initialse filters for aggregation [" + name() + "]", e); } } return weights; diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index ffb180614c3b2..9bfc0e8b6fea5 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -50,6 +50,7 @@ import org.opensearch.index.similarity.SimilarityService; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsSearchResult; @@ -542,4 +543,9 @@ public void addRescore(RescoreContext rescore) { public ReaderContext readerContext() { return in.readerContext(); } + + @Override + public InternalAggregation.ReduceContext partial() { + return in.partial(); + } } diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 76d0d7b72c6b4..d8202f1c36800 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -53,6 +53,7 @@ import org.opensearch.search.RescoreDocIds; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsSearchResult; @@ -429,4 +430,6 @@ public String toString() { } public abstract ReaderContext readerContext(); + + public abstract InternalAggregation.ReduceContext partial(); } diff --git a/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java b/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java index 1b4e8c5e3e56f..dbdcb8132bec0 100644 --- a/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java +++ b/server/src/main/java/org/opensearch/search/query/ConcurrentQueryPhaseSearcher.java @@ -13,6 +13,8 @@ import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Query; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.aggregations.ConcurrentAggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.profile.query.ProfileCollectorManager; @@ -47,12 +49,6 @@ protected boolean searchWithCollector( ) throws IOException { boolean couldUseConcurrentSegmentSearch = allowConcurrentSegmentSearch(searcher); - // TODO: support aggregations - if (searchContext.aggregations() != null) { - couldUseConcurrentSegmentSearch = false; - LOGGER.debug("Unable to use concurrent search over index segments (experimental): aggregations are present"); - } - if (couldUseConcurrentSegmentSearch) { LOGGER.debug("Using concurrent search over index segments (experimental)"); return searchWithCollectorManager(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); @@ -77,19 +73,14 @@ private static boolean searchWithCollectorManager( final QuerySearchResult queryResult = searchContext.queryResult(); final CollectorManager collectorManager; - // TODO: support aggregations in concurrent segment search flow - if (searchContext.aggregations() != null) { - throw new UnsupportedOperationException("The concurrent segment search does not support aggregations yet"); - } - if (searchContext.getProfilers() != null) { final ProfileCollectorManager profileCollectorManager = QueryCollectorManagerContext.createQueryCollectorManagerWithProfiler(collectorContexts); searchContext.getProfilers().getCurrentQueryProfiler().setCollector(profileCollectorManager); collectorManager = profileCollectorManager; } else { - // Create multi collector manager instance - collectorManager = QueryCollectorManagerContext.createMultiCollectorManager(collectorContexts); + // Create collector manager tree + collectorManager = QueryCollectorManagerContext.createQueryCollectorManager(collectorContexts); } try { @@ -112,6 +103,11 @@ private static boolean searchWithCollectorManager( return topDocsFactory.shouldRescore(); } + @Override + public AggregationProcessor newAggregationProcessor() { + return new ConcurrentAggregationProcessor(); + } + private static boolean allowConcurrentSegmentSearch(final ContextIndexSearcher searcher) { return (searcher.getExecutor() != null); } diff --git a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java index 9ce4a73c97c8d..c611587e879d6 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java @@ -215,7 +215,7 @@ CollectorManager createManager( final List> managers = new ArrayList<>(); managers.add(in); managers.addAll(subs); - return QueryCollectorManagerContext.createOpaqueCollectorManager(managers); + return QueryCollectorManagerContext.createMultiCollectorManager(managers); } }; } diff --git a/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java b/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java index 3fa5003cdbfbe..29e6244e05163 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java +++ b/server/src/main/java/org/opensearch/search/query/QueryCollectorManagerContext.java @@ -15,7 +15,6 @@ import org.opensearch.search.profile.query.ProfileCollectorManager; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -59,34 +58,24 @@ protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] resul } } - private static class OpaqueQueryCollectorManager extends QueryCollectorManager { - private OpaqueQueryCollectorManager(Collection> managers) { - super(managers); - } - - @Override - protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] results) { - return (QuerySearchResult result) -> {}; - } - } - - public static CollectorManager createOpaqueCollectorManager( - List> managers + /** + * Create query {@link CollectorManager} tree using the provided query collector contexts + * @param collectorContexts list of {@link QueryCollectorContext} + * @return {@link CollectorManager} representing the manager tree for the query + */ + public static CollectorManager createQueryCollectorManager( + List collectorContexts ) throws IOException { - return new OpaqueQueryCollectorManager(managers); - } - - public static CollectorManager createMultiCollectorManager( - List collectors - ) throws IOException { - final Collection> managers = new ArrayList<>(); - CollectorManager manager = null; - for (QueryCollectorContext ctx : collectors) { + for (QueryCollectorContext ctx : collectorContexts) { manager = ctx.createManager(manager); - managers.add(manager); } + return manager; + } + public static CollectorManager createMultiCollectorManager( + List> managers + ) { return new QueryCollectorManager(managers); } diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 4e9bc25df3a1a..82de8b4088fb7 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -54,7 +54,8 @@ import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchContextSourcePrinter; import org.opensearch.search.SearchService; -import org.opensearch.search.aggregations.AggregationPhase; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.aggregations.GlobalAggCollectorManager; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; @@ -69,8 +70,10 @@ import java.io.IOException; import java.util.LinkedList; +import java.util.Map; import java.util.Objects; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import static org.opensearch.search.query.QueryCollectorContext.createEarlyTerminationCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createFilteredCollectorContext; @@ -89,9 +92,8 @@ public class QueryPhase { // TODO: remove this property public static final boolean SYS_PROP_REWRITE_SORT = Booleans.parseBoolean(System.getProperty("opensearch.search.rewrite_sort", "true")); public static final QueryPhaseSearcher DEFAULT_QUERY_PHASE_SEARCHER = new DefaultQueryPhaseSearcher(); - private final QueryPhaseSearcher queryPhaseSearcher; - private final AggregationPhase aggregationPhase; + private final AggregationProcessor aggregationProcessor; private final SuggestPhase suggestPhase; private final RescorePhase rescorePhase; @@ -101,7 +103,10 @@ public QueryPhase() { public QueryPhase(QueryPhaseSearcher queryPhaseSearcher) { this.queryPhaseSearcher = Objects.requireNonNull(queryPhaseSearcher, "QueryPhaseSearcher is required"); - this.aggregationPhase = new AggregationPhase(); + this.aggregationProcessor = Objects.requireNonNull( + queryPhaseSearcher.newAggregationProcessor(), + "AggregationProcessor is required" + ); this.suggestPhase = new SuggestPhase(); this.rescorePhase = new RescorePhase(); } @@ -145,14 +150,14 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep // Pre-process aggregations as late as possible. In the case of a DFS_Q_T_F // request, preProcess is called on the DFS phase phase, this is why we pre-process them // here to make sure it happens during the QUERY phase - aggregationPhase.preProcess(searchContext); + aggregationProcessor.preProcess(searchContext); boolean rescore = executeInternal(searchContext, queryPhaseSearcher); if (rescore) { // only if we do a regular search rescorePhase.execute(searchContext); } suggestPhase.execute(searchContext); - aggregationPhase.execute(searchContext); + aggregationProcessor.postProcess(searchContext); if (searchContext.getProfilers() != null) { ProfileShardResult shardResults = SearchProfileShardResults.buildShardResults( @@ -163,6 +168,16 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep } } + // making public for testing + public QueryPhaseSearcher getQueryPhaseSearcher() { + return queryPhaseSearcher; + } + + // making public for testing + public AggregationProcessor getAggregationProcessor() { + return aggregationProcessor; + } + /** * In a package-private method so that it can be tested without having to * wire everything (mapperService, etc.) @@ -228,8 +243,17 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q hasFilterCollector = true; } if (searchContext.queryCollectorManagers().isEmpty() == false) { - // plug in additional collectors, like aggregations - collectors.add(createMultiCollectorContext(searchContext.queryCollectorManagers().values())); + // plug in additional collectors, like aggregations except global aggregations + collectors.add( + createMultiCollectorContext( + searchContext.queryCollectorManagers() + .entrySet() + .stream() + .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) + .map(Map.Entry::getValue) + .collect(Collectors.toList()) + ) + ); } if (searchContext.minimumScore() != null) { // apply the minimum score after multi collector so we filter aggs as well diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java b/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java index 1995137e8b52e..93bc29e9d8cb9 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhaseSearcher.java @@ -10,6 +10,8 @@ import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Query; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.aggregations.DefaultAggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; @@ -40,4 +42,12 @@ boolean searchWith( boolean hasFilterCollector, boolean hasTimeout ) throws IOException; + + /** + * {@link AggregationProcessor} to use to setup and post process aggregation related collectors during search request + * @return {@link AggregationProcessor} to use + */ + default AggregationProcessor newAggregationProcessor() { + return new DefaultAggregationProcessor(); + } } diff --git a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java index d2242b7d3f07e..82ebae65a147b 100644 --- a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java @@ -213,7 +213,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); contextWithoutScroll.from(300); contextWithoutScroll.close(); @@ -255,7 +256,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); context1.from(300); exception = expectThrows(IllegalArgumentException.class, () -> context1.preProcess(false)); @@ -325,7 +327,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); SliceBuilder sliceBuilder = mock(SliceBuilder.class); @@ -364,7 +367,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); ParsedQuery parsedQuery = ParsedQuery.parsedMatchAllQuery(); context3.sliceBuilder(null).parsedQuery(parsedQuery).preProcess(false); @@ -399,7 +403,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); context4.sliceBuilder(new SliceBuilder(1, 2)).parsedQuery(parsedQuery).preProcess(false); Query query1 = context4.query(); @@ -429,7 +434,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); int numSlicesForPit = maxSlicesPerPit + randomIntBetween(1, 100); when(sliceBuilder.getMax()).thenReturn(numSlicesForPit); @@ -526,7 +532,8 @@ protected Engine.Searcher acquireSearcherInternal(String source) { false, Version.CURRENT, false, - executor + executor, + null ); assertThat(context.searcher().hasCancellations(), is(false)); context.searcher().addQueryCancellation(() -> {}); diff --git a/server/src/test/java/org/opensearch/search/SearchModuleTests.java b/server/src/test/java/org/opensearch/search/SearchModuleTests.java index d8c178caf7da8..7ca5441564e1c 100644 --- a/server/src/test/java/org/opensearch/search/SearchModuleTests.java +++ b/server/src/test/java/org/opensearch/search/SearchModuleTests.java @@ -36,6 +36,7 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; @@ -50,6 +51,8 @@ import org.opensearch.search.aggregations.AggregatorFactories.Builder; import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.BaseAggregationBuilder; +import org.opensearch.search.aggregations.ConcurrentAggregationProcessor; +import org.opensearch.search.aggregations.DefaultAggregationProcessor; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregation.ReduceContext; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; @@ -73,6 +76,9 @@ import org.opensearch.search.fetch.subphase.highlight.Highlighter; import org.opensearch.search.fetch.subphase.highlight.PlainHighlighter; import org.opensearch.search.fetch.subphase.highlight.UnifiedHighlighter; +import org.opensearch.search.query.ConcurrentQueryPhaseSearcher; +import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.rescore.RescorerBuilder; @@ -86,6 +92,7 @@ import org.opensearch.search.suggest.term.TermSuggestion; import org.opensearch.search.suggest.term.TermSuggestionBuilder; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.ArrayList; @@ -93,6 +100,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static java.util.Collections.emptyList; @@ -102,6 +110,9 @@ import static java.util.stream.Collectors.toSet; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; public class SearchModuleTests extends OpenSearchTestCase { @@ -409,6 +420,93 @@ public List> getRescorers() { ); } + public void testDefaultQueryPhaseSearcher() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + QueryPhase queryPhase = searchModule.getQueryPhase(); + assertTrue(queryPhase.getQueryPhaseSearcher() instanceof QueryPhase.DefaultQueryPhaseSearcher); + assertTrue(queryPhase.getAggregationProcessor() instanceof DefaultAggregationProcessor); + } + + public void testConcurrentQueryPhaseSearcher() { + Settings settings = Settings.builder().put(FeatureFlags.CONCURRENT_SEGMENT_SEARCH, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + SearchModule searchModule = new SearchModule(settings, Collections.emptyList()); + QueryPhase queryPhase = searchModule.getQueryPhase(); + assertTrue(queryPhase.getQueryPhaseSearcher() instanceof ConcurrentQueryPhaseSearcher); + assertTrue(queryPhase.getAggregationProcessor() instanceof ConcurrentAggregationProcessor); + FeatureFlags.initializeFeatureFlags(Settings.EMPTY); + } + + public void testPluginQueryPhaseSearcher() { + Settings settings = Settings.builder().put(FeatureFlags.CONCURRENT_SEGMENT_SEARCH, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + QueryPhaseSearcher queryPhaseSearcher = (searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout) -> false; + SearchPlugin plugin1 = new SearchPlugin() { + @Override + public Optional getQueryPhaseSearcher() { + return Optional.of(queryPhaseSearcher); + } + }; + SearchModule searchModule = new SearchModule(settings, Collections.singletonList(plugin1)); + QueryPhase queryPhase = searchModule.getQueryPhase(); + assertEquals(queryPhaseSearcher, queryPhase.getQueryPhaseSearcher()); + assertTrue(queryPhase.getAggregationProcessor() instanceof DefaultAggregationProcessor); + FeatureFlags.initializeFeatureFlags(Settings.EMPTY); + } + + public void testMultiplePluginRegisterQueryPhaseSearcher() { + SearchPlugin plugin1 = new SearchPlugin() { + @Override + public Optional getQueryPhaseSearcher() { + return Optional.of(mock(QueryPhaseSearcher.class)); + } + }; + SearchPlugin plugin2 = new SearchPlugin() { + @Override + public Optional getQueryPhaseSearcher() { + return Optional.of(new ConcurrentQueryPhaseSearcher()); + } + }; + List searchPlugins = new ArrayList<>(); + searchPlugins.add(plugin1); + searchPlugins.add(plugin2); + expectThrows(IllegalStateException.class, () -> new SearchModule(Settings.EMPTY, searchPlugins)); + } + + public void testIndexSearcher() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + ThreadPool threadPool = mock(ThreadPool.class); + assertNull(searchModule.getIndexSearcherExecutor(threadPool)); + verify(threadPool, times(0)).executor(ThreadPool.Names.INDEX_SEARCHER); + + // enable concurrent segment search feature flag + Settings settings = Settings.builder().put(FeatureFlags.CONCURRENT_SEGMENT_SEARCH, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + searchModule = new SearchModule(settings, Collections.emptyList()); + searchModule.getIndexSearcherExecutor(threadPool); + verify(threadPool).executor(ThreadPool.Names.INDEX_SEARCHER); + FeatureFlags.initializeFeatureFlags(Settings.EMPTY); + } + + public void testMultiplePluginRegisterIndexSearcherProvider() { + SearchPlugin plugin1 = new SearchPlugin() { + @Override + public Optional getIndexSearcherExecutorProvider() { + return Optional.of(mock(ExecutorServiceProvider.class)); + } + }; + SearchPlugin plugin2 = new SearchPlugin() { + @Override + public Optional getIndexSearcherExecutorProvider() { + return Optional.of(mock(ExecutorServiceProvider.class)); + } + }; + List searchPlugins = new ArrayList<>(); + searchPlugins.add(plugin1); + searchPlugins.add(plugin2); + expectThrows(IllegalStateException.class, () -> new SearchModule(Settings.EMPTY, searchPlugins)); + } + private static final String[] NON_DEPRECATED_QUERIES = new String[] { "bool", "boosting", diff --git a/server/src/test/java/org/opensearch/search/SearchServiceTests.java b/server/src/test/java/org/opensearch/search/SearchServiceTests.java index a17834bccb238..72c74ddb71725 100644 --- a/server/src/test/java/org/opensearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/opensearch/search/SearchServiceTests.java @@ -1129,7 +1129,7 @@ public void testExpandSearchFrozen() { public void testCreateReduceContext() { SearchService service = getInstanceFromNode(SearchService.class); - InternalAggregation.ReduceContextBuilder reduceContextBuilder = service.aggReduceContextBuilder(new SearchRequest()); + InternalAggregation.ReduceContextBuilder reduceContextBuilder = service.aggReduceContextBuilder(new SearchSourceBuilder()); { InternalAggregation.ReduceContext reduceContext = reduceContextBuilder.forFinalReduction(); expectThrows( diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorManagerTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorManagerTests.java new file mode 100644 index 0000000000000..7fcf2216040c9 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorManagerTests.java @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.opensearch.search.aggregations.bucket.global.GlobalAggregator; + +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.mock; + +public class AggregationCollectorManagerTests extends AggregationSetupTests { + + public void testNonGlobalCollectorManagers() throws Exception { + final AggregatorFactories aggregatorFactories = getAggregationFactories(multipleNonGlobalAggs); + final SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + int expectedAggCount = 2; + final AggregationCollectorManager testAggCollectorManager = new NonGlobalAggCollectorManagerWithSingleCollector(context); + Collector aggCollector = testAggCollectorManager.newCollector(); + assertTrue(aggCollector instanceof MultiBucketCollector); + assertEquals(expectedAggCount, ((MultiBucketCollector) aggCollector).getCollectors().length); + testCollectorManagerCommon(testAggCollectorManager); + + // test NonGlobalCollectorManager which will be used in concurrent segment search case + testCollectorManagerCommon(new NonGlobalAggCollectorManager(context)); + } + + public void testGlobalCollectorManagers() throws Exception { + final AggregatorFactories aggregatorFactories = getAggregationFactories(globalAgg); + final SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + final AggregationCollectorManager testAggCollectorManager = new GlobalAggCollectorManagerWithSingleCollector(context); + testCollectorManagerCommon(testAggCollectorManager); + Collector aggCollector = testAggCollectorManager.newCollector(); + assertTrue(aggCollector instanceof BucketCollector); + + // test GlobalAggCollectorManager which will be used in concurrent segment search case + testCollectorManagerCommon(new GlobalAggCollectorManager(context)); + } + + public void testAggCollectorManagersWithBothGlobalNonGlobalAggregators() throws Exception { + final AggregatorFactories aggregatorFactories = getAggregationFactories(globalNonGlobalAggs); + final SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + final AggregationCollectorManager testAggCollectorManager = new NonGlobalAggCollectorManagerWithSingleCollector(context); + Collector aggCollector = testAggCollectorManager.newCollector(); + assertTrue(aggCollector instanceof BucketCollector); + assertFalse(aggCollector instanceof GlobalAggregator); + + final AggregationCollectorManager testGlobalAggCollectorManager = new GlobalAggCollectorManagerWithSingleCollector(context); + Collector globalAggCollector = testGlobalAggCollectorManager.newCollector(); + assertTrue(globalAggCollector instanceof BucketCollector); + assertTrue(globalAggCollector instanceof GlobalAggregator); + + testCollectorManagerCommon(testAggCollectorManager); + testCollectorManagerCommon(testGlobalAggCollectorManager); + } + + public void testAssertionWhenCollectorManagerCreatesNoOPCollector() throws Exception { + AggregatorFactories aggregatorFactories = getAggregationFactories(globalAgg); + SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + expectThrows(AssertionError.class, () -> new NonGlobalAggCollectorManagerWithSingleCollector(context)); + expectThrows(AssertionError.class, () -> new NonGlobalAggCollectorManager(context)); + + aggregatorFactories = getAggregationFactories(multipleNonGlobalAggs); + contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + expectThrows(AssertionError.class, () -> new GlobalAggCollectorManagerWithSingleCollector(context)); + expectThrows(AssertionError.class, () -> new GlobalAggCollectorManager(context)); + } + + public void testAssertionInSingleCollectorCMReduce() throws Exception { + AggregatorFactories aggregatorFactories = getAggregationFactories(globalNonGlobalAggs); + SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + List collectorsList = new ArrayList<>(); + collectorsList.add(mock(Collector.class)); + context.aggregations(contextAggregations); + AggregationCollectorManager globalCM = new GlobalAggCollectorManagerWithSingleCollector(context); + AggregationCollectorManager nonGlobalCM = new NonGlobalAggCollectorManagerWithSingleCollector(context); + expectThrows(AssertionError.class, () -> globalCM.reduce(collectorsList)); + expectThrows(AssertionError.class, () -> nonGlobalCM.reduce(collectorsList)); + } + + private void testCollectorManagerCommon(AggregationCollectorManager collectorManager) throws Exception { + final Collector expectedCollector = collectorManager.newCollector(); + for (int i = 0; i < randomIntBetween(2, 5); ++i) { + final Collector newCollector = collectorManager.newCollector(); + if (collectorManager instanceof GlobalAggCollectorManagerWithSingleCollector + || collectorManager instanceof NonGlobalAggCollectorManagerWithSingleCollector) { + // calling the newCollector multiple times should return the same instance each time + assertSame(expectedCollector, newCollector); + } else if (collectorManager instanceof GlobalAggCollectorManager || collectorManager instanceof NonGlobalAggCollectorManager) { + // calling the newCollector multiple times should not return the same instance each time + assertNotSame(expectedCollector, newCollector); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorTests.java index 3c8701d10fd83..a39a261f01ccc 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationCollectorTests.java @@ -32,24 +32,18 @@ package org.opensearch.search.aggregations; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.index.IndexService; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.search.aggregations.bucket.global.GlobalAggregator; +import org.opensearch.search.aggregations.bucket.terms.NumericTermsAggregator; import java.io.IOException; +import java.util.List; -public class AggregationCollectorTests extends OpenSearchSingleNodeTestCase { +public class AggregationCollectorTests extends AggregationSetupTests { public void testNeedsScores() throws Exception { - IndexService index = createIndex("idx"); - client().prepareIndex("idx").setId("1").setSource("f", 5).execute().get(); - client().admin().indices().prepareRefresh("idx").get(); - // simple field aggregation, no scores needed String fieldAgg = "{ \"my_terms\": {\"terms\": {\"field\": \"f\"}}}"; - assertFalse(needsScores(index, fieldAgg)); + assertFalse(needsScores(fieldAgg)); // agg on a script => scores are needed // TODO: can we use a mock script service here? @@ -61,23 +55,50 @@ public void testNeedsScores() throws Exception { // make sure the information is propagated to sub aggregations String subFieldAgg = "{ \"my_outer_terms\": { \"terms\": { \"field\": \"f\" }, \"aggs\": " + fieldAgg + "}}"; - assertFalse(needsScores(index, subFieldAgg)); + assertFalse(needsScores(subFieldAgg)); // top_hits is a particular example of an aggregation that needs scores String topHitsAgg = "{ \"my_hits\": {\"top_hits\": {}}}"; - assertTrue(needsScores(index, topHitsAgg)); + assertTrue(needsScores(topHitsAgg)); + } + + public void testNonGlobalTopLevelAggregators() throws Exception { + // simple field aggregation + String fieldAgg = "{ \"my_terms\": {\"terms\": {\"field\": \"f\"}}}"; + final List aggregators = createNonGlobalAggregators(fieldAgg); + final List topLevelAggregators = createTopLevelAggregators(fieldAgg); + assertEquals(topLevelAggregators.size(), aggregators.size()); + assertEquals(topLevelAggregators.get(0).name(), aggregators.get(0).name()); + assertTrue(aggregators.get(0) instanceof NumericTermsAggregator); + } + + public void testGlobalAggregators() throws Exception { + // global aggregation + final List aggregators = createGlobalAggregators(globalAgg); + final List topLevelAggregators = createTopLevelAggregators(globalAgg); + assertEquals(topLevelAggregators.size(), aggregators.size()); + assertEquals(topLevelAggregators.get(0).name(), aggregators.get(0).name()); + assertTrue(aggregators.get(0) instanceof GlobalAggregator); } - private boolean needsScores(IndexService index, String agg) throws IOException { - try (XContentParser aggParser = createParser(JsonXContent.jsonXContent, agg)) { - aggParser.nextToken(); - SearchContext context = createSearchContext(index); - final AggregatorFactories factories = AggregatorFactories.parseAggregators(aggParser) - .build(context.getQueryShardContext(), null); - final Aggregator[] aggregators = factories.createTopLevelAggregators(context); - assertEquals(1, aggregators.length); - return aggregators[0].scoreMode().needsScores(); - } + private boolean needsScores(String agg) throws IOException { + final List aggregators = createTopLevelAggregators(agg); + assertEquals(1, aggregators.size()); + return aggregators.get(0).scoreMode().needsScores(); } + private List createTopLevelAggregators(String agg) throws IOException { + final AggregatorFactories factories = getAggregationFactories(agg); + return factories.createTopLevelAggregators(context); + } + + private List createNonGlobalAggregators(String agg) throws IOException { + final AggregatorFactories factories = getAggregationFactories(agg); + return factories.createTopLevelNonGlobalAggregators(context); + } + + private List createGlobalAggregators(String agg) throws IOException { + final AggregatorFactories factories = getAggregationFactories(agg); + return factories.createTopLevelGlobalAggregators(context); + } } diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationProcessorTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationProcessorTests.java new file mode 100644 index 0000000000000..cff83b36ce884 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationProcessorTests.java @@ -0,0 +1,172 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.mockito.ArgumentMatchers; +import org.opensearch.search.aggregations.bucket.global.GlobalAggregator; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.test.TestSearchContext; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AggregationProcessorTests extends AggregationSetupTests { + private final AggregationProcessor testAggregationProcessor = new ConcurrentAggregationProcessor(); + + public void testPreProcessWithNoAggregations() { + testAggregationProcessor.preProcess(context); + assertTrue(context.queryCollectorManagers().isEmpty()); + } + + public void testPreProcessWithOnlyGlobalAggregator() throws Exception { + testPreProcessCommon(globalAgg, 1, 0); + } + + public void testPreProcessWithGlobalAndNonGlobalAggregators() throws Exception { + testPreProcessCommon(globalNonGlobalAggs, 1, 1); + } + + public void testPreProcessWithOnlyNonGlobalAggregators() throws Exception { + testPreProcessCommon(multipleNonGlobalAggs, 0, 2); + } + + public void testPostProcessWithNonGlobalAggregatorsAndSingleSlice() throws Exception { + testPostProcessCommon(multipleNonGlobalAggs, 1, 0, 2); + } + + public void testPostProcessWithNonGlobalAggregatorsAndMultipleSlices() throws Exception { + testPostProcessCommon(multipleNonGlobalAggs, randomIntBetween(2, 5), 0, 2); + } + + public void testPostProcessGlobalAndNonGlobalAggregators() throws Exception { + testPostProcessCommon(globalNonGlobalAggs, randomIntBetween(2, 5), 1, 1); + } + + private void testPreProcessCommon(String agg, int expectedGlobalAggs, int expectedNonGlobalAggs) throws Exception { + testPreProcessCommon(agg, expectedGlobalAggs, expectedNonGlobalAggs, new ArrayList<>(), new ArrayList<>()); + } + + private void testPreProcessCommon( + String agg, + int expectedGlobalAggs, + int expectedNonGlobalAggs, + Collection createdNonGlobalCollectors, + Collection createdGlobalCollectors + ) throws Exception { + final AggregatorFactories aggregatorFactories = getAggregationFactories(agg); + final SearchContextAggregations contextAggregations = new SearchContextAggregations( + aggregatorFactories, + mock(MultiBucketConsumerService.MultiBucketConsumer.class) + ); + context.aggregations(contextAggregations); + testAggregationProcessor.preProcess(context); + CollectorManager globalCollectorManager = null; + CollectorManager nonGlobalCollectorManager = null; + if (expectedNonGlobalAggs == 0 && expectedGlobalAggs == 0) { + assertTrue(context.queryCollectorManagers().isEmpty()); + return; + } else if (expectedGlobalAggs > 0 && expectedNonGlobalAggs > 0) { + assertTrue(context.queryCollectorManagers().containsKey(NonGlobalAggCollectorManager.class)); + assertTrue(context.queryCollectorManagers().containsKey(GlobalAggCollectorManager.class)); + globalCollectorManager = context.queryCollectorManagers().get(GlobalAggCollectorManager.class); + nonGlobalCollectorManager = context.queryCollectorManagers().get(NonGlobalAggCollectorManager.class); + } else if (expectedGlobalAggs == 0) { + assertTrue(context.queryCollectorManagers().containsKey(NonGlobalAggCollectorManager.class)); + assertFalse(context.queryCollectorManagers().containsKey(GlobalAggCollectorManager.class)); + nonGlobalCollectorManager = context.queryCollectorManagers().get(NonGlobalAggCollectorManager.class); + } else { + assertTrue(context.queryCollectorManagers().containsKey(GlobalAggCollectorManager.class)); + assertFalse(context.queryCollectorManagers().containsKey(NonGlobalAggCollectorManager.class)); + globalCollectorManager = context.queryCollectorManagers().get(GlobalAggCollectorManager.class); + } + + Collector aggCollector; + if (expectedGlobalAggs == 1) { + aggCollector = globalCollectorManager.newCollector(); + createdGlobalCollectors.add(aggCollector); + assertTrue(aggCollector instanceof BucketCollector); + assertTrue(aggCollector instanceof GlobalAggregator); + } else if (expectedGlobalAggs > 1) { + aggCollector = globalCollectorManager.newCollector(); + createdGlobalCollectors.add(aggCollector); + assertTrue(aggCollector instanceof MultiBucketCollector); + for (Collector currentCollector : ((MultiBucketCollector) aggCollector).getCollectors()) { + assertTrue(currentCollector instanceof GlobalAggregator); + } + } + + if (expectedNonGlobalAggs == 1) { + aggCollector = nonGlobalCollectorManager.newCollector(); + createdNonGlobalCollectors.add(aggCollector); + assertTrue(aggCollector instanceof BucketCollector); + assertFalse(aggCollector instanceof GlobalAggregator); + } else if (expectedNonGlobalAggs > 1) { + aggCollector = nonGlobalCollectorManager.newCollector(); + createdNonGlobalCollectors.add(aggCollector); + assertTrue(aggCollector instanceof MultiBucketCollector); + for (Collector currentCollector : ((MultiBucketCollector) aggCollector).getCollectors()) { + assertFalse(currentCollector instanceof GlobalAggregator); + } + } + } + + private void testPostProcessCommon(String aggs, int numSlices, int expectedGlobalAggs, int expectedNonGlobalAggsPerSlice) + throws Exception { + final Collection nonGlobalCollectors = new ArrayList<>(); + final Collection globalCollectors = new ArrayList<>(); + testPreProcessCommon(aggs, expectedGlobalAggs, expectedNonGlobalAggsPerSlice, nonGlobalCollectors, globalCollectors); + // newCollector is initialized once in the collector manager constructor + for (int i = 1; i < numSlices; ++i) { + if (expectedNonGlobalAggsPerSlice > 0) { + nonGlobalCollectors.add(context.queryCollectorManagers().get(NonGlobalAggCollectorManager.class).newCollector()); + } + if (expectedGlobalAggs > 0) { + globalCollectors.add(context.queryCollectorManagers().get(GlobalAggCollectorManager.class).newCollector()); + } + } + final ContextIndexSearcher testSearcher = mock(ContextIndexSearcher.class); + final IndexSearcher.LeafSlice[] slicesToReturn = new IndexSearcher.LeafSlice[numSlices]; + when(testSearcher.getSlices()).thenReturn(slicesToReturn); + ((TestSearchContext) context).setSearcher(testSearcher); + AggregationCollectorManager collectorManager; + if (expectedNonGlobalAggsPerSlice > 0) { + collectorManager = (AggregationCollectorManager) context.queryCollectorManagers().get(NonGlobalAggCollectorManager.class); + collectorManager.reduce(nonGlobalCollectors).reduce(context.queryResult()); + } + if (expectedGlobalAggs > 0) { + collectorManager = (AggregationCollectorManager) context.queryCollectorManagers().get(GlobalAggCollectorManager.class); + ReduceableSearchResult result = collectorManager.reduce(globalCollectors); + when(testSearcher.search(nullable(Query.class), ArgumentMatchers.>any())) + .thenReturn(result); + } + assertTrue(context.queryResult().hasAggs()); + testAggregationProcessor.postProcess(context); + assertTrue(context.queryResult().hasAggs()); + // for global aggs verify that search.search is called with CollectionManager + if (expectedGlobalAggs > 0) { + verify(testSearcher, times(1)).search(nullable(Query.class), ArgumentMatchers.>any()); + } + // after shard level reduce it should have only 1 InternalAggregation instance for each agg in request and internal aggregation + // will be equal to sum of expected global and nonglobal aggs + assertEquals(expectedNonGlobalAggsPerSlice + expectedGlobalAggs, context.queryResult().aggregations().expand().aggregations.size()); + assertNull(context.aggregations()); + assertTrue(context.queryCollectorManagers().isEmpty()); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java new file mode 100644 index 0000000000000..0095fd097d3f5 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexService; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +import java.io.IOException; + +public class AggregationSetupTests extends OpenSearchSingleNodeTestCase { + protected IndexService index; + + protected SearchContext context; + + protected final String globalNonGlobalAggs = "{ \"my_terms\": {\"terms\": {\"field\": \"f\"}}, " + + "\"all_products\": {\"global\": {}, \"aggs\": {\"avg_price\": {\"avg\": { \"field\": \"f\"}}}}}"; + + protected final String multipleNonGlobalAggs = "{ \"my_terms\": {\"terms\": {\"field\": \"f\"}}, " + + "\"avg_price\": {\"avg\": { \"field\": \"f\"}}}"; + + protected final String globalAgg = "{ \"all_products\": {\"global\": {}, \"aggs\": {\"avg_price\": {\"avg\": { \"field\": \"f\"}}}}}"; + + @Override + public void setUp() throws Exception { + super.setUp(); + index = createIndex("idx"); + client().prepareIndex("idx").setId("1").setSource("f", 5).execute().get(); + client().admin().indices().prepareRefresh("idx").get(); + context = createSearchContext(index); + } + + protected AggregatorFactories getAggregationFactories(String agg) throws IOException { + try (XContentParser aggParser = createParser(JsonXContent.jsonXContent, agg)) { + aggParser.nextToken(); + return AggregatorFactories.parseAggregators(aggParser).build(context.getQueryShardContext(), null); + } + } +} diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index 414584ae19f5c..7e6d31a51bd4d 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -1274,15 +1274,15 @@ public TotalHitCountCollector newCollector() throws IOException { @Override public ReduceableSearchResult reduce(Collection collectors) throws IOException { - final ReduceableSearchResult result = super.reduce(collectors); totalHits = collectors.stream().mapToInt(TotalHitCountCollector::getTotalHits).sum(); if (teminateAfter != null) { assertThat(totalHits, greaterThanOrEqualTo(teminateAfter)); totalHits = Math.min(totalHits, teminateAfter); } - - return result; + // this collector should not participate in reduce as it is added for test purposes to capture the totalHits count + // returning a ReduceableSearchResult modifies the QueryResult which is not expected + return (result) -> {}; } public int getTotalHits() { diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 0b2235a0afedd..2b7e1450b9fbc 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -52,6 +52,7 @@ import org.opensearch.index.similarity.SimilarityService; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsSearchResult; @@ -639,6 +640,11 @@ public ReaderContext readerContext() { throw new UnsupportedOperationException(); } + @Override + public InternalAggregation.ReduceContext partial() { + return InternalAggregationTestCase.emptyReduceContextBuilder().forPartialReduction(); + } + /** * Clean the query results by consuming all of it */