Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Deng <[email protected]>
  • Loading branch information
jed326 authored and Jay Deng committed Aug 9, 2023
1 parent 5a9a704 commit f239aec
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* aggregation operators
*/
class AggregationCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {
private final SearchContext context;
protected final SearchContext context;
private final CheckedFunction<SearchContext, List<Aggregator>, IOException> aggProvider;
private final String collectorReason;

Expand Down Expand Up @@ -63,18 +63,11 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IO
}

final InternalAggregations internalAggregations = InternalAggregations.from(internals);
// Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce iff multiple slices
// were created to execute this request and it used concurrent segment search path
// TODO: Add the check for flag that the request was executed using concurrent search
if (collectors.size() >= 1) {
// using reduce is fine here instead of topLevelReduce as pipeline aggregation is evaluated on the coordinator after all
// documents are collected across shards for an aggregation
return new AggregationReduceableSearchResult(
InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard())
);
} else {
return new AggregationReduceableSearchResult(internalAggregations);
}
return buildAggregationResult(internalAggregations);
}

public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) {
return new AggregationReduceableSearchResult(internalAggregations);
}

static Collector createCollector(SearchContext context, List<Aggregator> collectors, String reason) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.search.profile.query.CollectorResult;

import java.io.IOException;
import java.util.Collections;
import java.util.Objects;

/**
Expand All @@ -38,4 +39,13 @@ public Collector newCollector() throws IOException {
return super.newCollector();
}
}

@Override
public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) {
// Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce as long as any slices
// were created so that we can apply shard level bucket count thresholds in the reduce phase.
return new AggregationReduceableSearchResult(
InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.rest.action.search.RestSearchAction;
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregator;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.support.AggregationPath;
Expand Down Expand Up @@ -160,6 +161,16 @@ public boolean isSliceLevel() {
return this.isSliceLevel;
}

// For slice level partial reduce we will apply shard level `shard_size` and `shard_min_doc_count` limits whereas for coordinator
// level partial reduce it will use top level `size` and `min_doc_count`
public int getRequiredSizeLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) {
return isSliceLevel() ? bucketCountThresholds.getShardSize() : bucketCountThresholds.getRequiredSize();
}

public long getMinDocCountLocal(TermsAggregator.BucketCountThresholds bucketCountThresholds) {
return isSliceLevel() ? bucketCountThresholds.getShardMinDocCount() : bucketCountThresholds.getMinDocCount();
}

public BigArrays bigArrays() {
return bigArrays;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.search.profile.query.CollectorResult;

import java.io.IOException;
import java.util.Collections;
import java.util.Objects;

/**
Expand All @@ -38,4 +39,13 @@ public Collector newCollector() throws IOException {
return super.newCollector();
}
}

@Override
public AggregationReduceableSearchResult buildAggregationResult(InternalAggregations internalAggregations) {
// Reduce the aggregations across slices before sending to the coordinator. We will perform shard level reduce as long as any slices
// were created so that we can apply shard level bucket count thresholds in the reduce phase.
return new AggregationReduceableSearchResult(
InternalAggregations.reduce(Collections.singletonList(internalAggregations), context.partialOnShard())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -603,15 +603,6 @@ abstract class ResultStrategy<
TB extends InternalMultiBucketAggregation.InternalBucket> implements Releasable {

private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
int requiredSizeLocal;
long minDocCountLocal;
if (context.isConcurrentSegmentSearchEnabled()) {
requiredSizeLocal = Integer.MAX_VALUE;
minDocCountLocal = 0;
} else {
requiredSizeLocal = bucketCountThresholds.getShardSize();
minDocCountLocal = bucketCountThresholds.getShardMinDocCount();
}
if (valueCount == 0) { // no context in this reader
InternalAggregation[] results = new InternalAggregation[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
Expand All @@ -624,11 +615,11 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
long[] otherDocCount = new long[owningBucketOrds.length];
for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) {
final int size;
if (minDocCountLocal == 0) {
if (context.getMinDocCountLocal(bucketCountThresholds) == 0) {
// if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns
size = (int) Math.min(valueCount, requiredSizeLocal);
size = (int) Math.min(valueCount, context.getRequiredSizeLocal(bucketCountThresholds));
} else {
size = (int) Math.min(maxBucketOrd(), requiredSizeLocal);
size = (int) Math.min(maxBucketOrd(), context.getRequiredSizeLocal(bucketCountThresholds));
}
PriorityQueue<TB> ordered = buildPriorityQueue(size);
final int finalOrdIdx = ordIdx;
Expand All @@ -639,7 +630,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws
@Override
public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException {
otherDocCount[finalOrdIdx] += docCount;
if (docCount >= minDocCountLocal) {
if (docCount >= context.getMinDocCountLocal(bucketCountThresholds)) {
if (spare == null) {
spare = buildEmptyTemporaryBucket();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
*/
public abstract class InternalMappedTerms<A extends InternalTerms<A, B>, B extends InternalTerms.Bucket<B>> extends InternalTerms<A, B> {
protected final DocValueFormat format;
protected final int shardSize;
protected final boolean showTermDocCountError;
protected final long otherDocCount;
protected final List<B> buckets;
Expand All @@ -73,6 +74,7 @@ protected InternalMappedTerms(
) {
super(name, reduceOrder, order, bucketCountThresholds, metadata);
this.format = format;
this.shardSize = bucketCountThresholds.getShardSize();
this.showTermDocCountError = showTermDocCountError;
this.otherDocCount = otherDocCount;
this.docCountError = docCountError;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public int compare(List<Object> thisObjects, List<Object> thatObjects) {
}
}

private final int shardSize;
private final boolean showTermDocCountError;
private final long otherDocCount;
private final List<DocValueFormat> termFormats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.InternalMultiBucketAggregation;
import org.opensearch.search.aggregations.bucket.BucketUtils;
import org.opensearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic;

import java.io.IOException;
Expand Down Expand Up @@ -196,8 +195,6 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params)

