From 95e8eab530ad3c35307f6540f3d064b61dc6b8de Mon Sep 17 00:00:00 2001 From: Jay Deng Date: Wed, 2 Aug 2023 22:00:01 -0700 Subject: [PATCH] temp --- .../GlobalOrdinalsStringTermsAggregator.java | 16 ++++-- .../terms/InternalSignificantTerms.java | 57 ++++++++++++++++--- .../bucket/terms/InternalTerms.java | 37 ++++++++++-- .../terms/MapStringTermsAggregator.java | 20 +++++-- .../bucket/terms/MultiTermsAggregator.java | 15 ++++- .../bucket/terms/NumericTermsAggregator.java | 28 +++++++-- .../search/internal/ContextIndexSearcher.java | 7 +++ .../test/OpenSearchIntegTestCase.java | 1 + 8 files changed, 152 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index e0a22435b8f48..a4b6811d0e277 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -617,9 +617,9 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws final int size; if (bucketCountThresholds.getMinDocCount() == 0) { // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns - size = (int) Math.min(valueCount, bucketCountThresholds.getShardSize()); + size = context.isConcurrentSegmentSearchEnabled() ? (int) valueCount : (int) Math.min(valueCount, bucketCountThresholds.getShardSize()); } else { - size = (int) Math.min(maxBucketOrd(), bucketCountThresholds.getShardSize()); + size = context.isConcurrentSegmentSearchEnabled() ? (int) maxBucketOrd() : (int) Math.min(maxBucketOrd(), bucketCountThresholds.getShardSize()); } PriorityQueue ordered = buildPriorityQueue(size); final int finalOrdIdx = ordIdx; @@ -630,7 +630,8 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws @Override public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException { otherDocCount[finalOrdIdx] += docCount; - if (docCount >= bucketCountThresholds.getShardMinDocCount()) { + // Don't evaluate shard_min_doc_count at the slice level for concurrent segment search + if (context.isConcurrentSegmentSearchEnabled() || docCount >= bucketCountThresholds.getShardMinDocCount()) { if (spare == null) { spare = buildEmptyTemporaryBucket(); } @@ -795,7 +796,7 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu } else { reduceOrder = order; } - return new StringTerms( + StringTerms stringTerms = new StringTerms( name, reduceOrder, order, @@ -809,6 +810,8 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu Arrays.asList(topBuckets), 0 ); + stringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return stringTerms; } @Override @@ -922,7 +925,7 @@ void buildSubAggs(SignificantStringTerms.Bucket[][] topBucketsPreOrd) throws IOE @Override SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, SignificantStringTerms.Bucket[] topBuckets) { - return new SignificantStringTerms( + SignificantStringTerms significantStringTerms = new SignificantStringTerms( name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), @@ -933,6 +936,9 @@ SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, Sig significanceHeuristic, Arrays.asList(topBuckets) ); + significantStringTerms.setShardSize(bucketCountThresholds.getShardSize()); + significantStringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return significantStringTerms; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 84d148199a7f9..571c4e87e7e7b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -195,11 +195,15 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) protected final int requiredSize; protected final long minDocCount; + protected int shardSize; + protected long shardMinDocCount; protected InternalSignificantTerms(String name, int requiredSize, long minDocCount, Map metadata) { super(name, metadata); this.requiredSize = requiredSize; this.minDocCount = minDocCount; + shardSize = 0; + shardMinDocCount = 0; } /** @@ -222,8 +226,25 @@ protected final void doWriteTo(StreamOutput out) throws IOException { @Override public abstract List getBuckets(); + public int getShardSize() { + return shardSize; + } + + public void setShardSize(int shardSize) { + this.shardSize = shardSize; + } + + public long getShardMinDocCount() { + return shardMinDocCount; + } + + public void setShardMinDocCount(long shardMinDocCount) { + this.shardMinDocCount = shardMinDocCount; + } + @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + // TODO: Refactor this into slice level and shard level reduce in different helper functions long globalSubsetSize = 0; long globalSupersetSize = 0; // Compute the overall result set size and the corpus size using the @@ -265,21 +286,43 @@ public InternalAggregation reduce(List aggregations, Reduce } } SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext); - final int size = reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()); + // Apply shard_size limit at slice level reduce + final int size; + if (reduceContext.isFinalReduce()) { + size = Math.min(requiredSize, buckets.size()); + } else if (reduceContext.isSliceLevel()) { + size = Math.min(getShardSize(), buckets.size()); + } else { + size = buckets.size(); + } BucketSignificancePriorityQueue ordered = new BucketSignificancePriorityQueue<>(size); for (Map.Entry> entry : buckets.entrySet()) { List sameTermBuckets = entry.getValue(); final B b = reduceBucket(sameTermBuckets, reduceContext); b.updateScore(heuristic); - if (((b.score > 0) && (b.subsetDf >= minDocCount)) || reduceContext.isFinalReduce() == false) { - B removed = ordered.insertWithOverflow(b); - if (removed == null) { - reduceContext.consumeBucketsAndMaybeBreak(1); + // this needs to be simplified greatly + if (reduceContext.isSliceLevel()) { + if ((b.score > 0) && (b.subsetDf >= getShardMinDocCount())) { + B removed = ordered.insertWithOverflow(b); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(b)); } } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(b)); + if (((b.score > 0) && (b.subsetDf >= minDocCount)) || reduceContext.isFinalReduce() == false) { + B removed = ordered.insertWithOverflow(b); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(b)); + } } } B[] list = createBucketsArray(ordered.size()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java index 9a80155eea51c..bd983fa5eb9e1 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java @@ -223,6 +223,7 @@ public int hashCode() { protected final BucketOrder order; protected final int requiredSize; protected final long minDocCount; + protected long shardMinDocCount; /** * Creates a new {@link InternalTerms} @@ -246,6 +247,7 @@ protected InternalTerms( this.order = order; this.requiredSize = requiredSize; this.minDocCount = minDocCount; + this.shardMinDocCount = 0; } /** @@ -329,9 +331,19 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent b) { pq.add(new IteratorAndCurrent(terms.getBuckets().iterator())); } } - List reducedBuckets = new ArrayList<>(); + ; + final BucketPriorityQueue reducedBuckets; // list of buckets coming from different shards that have the same key List currentBuckets = new ArrayList<>(); + + // Apply shard_size parameter at the slice reduce level if it is > 0 + if (reduceContext.isSliceLevel() && getShardSize() > 0) { + final int size = Math.min(getShardSize(), aggregations.size()); + reducedBuckets = new BucketPriorityQueue<>(size, order.comparator()); + } else { + reducedBuckets = new BucketPriorityQueue<>(requiredSize, order.comparator()); + } + B lastBucket = null; while (pq.size() > 0) { final IteratorAndCurrent top = pq.top(); @@ -339,7 +351,9 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent b) { if (lastBucket != null && cmp.compare(top.current(), lastBucket) != 0) { // the key changes, reduce what we already buffered and reset the buffer for current buckets final B reduced = reduceBucket(currentBuckets, reduceContext); - reducedBuckets.add(reduced); + if (!reduceContext.isSliceLevel() || reduced.getDocCount() >= getShardMinDocCount()) { + reducedBuckets.insertWithOverflow(reduced); + } currentBuckets.clear(); } lastBucket = top.current(); @@ -355,9 +369,16 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent b) { if (currentBuckets.isEmpty() == false) { final B reduced = reduceBucket(currentBuckets, reduceContext); - reducedBuckets.add(reduced); + if (!reduceContext.isSliceLevel() || reduced.getDocCount() >= getShardMinDocCount()) { + reducedBuckets.insertWithOverflow(reduced); + } } - return reducedBuckets; + + // Shards must return buckets sorted by key + List result = new ArrayList<>(); + reducedBuckets.forEach(result::add); + result.sort(cmp); + return result; } private List reduceLegacy(List aggregations, ReduceContext reduceContext) { @@ -521,6 +542,14 @@ protected B reduceBucket(List buckets, ReduceContext context) { return createBucket(docCount, aggs, docCountError, buckets.get(0)); } + protected void setShardMinDocCount(long shardMinDocCount) { + this.shardMinDocCount = shardMinDocCount; + } + + protected long getShardMinDocCount() { + return shardMinDocCount; + } + protected abstract void setDocCountError(long docCountError); protected abstract int getShardSize(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java index bcdf1f4480a31..f09abfdf9b009 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java @@ -248,8 +248,14 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws long[] otherDocCounts = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); - int size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); + // Do not apply shard_size at the slice level for concurrent segment search + int size; + if (context.isConcurrentSegmentSearchEnabled()) { + size = (int) bucketOrds.size(); + } else { + size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); + } PriorityQueue ordered = buildPriorityQueue(size); B spare = null; BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds[ordIdx]); @@ -257,7 +263,8 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < bucketCountThresholds.getShardMinDocCount()) { + // Don't evaluate shard_min_doc_count at the slice level for concurrent segment search + if (!context.isConcurrentSegmentSearchEnabled() && docCount < bucketCountThresholds.getShardMinDocCount()) { continue; } if (spare == null) { @@ -450,7 +457,7 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu } else { reduceOrder = order; } - return new StringTerms( + StringTerms stringTerms = new StringTerms( name, reduceOrder, order, @@ -464,6 +471,8 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu Arrays.asList(topBuckets), 0 ); + stringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return stringTerms; } @Override @@ -570,7 +579,7 @@ void buildSubAggs(SignificantStringTerms.Bucket[][] topBucketsPerOrd) throws IOE @Override SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, SignificantStringTerms.Bucket[] topBuckets) { - return new SignificantStringTerms( + SignificantStringTerms significantStringTerms = new SignificantStringTerms( name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), @@ -581,6 +590,9 @@ SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, Sig significanceHeuristic, Arrays.asList(topBuckets) ); + significantStringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + significantStringTerms.setShardSize(bucketCountThresholds.getShardSize()); + return significantStringTerms; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index 9d99c0b90a075..14d52322cf0e2 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -124,7 +124,13 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); - int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + // Do not apply shard_size at the slice level for concurrent segment search + int size; + if (context.isConcurrentSegmentSearchEnabled()) { + size = (int) bucketsInOrd; + } else { + size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + } PriorityQueue ordered = new BucketPriorityQueue<>(size, partiallyBuiltBucketComparator); InternalMultiTerms.Bucket spare = null; BytesRef dest = null; @@ -136,7 +142,8 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < bucketCountThresholds.getShardMinDocCount()) { + // Don't evaluate shard_min_doc_count at the slice level for concurrent segment search + if (!context.isConcurrentSegmentSearchEnabled() && docCount < bucketCountThresholds.getShardMinDocCount()) { continue; } if (spare == null) { @@ -178,7 +185,7 @@ InternalMultiTerms buildResult(long owningBucketOrd, long otherDocCount, Interna } else { reduceOrder = order; } - return new InternalMultiTerms( + InternalMultiTerms internalMultiTerms = new InternalMultiTerms( name, reduceOrder, order, @@ -192,6 +199,8 @@ InternalMultiTerms buildResult(long owningBucketOrd, long otherDocCount, Interna formats, List.of(topBuckets) ); + internalMultiTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return internalMultiTerms; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java index a0265135fe9d3..4bc9cc4cb8d63 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java @@ -179,7 +179,13 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); - int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + // Do not apply shard_size at the slice level for concurrent segment search + int size; + if (context.isConcurrentSegmentSearchEnabled()) { + size = (int) bucketsInOrd; + } else { + size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + } PriorityQueue ordered = buildPriorityQueue(size); B spare = null; BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds[ordIdx]); @@ -187,7 +193,8 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws while (ordsEnum.next()) { long docCount = bucketDocCount(ordsEnum.ord()); otherDocCounts[ordIdx] += docCount; - if (docCount < bucketCountThresholds.getShardMinDocCount()) { + // Don't evaluate shard_min_doc_count at the slice level for concurrent segment search + if (!context.isConcurrentSegmentSearchEnabled() && docCount < bucketCountThresholds.getShardMinDocCount()) { continue; } if (spare == null) { @@ -391,7 +398,7 @@ LongTerms buildResult(long owningBucketOrd, long otherDocCount, LongTerms.Bucket } else { reduceOrder = order; } - return new LongTerms( + LongTerms longTerms = new LongTerms( name, reduceOrder, order, @@ -405,6 +412,8 @@ LongTerms buildResult(long owningBucketOrd, long otherDocCount, LongTerms.Bucket List.of(topBuckets), 0 ); + longTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return longTerms; } @Override @@ -473,7 +482,7 @@ DoubleTerms buildResult(long owningBucketOrd, long otherDocCount, DoubleTerms.Bu } else { reduceOrder = order; } - return new DoubleTerms( + DoubleTerms doubleTerms = new DoubleTerms( name, reduceOrder, order, @@ -487,6 +496,8 @@ DoubleTerms buildResult(long owningBucketOrd, long otherDocCount, DoubleTerms.Bu List.of(topBuckets), 0 ); + doubleTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return doubleTerms; } @Override @@ -554,7 +565,7 @@ UnsignedLongTerms buildResult(long owningBucketOrd, long otherDocCount, Unsigned } else { reduceOrder = order; } - return new UnsignedLongTerms( + UnsignedLongTerms unsignedLongTerms = new UnsignedLongTerms( name, reduceOrder, order, @@ -568,6 +579,8 @@ UnsignedLongTerms buildResult(long owningBucketOrd, long otherDocCount, Unsigned List.of(topBuckets), 0 ); + unsignedLongTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + return unsignedLongTerms; } @Override @@ -670,7 +683,7 @@ void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOException {} @Override SignificantLongTerms buildResult(long owningBucketOrd, long otherDocCoun, SignificantLongTerms.Bucket[] topBuckets) { - return new SignificantLongTerms( + SignificantLongTerms significantLongTerms = new SignificantLongTerms( name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), @@ -681,6 +694,9 @@ SignificantLongTerms buildResult(long owningBucketOrd, long otherDocCoun, Signif significanceHeuristic, List.of(topBuckets) ); + significantLongTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount()); + significantLongTerms.setShardSize(bucketCountThresholds.getShardSize()); + return significantLongTerms; } @Override diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java index e3ca932eb4699..0be1063cbc181 100644 --- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java @@ -518,4 +518,11 @@ private boolean shouldReverseLeafReaderContexts() { } return false; } + + @Override + protected LeafSlice[] slices(List leaves) { + final LeafSlice[] leafSlices = slices(leaves, 1, 1); + System.out.println("ManualTesting-========= slice count: " + leafSlices.length); + return leafSlices; + } } diff --git a/test/framework/src/main/java/org/opensearch/test/OpenSearchIntegTestCase.java b/test/framework/src/main/java/org/opensearch/test/OpenSearchIntegTestCase.java index 3564bd667ee2b..e47f01b33ffc3 100644 --- a/test/framework/src/main/java/org/opensearch/test/OpenSearchIntegTestCase.java +++ b/test/framework/src/main/java/org/opensearch/test/OpenSearchIntegTestCase.java @@ -779,6 +779,7 @@ protected Settings featureFlagSettings() { featureSettings.put(builtInFlag.getKey(), builtInFlag.getDefaultRaw(Settings.EMPTY)); } featureSettings.put(FeatureFlags.TELEMETRY_SETTING.getKey(), true); + featureSettings.put(FeatureFlags.CONCURRENT_SEGMENT_SEARCH_SETTING.getKey(), true); return featureSettings.build(); }