Skip to content

Commit

Permalink
Do not evaluate shard_size and shard_min_doc_count at slice level for…
Browse files Browse the repository at this point in the history
… concurrent segment search

Signed-off-by: Jay Deng <[email protected]>
  • Loading branch information
jed326 authored and Jay Deng committed Aug 3, 2023
1 parent 4ad4182 commit 978b105
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Exclude 'benchmarks' from codecov report ([#8805](https://github.com/opensearch-project/OpenSearch/pull/8805))
- [Refactor] MediaTypeParser to MediaTypeParserRegistry ([#8636](https://github.com/opensearch-project/OpenSearch/pull/8636))
- Create separate SourceLookup instance per segment slice in SignificantTextAggregatorFactory ([#8807](https://github.com/opensearch-project/OpenSearch/pull/8807))
- Change shard_size and shard_min_doc_count evaluation to happen in shard level reduce phase ([#9085](https://github.com/opensearch-project/OpenSearch/pull/9085))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IO
}

final InternalAggregations internalAggregations = InternalAggregations.from(internals);
// Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce iff multiple slices
// were created to execute this request and it used concurrent segment search path
// Reduce the aggregations across slices before sending to the coordinator.
// We will perform shard level reduce if multiple slices were created to execute this request and the mustReduceOnSingleInternalAgg
// flag is true for the given Aggregation
// TODO: Add the check for flag that the request was executed using concurrent search
if (collectors.size() > 1) {
if (collectors.size() > 1 || ((InternalAggregation) internalAggregations.aggregations.get(0)).mustReduceOnSingleInternalAgg()) {
// using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all
// documents are collected across shards for an aggregation
return new AggregationReduceableSearchResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public final double sortValue(AggregationPath.PathElement head, Iterator<Aggrega
}

@Override
protected boolean mustReduceOnSingleInternalAgg() {
public boolean mustReduceOnSingleInternalAgg() {
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,13 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
final int size;
if (bucketCountThresholds.getMinDocCount() == 0) {
// if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns
size = (int) Math.min(valueCount, bucketCountThresholds.getShardSize());
size = context.isConcurrentSegmentSearchEnabled()
? (int) valueCount
: (int) Math.min(valueCount, bucketCountThresholds.getShardSize());
} else {
size = (int) Math.min(maxBucketOrd(), bucketCountThresholds.getShardSize());
size = context.isConcurrentSegmentSearchEnabled()
? (int) maxBucketOrd()
: (int) Math.min(maxBucketOrd(), bucketCountThresholds.getShardSize());
}
PriorityQueue<TB> ordered = buildPriorityQueue(size);
final int finalOrdIdx = ordIdx;
Expand All @@ -630,7 +634,8 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
@Override
public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException {
otherDocCount[finalOrdIdx] += docCount;
if (docCount >= bucketCountThresholds.getShardMinDocCount()) {
// Don't evaluate shard_min_doc_count at the slice level for concurrent segment search
if (context.isConcurrentSegmentSearchEnabled() || docCount >= bucketCountThresholds.getShardMinDocCount()) {
if (spare == null) {
spare = buildEmptyTemporaryBucket();
}
Expand Down Expand Up @@ -795,7 +800,7 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu
} else {
reduceOrder = order;
}
return new StringTerms(
StringTerms stringTerms = new StringTerms(
name,
reduceOrder,
order,
Expand All @@ -809,6 +814,8 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu
Arrays.asList(topBuckets),
0
);
stringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount());
return stringTerms;
}

@Override
Expand Down Expand Up @@ -922,7 +929,7 @@ void buildSubAggs(SignificantStringTerms.Bucket[][] topBucketsPreOrd) throws IOE

@Override
SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, SignificantStringTerms.Bucket[] topBuckets) {
return new SignificantStringTerms(
SignificantStringTerms significantStringTerms = new SignificantStringTerms(
name,
bucketCountThresholds.getRequiredSize(),
bucketCountThresholds.getMinDocCount(),
Expand All @@ -933,6 +940,9 @@ SignificantStringTerms buildResult(long owningBucketOrd, long otherDocCount, Sig
significanceHeuristic,
Arrays.asList(topBuckets)
);
significantStringTerms.setShardSize(bucketCountThresholds.getShardSize());
significantStringTerms.setShardMinDocCount(bucketCountThresholds.getShardMinDocCount());
return significantStringTerms;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,15 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)

protected final int requiredSize;
protected final long minDocCount;
protected int shardSize;
protected long shardMinDocCount;

protected InternalSignificantTerms(String name, int requiredSize, long minDocCount, Map<String, Object> metadata) {
super(name, metadata);
this.requiredSize = requiredSize;
this.minDocCount = minDocCount;
shardSize = 0;
shardMinDocCount = 0;
}

/**
Expand All @@ -222,8 +226,32 @@ protected final void doWriteTo(StreamOutput out) throws IOException {
@Override
public abstract List<B> getBuckets();

public int getShardSize() {
return shardSize;
}

public void setShardSize(int shardSize) {
this.shardSize = shardSize;
}

public long getShardMinDocCount() {
return shardMinDocCount;
}

public void setShardMinDocCount(long shardMinDocCount) {
this.shardMinDocCount = shardMinDocCount;
}

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
if (reduceContext.isSliceLevel()) {
return reduceOnShard(aggregations, reduceContext);
} else {
return reduceOnCoordinator(aggregations, reduceContext);
}
}

private InternalAggregation reduceOnShard(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long globalSubsetSize = 0;
long globalSupersetSize = 0;
// Compute the overall result set size and the corpus size using the
Expand All @@ -234,12 +262,72 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
globalSubsetSize += terms.getSubsetSize();
// supersetSize is a shard level count, if we sum it across slices we would produce num_slices_with_bucket * supersetSize where
// num_slices_with_bucket is the number of segment slices that have collected a bucket for the key
if (reduceContext.isSliceLevel()) {
globalSupersetSize = terms.getSupersetSize();
globalSupersetSize = terms.getSupersetSize();
}
Map<String, List<B>> buckets = new HashMap<>();
for (InternalAggregation aggregation : aggregations) {
@SuppressWarnings("unchecked")
InternalSignificantTerms<A, B> terms = (InternalSignificantTerms<A, B>) aggregation;
for (B bucket : terms.getBuckets()) {
List<B> existingBuckets = buckets.get(bucket.getKeyAsString());
if (existingBuckets == null) {
existingBuckets = new ArrayList<>(aggregations.size());
buckets.put(bucket.getKeyAsString(), existingBuckets);
}
// Adjust the buckets with the global stats representing the
// total size of the pots from which the stats are drawn
existingBuckets.add(
createBucket(
bucket.getSubsetDf(),
globalSubsetSize,
bucket.getSupersetDf(),
globalSupersetSize,
bucket.aggregations,
bucket
)
);
}
}
SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext);
// Apply shard_size limit at slice level reduce
final int size = Math.min(getShardSize(), buckets.size());
BucketSignificancePriorityQueue<B> ordered = new BucketSignificancePriorityQueue<>(size);
for (Map.Entry<String, List<B>> entry : buckets.entrySet()) {
List<B> sameTermBuckets = entry.getValue();
final B b = reduceBucket(sameTermBuckets, reduceContext);
b.updateScore(heuristic);
// this needs to be simplified greatly
if ((b.score > 0) && (b.subsetDf >= getShardMinDocCount())) {
B removed = ordered.insertWithOverflow(b);
if (removed == null) {
reduceContext.consumeBucketsAndMaybeBreak(1);
} else {
reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed));
}
} else {
globalSupersetSize += terms.getSupersetSize();
reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(b));
}
}
B[] list = createBucketsArray(ordered.size());
for (int i = ordered.size() - 1; i >= 0; i--) {
list[i] = ordered.pop();
}
return create(globalSubsetSize, globalSupersetSize, Arrays.asList(list));
}

private InternalAggregation reduceOnCoordinator(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
long globalSubsetSize = 0;
long globalSupersetSize = 0;
// Compute the overall result set size and the corpus size using the
// top-level Aggregations from each shard
for (InternalAggregation aggregation : aggregations) {
@SuppressWarnings("unchecked")
InternalSignificantTerms<A, B> terms = (InternalSignificantTerms<A, B>) aggregation;
globalSubsetSize += terms.getSubsetSize();
// supersetSize is a shard level count, if we sum it across slices we would produce num_slices_with_bucket * supersetSize where
// num_slices_with_bucket is the number of segment slices that have collected a bucket for the key
globalSupersetSize += terms.getSupersetSize();
}
Map<String, List<B>> buckets = new HashMap<>();
for (InternalAggregation aggregation : aggregations) {
@SuppressWarnings("unchecked")
Expand All @@ -265,7 +353,13 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
}
}
SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext);
final int size = reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size());
// Apply shard_size limit at slice level reduce
final int size;
if (reduceContext.isFinalReduce()) {
size = Math.min(requiredSize, buckets.size());
} else {
size = buckets.size();
}
BucketSignificancePriorityQueue<B> ordered = new BucketSignificancePriorityQueue<>(size);
for (Map.Entry<String, List<B>> entry : buckets.entrySet()) {
List<B> sameTermBuckets = entry.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ public int hashCode() {
protected final BucketOrder order;
protected final int requiredSize;
protected final long minDocCount;
protected long shardMinDocCount;

/**
* Creates a new {@link InternalTerms}
Expand All @@ -246,6 +247,7 @@ protected InternalTerms(
this.order = order;
this.requiredSize = requiredSize;
this.minDocCount = minDocCount;
this.shardMinDocCount = 0;
}

/**
Expand Down Expand Up @@ -329,17 +331,28 @@ protected boolean lessThan(IteratorAndCurrent<B> a, IteratorAndCurrent<B> b) {
pq.add(new IteratorAndCurrent(terms.getBuckets().iterator()));
}
}
List<B> reducedBuckets = new ArrayList<>();
;
final BucketPriorityQueue<B> reducedBuckets;
// list of buckets coming from different shards that have the same key
List<B> currentBuckets = new ArrayList<>();

// Apply shard_size parameter at the slice reduce level if it is > 0
if (reduceContext.isSliceLevel() && getShardSize() > 0) {
reducedBuckets = new BucketPriorityQueue<>(getShardSize(), order.comparator());
} else {
reducedBuckets = new BucketPriorityQueue<>(requiredSize, order.comparator());
}

B lastBucket = null;
while (pq.size() > 0) {
final IteratorAndCurrent<B> top = pq.top();
assert lastBucket == null || cmp.compare(top.current(), lastBucket) >= 0;
if (lastBucket != null && cmp.compare(top.current(), lastBucket) != 0) {
// the key changes, reduce what we already buffered and reset the buffer for current buckets
final B reduced = reduceBucket(currentBuckets, reduceContext);
reducedBuckets.add(reduced);
if (!reduceContext.isSliceLevel() || reduced.getDocCount() >= getShardMinDocCount()) {
reducedBuckets.insertWithOverflow(reduced);
}
currentBuckets.clear();
}
lastBucket = top.current();
Expand All @@ -355,9 +368,17 @@ protected boolean lessThan(IteratorAndCurrent<B> a, IteratorAndCurrent<B> b) {

if (currentBuckets.isEmpty() == false) {
final B reduced = reduceBucket(currentBuckets, reduceContext);
reducedBuckets.add(reduced);
// Apply shard_min_doc_count parameter at the slice reduce level
if (!reduceContext.isSliceLevel() || reduced.getDocCount() >= getShardMinDocCount()) {
reducedBuckets.insertWithOverflow(reduced);
}
}
return reducedBuckets;

// Shards must return buckets sorted by key
List<B> result = new ArrayList<>();
reducedBuckets.forEach(result::add);
result.sort(cmp);
return result;
}

private List<B> reduceLegacy(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
Expand All @@ -376,12 +397,24 @@ private List<B> reduceLegacy(List<InternalAggregation> aggregations, ReduceConte
}
}
}
List<B> reducedBuckets = new ArrayList<>();
final BucketPriorityQueue<B> reducedBuckets;
// Apply shard_size parameter at the slice reduce level if it is > 0
if (reduceContext.isSliceLevel() && getShardSize() > 0) {
reducedBuckets = new BucketPriorityQueue<>(getShardSize(), order.comparator());
} else {
reducedBuckets = new BucketPriorityQueue<>(requiredSize, order.comparator());
}
for (List<B> sameTermBuckets : bucketMap.values()) {
final B b = reduceBucket(sameTermBuckets, reduceContext);
reducedBuckets.add(b);
// Apply shard_min_doc_count parameter at the slice reduce level
if (!reduceContext.isSliceLevel() || b.getDocCount() >= getShardMinDocCount()) {
reducedBuckets.insertWithOverflow(b);
}
}
return reducedBuckets;

List<B> result = new ArrayList<>();
reducedBuckets.forEach(result::add);
return result;
}

public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
Expand Down Expand Up @@ -521,6 +554,14 @@ protected B reduceBucket(List<B> buckets, ReduceContext context) {
return createBucket(docCount, aggs, docCountError, buckets.get(0));
}

protected void setShardMinDocCount(long shardMinDocCount) {
this.shardMinDocCount = shardMinDocCount;
}

protected long getShardMinDocCount() {
return shardMinDocCount;
}

protected abstract void setDocCountError(long docCountError);

protected abstract int getShardSize();
Expand Down
Loading

0 comments on commit 978b105

Please sign in to comment.