protected final int requiredSize;
protected final long minDocCount;
protected int shardSize;
protected final long shardMinDocCount;
protected final TermsAggregator.BucketCountThresholds bucketCountThresholds;

protected InternalSignificantTerms(
Expand All @@ -208,8 +205,6 @@ protected InternalSignificantTerms(
super(name, metadata);
this.requiredSize = bucketCountThresholds.getRequiredSize();
this.minDocCount = bucketCountThresholds.getMinDocCount();
this.shardSize = bucketCountThresholds.getShardSize();
this.shardMinDocCount = bucketCountThresholds.getShardMinDocCount();
this.bucketCountThresholds = bucketCountThresholds;
}

Expand All @@ -220,9 +215,9 @@ protected InternalSignificantTerms(StreamInput in) throws IOException {
super(in);
requiredSize = readSize(in);
minDocCount = in.readVLong();
shardSize = BucketUtils.suggestShardSideQueueSize(requiredSize);
shardMinDocCount = 0;
bucketCountThresholds = new TermsAggregator.BucketCountThresholds(minDocCount, shardMinDocCount, requiredSize, shardSize);
// shardMinDocCount and shardSize are not used on the coordinator, so they are not deserialized. We use
// CoordinatorBucketCountThresholds which will throw an exception if they are accessed.
bucketCountThresholds = new TermsAggregator.CoordinatorBucketCountThresholds(minDocCount, -1, requiredSize, -1);
}

protected final void doWriteTo(StreamOutput out) throws IOException {
Expand All @@ -238,15 +233,6 @@ protected final void doWriteTo(StreamOutput out) throws IOException {

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
int requiredSizeLocal;
long minDocCountLocal;
if (reduceContext.isSliceLevel()) {
requiredSizeLocal = bucketCountThresholds.getShardSize();
minDocCountLocal = bucketCountThresholds.getShardMinDocCount();
} else {
requiredSizeLocal = bucketCountThresholds.getRequiredSize();
minDocCountLocal = bucketCountThresholds.getMinDocCount();
}

long globalSubsetSize = 0;
long globalSupersetSize = 0;
Expand Down Expand Up @@ -289,17 +275,21 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
}
}
SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext);
final int size = (reduceContext.isFinalReduce() || reduceContext.isSliceLevel())
? Math.min(requiredSizeLocal, buckets.size())
boolean isCoordinatorPartialReduce = reduceContext.isFinalReduce() == false && reduceContext.isSliceLevel() == false;
// Do not apply size threshold on coordinator partial reduce
final int size = !isCoordinatorPartialReduce
? Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), buckets.size())
: buckets.size();
BucketSignificancePriorityQueue<B> ordered = new BucketSignificancePriorityQueue<>(size);
for (Map.Entry<String, List<B>> entry : buckets.entrySet()) {
List<B> sameTermBuckets = entry.getValue();
final B b = reduceBucket(sameTermBuckets, reduceContext);
b.updateScore(heuristic);
if (!(reduceContext.isFinalReduce() || reduceContext.isSliceLevel()) // Don't apply thresholds for partial reduce
|| (reduceContext.isSliceLevel() && (b.subsetDf >= minDocCountLocal)) // Score needs to be evaluated only at the coordinator
|| ((b.score > 0) && (b.subsetDf >= minDocCountLocal))) {
// For concurrent search case we do not apply bucket count thresholds in buildAggregation and instead is done here during
// reduce. However, the bucket score is only evaluated at the final coordinator reduce.
boolean meetsThresholds = (b.subsetDf >= reduceContext.getMinDocCountLocal(bucketCountThresholds))
&& (((b.score > 0) || reduceContext.isSliceLevel()));
if (isCoordinatorPartialReduce || meetsThresholds) {
B removed = ordered.insertWithOverflow(b);
if (removed == null) {
reduceContext.consumeBucketsAndMaybeBreak(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import org.opensearch.search.aggregations.InternalMultiBucketAggregation;
import org.opensearch.search.aggregations.InternalOrder;
import org.opensearch.search.aggregations.KeyComparable;
import org.opensearch.search.aggregations.bucket.BucketUtils;
import org.opensearch.search.aggregations.bucket.IteratorAndCurrent;
import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation;

Expand All @@ -58,7 +57,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.search.aggregations.InternalOrder.isKeyAsc;
import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder;
Expand Down Expand Up @@ -225,8 +223,6 @@ public int hashCode() {
protected final BucketOrder order;
protected final int requiredSize;
protected final long minDocCount;
protected int shardSize;
protected final long shardMinDocCount;
protected final TermsAggregator.BucketCountThresholds bucketCountThresholds;

/**
Expand All @@ -250,8 +246,6 @@ protected InternalTerms(
this.bucketCountThresholds = bucketCountThresholds;
this.requiredSize = bucketCountThresholds.getRequiredSize();
this.minDocCount = bucketCountThresholds.getMinDocCount();
this.shardSize = bucketCountThresholds.getShardSize();
this.shardMinDocCount = bucketCountThresholds.getShardMinDocCount();
}

/**
Expand All @@ -263,9 +257,9 @@ protected InternalTerms(StreamInput in) throws IOException {
order = InternalOrder.Streams.readOrder(in);
requiredSize = readSize(in);
minDocCount = in.readVLong();
shardSize = BucketUtils.suggestShardSideQueueSize(requiredSize);
shardMinDocCount = 0;
bucketCountThresholds = new TermsAggregator.BucketCountThresholds(minDocCount, shardMinDocCount, requiredSize, shardSize);
// shardMinDocCount and shardSize are not used on the coordinator, so they are not deserialized. We use
// CoordinatorBucketCountThresholds which will throw an exception if they are accessed.
bucketCountThresholds = new TermsAggregator.CoordinatorBucketCountThresholds(minDocCount, -1, requiredSize, getShardSize());
}

@Override
Expand Down Expand Up @@ -394,16 +388,6 @@ private List<B> reduceLegacy(List<InternalAggregation> aggregations, ReduceConte
}

public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
int requiredSizeLocal;
long minDocCountLocal;
if (reduceContext.isSliceLevel()) {
requiredSizeLocal = bucketCountThresholds.getShardSize();
minDocCountLocal = bucketCountThresholds.getShardMinDocCount();
} else {
requiredSizeLocal = bucketCountThresholds.getRequiredSize();
minDocCountLocal = bucketCountThresholds.getMinDocCount();
}

long sumDocCountError = 0;
long otherDocCount = 0;
InternalTerms<A, B> referenceTerms = null;
Expand Down Expand Up @@ -464,7 +448,7 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
}
final B[] list;
if (reduceContext.isFinalReduce() || reduceContext.isSliceLevel()) {
final int size = Math.min(requiredSizeLocal, reducedBuckets.size());
final int size = Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), reducedBuckets.size());
// final comparator
final BucketPriorityQueue<B> ordered = new BucketPriorityQueue<>(size, order.comparator());
for (B bucket : reducedBuckets) {
Expand All @@ -474,7 +458,7 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
final long finalSumDocCountError = sumDocCountError;
bucket.setDocCountError(docCountError -> docCountError + finalSumDocCountError);
}
if (bucket.getDocCount() >= minDocCountLocal) {
if (bucket.getDocCount() >= reduceContext.getMinDocCountLocal(bucketCountThresholds)) {
B removed = ordered.insertWithOverflow(bucket);
if (removed != null) {
otherDocCount += removed.getDocCount();
Expand All @@ -493,8 +477,8 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
} else {
// we can prune the list on partial reduce if the aggregation is ordered by key
// and not filtered (minDocCount == 0)
int size = isKeyOrder(order) && minDocCountLocal == 0
? Math.min(requiredSizeLocal, reducedBuckets.size())
int size = isKeyOrder(order) && reduceContext.getMinDocCountLocal(bucketCountThresholds) == 0
? Math.min(reduceContext.getRequiredSizeLocal(bucketCountThresholds), reducedBuckets.size())
: reducedBuckets.size();
list = createBucketsArray(size);
for (int i = 0; i < size; i++) {
Expand All @@ -515,14 +499,11 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
docCountError = aggregations.size() == 1 ? 0 : sumDocCountError;
}

// Shards must return buckets sorted by key, so we apply the sort here
List<B> resultList;
// Shards must return buckets sorted by key, so we apply the sort here in shard level reduce
if (reduceContext.isSliceLevel()) {
resultList = Arrays.stream(list).sorted(thisReduceOrder.comparator()).collect(Collectors.toList());
} else {
resultList = Arrays.asList(list);
Arrays.sort(list, thisReduceOrder.comparator());
}
return create(name, resultList, reduceContext.isFinalReduce() ? order : thisReduceOrder, docCountError, otherDocCount);
return create(name, Arrays.asList(list), reduceContext.isFinalReduce() ? order : thisReduceOrder, docCountError, otherDocCount);
}

@Override
Expand Down
Loading

0 comments on commit f239aec

Please sign in to comment.