diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java index 11052cd215fb6..5e67193081f03 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregationCollectorManager.java @@ -26,7 +26,7 @@ * aggregation operators */ class AggregationCollectorManager implements CollectorManager { - private final SearchContext context; + protected final SearchContext context; private final CheckedFunction, IOException> aggProvider; private final String collectorReason; @@ -63,18 +63,11 @@ public ReduceableSearchResult reduce(Collection collectors) throws IO } final InternalAggregations internalAggregations = InternalAggregations.from(internals); - // Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce iff multiple slices - // were created to execute this request and it used concurrent segment search path - // TODO: Add the check for flag that the request was executed using concurrent search - if (collectors.size() >= 1) { - // using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all - // documents are collected across shards for an aggregation - return new AggregationReduceableSearchResult( - InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard()) - ); - } else { - return new AggregationReduceableSearchResult(internalAggregations); - } + return buildAggregationResult(internalAggregations); + } + + public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) { + return new AggregationReduceableSearchResult(internalAggregations); } static Collector createCollector(SearchContext context, List collectors, String reason) throws IOException { diff --git a/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java index 56f53a57a8573..1dcaee7e2ea6b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java +++ b/server/src/main/java/org/opensearch/search/aggregations/GlobalAggCollectorManager.java @@ -14,6 +14,7 @@ import org.opensearch.search.profile.query.CollectorResult; import java.io.IOException; +import java.util.Collections; import java.util.Objects; /** @@ -38,4 +39,13 @@ public Collector newCollector() throws IOException { return super.newCollector(); } } + + @Override + public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) { + // Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce as long as any slices + // were created so that we can apply shard level bucket count thresholds in the reduce phase. + return new AggregationReduceableSearchResult( + InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard()) + ); + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java index 999ed458f2388..7ac8ac9579a58 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java @@ -40,6 +40,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.rest.action.search.RestSearchAction; import org.opensearch.script.ScriptService; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; import org.opensearch.search.aggregations.support.AggregationPath; @@ -160,6 +161,16 @@ public boolean isSliceLevel() { return this.isSliceLevel; } + // For slice level partial reduce we will apply shard level `shard_size` and `shard_min_doc_count` limits whereas for coordinator + // level partial reduce it will use top level `size` and `min_doc_count` + public int getRequiredSizeLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) { + return isSliceLevel() ? bucketCountThresholds.getShardSize() : bucketCountThresholds.getRequiredSize(); + } + + public long getMinDocCountLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) { + return isSliceLevel() ? bucketCountThresholds.getShardMinDocCount() : bucketCountThresholds.getMinDocCount(); + } + public BigArrays bigArrays() { return bigArrays; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java index 3729734c48ed7..adc7cfa775a97 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java +++ b/server/src/main/java/org/opensearch/search/aggregations/NonGlobalAggCollectorManager.java @@ -14,6 +14,7 @@ import org.opensearch.search.profile.query.CollectorResult; import java.io.IOException; +import java.util.Collections; import java.util.Objects; /** @@ -38,4 +39,13 @@ public Collector newCollector() throws IOException { return super.newCollector(); } } + + @Override + public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) { + // Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce as long as any slices + // were created so that we can apply shard level bucket count thresholds in the reduce phase. + return new AggregationReduceableSearchResult( + InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard()) + ); + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 5718ec08eafdd..eed3f518b089e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -603,15 +603,6 @@ abstract class ResultStrategy< TB extends InternalMultiBucketAggregation.InternalBucket> implements Releasable { private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { - int requiredSizeLocal; - long minDocCountLocal; - if (context.isConcurrentSegmentSearchEnabled()) { - requiredSizeLocal = Integer.MAX_VALUE; - minDocCountLocal = 0; - } else { - requiredSizeLocal = bucketCountThresholds.getShardSize(); - minDocCountLocal = bucketCountThresholds.getShardMinDocCount(); - } if (valueCount == 0) { // no context in this reader InternalAggregation[] results = new InternalAggregation[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { @@ -624,11 +615,11 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws long[] otherDocCount = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { final int size; - if (minDocCountLocal == 0) { + if (context.getMinDocCountLocal(bucketCountThresholds) == 0) { // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns - size = (int) Math.min(valueCount, requiredSizeLocal); + size = (int) Math.min(valueCount, context.getRequiredSizeLocal(bucketCountThresholds)); } else { - size = (int) Math.min(maxBucketOrd(), requiredSizeLocal); + size = (int) Math.min(maxBucketOrd(), context.getRequiredSizeLocal(bucketCountThresholds)); } PriorityQueue ordered = buildPriorityQueue(size); final int finalOrdIdx = ordIdx; @@ -639,7 +630,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws @Override public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException { otherDocCount[finalOrdIdx] += docCount; - if (docCount >= minDocCountLocal) { + if (docCount >= context.getMinDocCountLocal(bucketCountThresholds)) { if (spare == null) { spare = buildEmptyTemporaryBucket(); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMappedTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMappedTerms.java index 943d8e37eed15..8d779263490bb 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMappedTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMappedTerms.java @@ -52,6 +52,7 @@ */ public abstract class InternalMappedTerms, B extends InternalTerms.Bucket> extends InternalTerms { protected final DocValueFormat format; + protected final int shardSize; protected final boolean showTermDocCountError; protected final long otherDocCount; protected final List buckets; @@ -73,6 +74,7 @@ protected InternalMappedTerms( ) { super(name, reduceOrder, order, bucketCountThresholds, metadata); this.format = format; + this.shardSize = bucketCountThresholds.getShardSize(); this.showTermDocCountError = showTermDocCountError; this.otherDocCount = otherDocCount; this.docCountError = docCountError; diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java index 89fd488d21e23..915ed2d3870a2 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalMultiTerms.java @@ -220,6 +220,7 @@ public int compare(List thisObjects, List thatObjects) { } } + private final int shardSize; private final boolean showTermDocCountError; private final long otherDocCount; private final List termFormats; diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 8eb3919f4b88d..a3140be9f5bff 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -39,7 +39,6 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.aggregations.InternalMultiBucketAggregation; -import org.opensearch.search.aggregations.bucket.BucketUtils; import org.opensearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic; import java.io.IOException; @@ -196,8 +195,6 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) protected final int requiredSize; protected final long minDocCount; - protected int shardSize; - protected final long shardMinDocCount; protected final TermsAggregator.BucketCountThresholds bucketCountThresholds; protected InternalSignificantTerms( @@ -208,8 +205,6 @@ protected InternalSignificantTerms( super(name, metadata); this.requiredSize = bucketCountThresholds.getRequiredSize(); this.minDocCount = bucketCountThresholds.getMinDocCount(); - this.shardSize = bucketCountThresholds.getShardSize(); - this.shardMinDocCount = bucketCountThresholds.getShardMinDocCount(); this.bucketCountThresholds = bucketCountThresholds; } @@ -220,9 +215,9 @@ protected InternalSignificantTerms(StreamInput in) throws IOException { super(in); requiredSize = readSize(in); minDocCount = in.readVLong(); - shardSize = BucketUtils.suggestShardSideQueueSize(requiredSize); - shardMinDocCount = 0; - bucketCountThresholds = new TermsAggregator.BucketCountThresholds(minDocCount, shardMinDocCount, requiredSize, shardSize); + // shardMinDocCount and shardSize are not used on the coordinator, so they are not deserialized. We use + // CoordinatorBucketCountThresholds which will throw an exception if they are accessed. + bucketCountThresholds = new TermsAggregator.CoordinatorBucketCountThresholds(minDocCount, -1, requiredSize, -1); } protected final void doWriteTo(StreamOutput out) throws IOException { @@ -238,15 +233,6 @@ protected final void doWriteTo(StreamOutput out) throws IOException { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { - int requiredSizeLocal; - long minDocCountLocal; - if (reduceContext.isSliceLevel()) { - requiredSizeLocal = bucketCountThresholds.getShardSize(); - minDocCountLocal = bucketCountThresholds.getShardMinDocCount(); - } else { - requiredSizeLocal = bucketCountThresholds.getRequiredSize(); - minDocCountLocal = bucketCountThresholds.getMinDocCount(); - } long globalSubsetSize = 0; long globalSupersetSize = 0; @@ -289,17 +275,21 @@ public InternalAggregation reduce(List aggregations, Reduce } } SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext); - final int size = (reduceContext.isFinalReduce() || reduceContext.isSliceLevel()) - ? Math.min(requiredSizeLocal, buckets.size()) + boolean isCoordinatorPartialReduce = reduceContext.isFinalReduce() == false && reduceContext.isSliceLevel() == false; + // Do not apply size threshold on coordinator partial reduce + final int size = !isCoordinatorPartialReduce + ? Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), buckets.size()) : buckets.size(); BucketSignificancePriorityQueue ordered = new BucketSignificancePriorityQueue<>(size); for (Map.Entry> entry : buckets.entrySet()) { List sameTermBuckets = entry.getValue(); final B b = reduceBucket(sameTermBuckets, reduceContext); b.updateScore(heuristic); - if (!(reduceContext.isFinalReduce() || reduceContext.isSliceLevel()) // Don't apply thresholds for partial reduce - || (reduceContext.isSliceLevel() && (b.subsetDf >= minDocCountLocal)) // Score needs to be evaluated only at the coordinator - || ((b.score > 0) && (b.subsetDf >= minDocCountLocal))) { + // For concurrent search case we do not apply bucket count thresholds in buildAggregation and instead is done here during + // reduce. However, the bucket score is only evaluated at the final coordinator reduce. + boolean meetsThresholds = (b.subsetDf >= reduceContext.getMinDocCountLocal(bucketCountThresholds)) + && (((b.score > 0) || reduceContext.isSliceLevel())); + if (isCoordinatorPartialReduce || meetsThresholds) { B removed = ordered.insertWithOverflow(b); if (removed == null) { reduceContext.consumeBucketsAndMaybeBreak(1); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java index 7489822371011..844f68f283eaf 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java @@ -45,7 +45,6 @@ import org.opensearch.search.aggregations.InternalMultiBucketAggregation; import org.opensearch.search.aggregations.InternalOrder; import org.opensearch.search.aggregations.KeyComparable; -import org.opensearch.search.aggregations.bucket.BucketUtils; import org.opensearch.search.aggregations.bucket.IteratorAndCurrent; import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; @@ -58,7 +57,6 @@ import java.util.Map; import java.util.Objects; import java.util.function.Function; -import java.util.stream.Collectors; import static org.opensearch.search.aggregations.InternalOrder.isKeyAsc; import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder; @@ -225,8 +223,6 @@ public int hashCode() { protected final BucketOrder order; protected final int requiredSize; protected final long minDocCount; - protected int shardSize; - protected final long shardMinDocCount; protected final TermsAggregator.BucketCountThresholds bucketCountThresholds; /** @@ -250,8 +246,6 @@ protected InternalTerms( this.bucketCountThresholds = bucketCountThresholds; this.requiredSize = bucketCountThresholds.getRequiredSize(); this.minDocCount = bucketCountThresholds.getMinDocCount(); - this.shardSize = bucketCountThresholds.getShardSize(); - this.shardMinDocCount = bucketCountThresholds.getShardMinDocCount(); } /** @@ -263,9 +257,9 @@ protected InternalTerms(StreamInput in) throws IOException { order = InternalOrder.Streams.readOrder(in); requiredSize = readSize(in); minDocCount = in.readVLong(); - shardSize = BucketUtils.suggestShardSideQueueSize(requiredSize); - shardMinDocCount = 0; - bucketCountThresholds = new TermsAggregator.BucketCountThresholds(minDocCount, shardMinDocCount, requiredSize, shardSize); + // shardMinDocCount and shardSize are not used on the coordinator, so they are not deserialized. We use + // CoordinatorBucketCountThresholds which will throw an exception if they are accessed. + bucketCountThresholds = new TermsAggregator.CoordinatorBucketCountThresholds(minDocCount, -1, requiredSize, getShardSize()); } @Override @@ -394,16 +388,6 @@ private List reduceLegacy(List aggregations, ReduceConte } public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { - int requiredSizeLocal; - long minDocCountLocal; - if (reduceContext.isSliceLevel()) { - requiredSizeLocal = bucketCountThresholds.getShardSize(); - minDocCountLocal = bucketCountThresholds.getShardMinDocCount(); - } else { - requiredSizeLocal = bucketCountThresholds.getRequiredSize(); - minDocCountLocal = bucketCountThresholds.getMinDocCount(); - } - long sumDocCountError = 0; long otherDocCount = 0; InternalTerms referenceTerms = null; @@ -464,7 +448,7 @@ public InternalAggregation reduce(List aggregations, Reduce } final B[] list; if (reduceContext.isFinalReduce() || reduceContext.isSliceLevel()) { - final int size = Math.min(requiredSizeLocal, reducedBuckets.size()); + final int size = Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), reducedBuckets.size()); // final comparator final BucketPriorityQueue ordered = new BucketPriorityQueue<>(size, order.comparator()); for (B bucket : reducedBuckets) { @@ -474,7 +458,7 @@ public InternalAggregation reduce(List aggregations, Reduce final long finalSumDocCountError = sumDocCountError; bucket.setDocCountError(docCountError -> docCountError + finalSumDocCountError); } - if (bucket.getDocCount() >= minDocCountLocal) { + if (bucket.getDocCount() >= reduceContext.getMinDocCountLocal(bucketCountThresholds)) { B removed = ordered.insertWithOverflow(bucket); if (removed != null) { otherDocCount += removed.getDocCount(); @@ -493,8 +477,8 @@ public InternalAggregation reduce(List aggregations, Reduce } else { // we can prune the list on partial reduce if the aggregation is ordered by key // and not filtered (minDocCount == 0) - int size = isKeyOrder(order) && minDocCountLocal == 0 - ? Math.min(requiredSizeLocal, reducedBuckets.size()) + int size = isKeyOrder(order) && reduceContext.getMinDocCountLocal(bucketCountThresholds) == 0 + ? Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), reducedBuckets.size()) : reducedBuckets.size(); list = createBucketsArray(size); for (int i = 0; i < size; i++) { @@ -515,14 +499,11 @@ public InternalAggregation reduce(List aggregations, Reduce docCountError = aggregations.size() == 1 ? 0 : sumDocCountError; } - // Shards must return buckets sorted by key, so we apply the sort here - List resultList; + // Shards must return buckets sorted by key, so we apply the sort here in shard level reduce if (reduceContext.isSliceLevel()) { - resultList = Arrays.stream(list).sorted(thisReduceOrder.comparator()).collect(Collectors.toList()); - } else { - resultList = Arrays.asList(list); + Arrays.sort(list, thisReduceOrder.comparator()); } - return create(name, resultList, reduceContext.isFinalReduce() ? order : thisReduceOrder, docCountError, otherDocCount); + return create(name, Arrays.asList(list), reduceContext.isFinalReduce() ? order : thisReduceOrder, docCountError, otherDocCount); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java index a4269a216050b..2537f9a9fd097 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java @@ -244,20 +244,11 @@ abstract class ResultStrategy ordered = buildPriorityQueue(size); B spare = null; @@ -266,7 +257,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < minDocCountLocal) { + if (docCount < context.getMinDocCountLocal(bucketCountThresholds)) { continue; } if (spare == null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index 499e1743856dd..b3ff7f11a7460 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -118,22 +118,13 @@ public MultiTermsAggregator( @Override public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { - int requiredSizeLocal; - long minDocCountLocal; - if (context.isConcurrentSegmentSearchEnabled()) { - requiredSizeLocal = Integer.MAX_VALUE; - minDocCountLocal = 0; - } else { - requiredSizeLocal = bucketCountThresholds.getShardSize(); - minDocCountLocal = bucketCountThresholds.getShardMinDocCount(); - } InternalMultiTerms.Bucket[][] topBucketsPerOrd = new InternalMultiTerms.Bucket[owningBucketOrds.length][]; long[] otherDocCounts = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); - int size = (int) Math.min(bucketsInOrd, requiredSizeLocal); + int size = (int) Math.min(bucketsInOrd, context.getRequiredSizeLocal(bucketCountThresholds)); PriorityQueue ordered = new BucketPriorityQueue<>(size, partiallyBuiltBucketComparator); InternalMultiTerms.Bucket spare = null; BytesRef dest = null; @@ -145,7 +136,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < minDocCountLocal) { + if (docCount < context.getMinDocCountLocal(bucketCountThresholds)) { continue; } if (spare == null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java index 53e1732e729fd..6f7e0e4bf5afd 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java @@ -173,22 +173,13 @@ abstract class ResultStrategy ordered = buildPriorityQueue(size); B spare = null; BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds[ordIdx]); @@ -196,7 +187,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < minDocCountLocal) { + if (docCount < context.getMinDocCountLocal(bucketCountThresholds)) { continue; } if (spare == null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregator.java index 9e2aa85bb1dd8..d9fe1eeefceea 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/TermsAggregator.java @@ -39,6 +39,7 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationExecutionException; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.BucketOrder; @@ -70,9 +71,9 @@ public abstract class TermsAggregator extends DeferableBucketAggregator { */ public static class BucketCountThresholds implements Writeable, ToXContentFragment { private long minDocCount; - private long shardMinDocCount; + protected long shardMinDocCount; private int requiredSize; - private int shardSize; + protected int shardSize; public BucketCountThresholds(long minDocCount, long shardMinDocCount, int requiredSize, int shardSize) { this.minDocCount = minDocCount; @@ -195,6 +196,29 @@ public boolean equals(Object obj) { } } + // BucketCountThresholds type that throws an exception when shardMinDocCount and shardSize are accessed. This is used for + // deserialization on the coordinator during reduce as shardMinDocCount and shardSize should not be accessed this way on the + // coordinator. + public static class CoordinatorBucketCountThresholds extends BucketCountThresholds { + + public CoordinatorBucketCountThresholds(long minDocCount, long shardMinDocCount, int requiredSize, int shardSize) { + super(minDocCount, shardMinDocCount, requiredSize, shardSize); + } + + @Override + public long getShardMinDocCount() { + throw new AggregationExecutionException("shard_min_doc_count should not be accessed"); + } + + @Override + public int getShardSize() { + if (shardSize < 0) { + throw new AggregationExecutionException("Invalid shard_size accessed"); + } + return shardSize; + } + } + protected final DocValueFormat format; protected final BucketCountThresholds bucketCountThresholds; protected final BucketOrder order; diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index bc2a0658e5a6d..600f91b0dc42e 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -35,6 +35,7 @@ import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; +import org.apache.lucene.util.ArrayUtil; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.common.Nullable; @@ -57,6 +58,7 @@ import org.opensearch.search.aggregations.BucketCollectorProcessor; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.SearchContextAggregations; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsSearchResult; import org.opensearch.search.fetch.FetchPhase; @@ -400,6 +402,20 @@ public boolean isConcurrentSegmentSearchEnabled() { return false; } + /** + * Returns the local size threshold based on search context + */ + public int getRequiredSizeLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) { + return isConcurrentSegmentSearchEnabled() ? ArrayUtil.MAX_ARRAY_LENGTH - 1 : bucketCountThresholds.getShardSize(); + } + + /** + * Returns the local minDocCount threshold based on search context + */ + public long getMinDocCountLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) { + return isConcurrentSegmentSearchEnabled() ? 0 : bucketCountThresholds.getShardMinDocCount(); + } + /** * Adds a releasable that will be freed when this context is closed. */ diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 874d60a4097f2..2c18d70e6838c 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -350,6 +350,8 @@ public boolean shouldCache(Query query) { when(searchContext.aggregations()).thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer)); when(searchContext.query()).thenReturn(query); when(searchContext.bucketCollectorProcessor()).thenReturn(new BucketCollectorProcessor()); + when(searchContext.getMinDocCountLocal(any())).thenCallRealMethod(); + when(searchContext.getRequiredSizeLocal(any())).thenCallRealMethod(); /* * Always use the circuit breaking big arrays instance so that the CircuitBreakerService * we're passed gets a chance to break.