Skip to content

Commit

Permalink
Parallelize build agg
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Deng committed Mar 18, 2024
1 parent 1f5df54 commit 2f3f541
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.opensearch.action.admin.indices.segments.IndicesSegmentResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.health.ClusterHealthStatus;
import org.opensearch.cluster.metadata.IndexMetadata;
Expand All @@ -26,6 +27,7 @@
import java.util.Collection;
import java.util.List;

import static org.opensearch.indices.IndicesRequestCache.INDEX_CACHE_REQUEST_ENABLED_SETTING;
import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
Expand All @@ -50,26 +52,31 @@ public void setupSuiteScopeCluster() throws Exception {
assertAcked(
prepareCreate(
"idx",
Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(INDEX_CACHE_REQUEST_ENABLED_SETTING.getKey(), false)
).setMapping("type", "type=keyword", "num", "type=integer", "score", "type=integer")
);
waitForRelocation(ClusterHealthStatus.GREEN);

client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "1", "score", "5").get();
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "11", "score", "50").get();
refresh("idx");
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "1", "score", "2").get();
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "12", "score", "20").get();
refresh("idx");
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "3", "score", "10").get();
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "13", "score", "15").get();
refresh("idx");
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "3", "score", "1").get();
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "13", "score", "100").get();
refresh("idx");
indexRandom(
true,
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "1", "score", "5"),
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "11", "score", "50"),
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "1", "score", "2"),
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "12", "score", "20"),
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "3", "score", "10"),
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "13", "score", "15"),
client().prepareIndex("idx").setId("1").setSource("type", "type1", "num", "3", "score", "1"),
client().prepareIndex("idx").setId("1").setSource("type", "type2", "num", "13", "score", "100")
);

waitForRelocation(ClusterHealthStatus.GREEN);
refresh();

IndicesSegmentResponse segmentResponse = client().admin().indices().prepareSegments("idx").get();
System.out.println("Segments: " + segmentResponse.getIndices().get("idx").getShards().get(0).getShards()[0].getSegments().size());
}

public void testCompositeAggWithNoSubAgg() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import org.opensearch.search.query.ReduceableSearchResult;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

/**
* Common {@link CollectorManager} used by both concurrent and non-concurrent aggregation path and also for global and non-global
Expand Down Expand Up @@ -54,19 +54,43 @@ public String getCollectorReason() {

public abstract String getCollectorName();

// @Override
// public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
// final List<Aggregator> aggregators = context.bucketCollectorProcessor().toAggregators(collectors);
// final List<InternalAggregation> internals = new ArrayList<>(aggregators.size());
// context.aggregations().resetBucketMultiConsumer();
// for (Aggregator aggregator : aggregators) {
// try {
// // post collection is called in ContextIndexSearcher after search on leaves are completed
// internals.add(aggregator.buildTopLevel());
// } catch (IOException e) {
// throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e);
// }
// }
//
// final InternalAggregations internalAggregations = InternalAggregations.from(internals);
// return buildAggregationResult(internalAggregations);
// }

@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final List<Aggregator> aggregators = context.bucketCollectorProcessor().toAggregators(collectors);
final List<InternalAggregation> internals = new ArrayList<>(aggregators.size());
context.aggregations().resetBucketMultiConsumer();
for (Aggregator aggregator : aggregators) {
try {
// post collection is called in ContextIndexSearcher after search on leaves are completed
internals.add(aggregator.buildTopLevel());
} catch (IOException e) {
throw new AggregationExecutionException("Failed to build aggregation [" + aggregator.name() + "]", e);
List<InternalAggregation> internals = context.bucketCollectorProcessor().toInternalAggregations(collectors);

// collect does not get called whenever there are no leaves on a shard. Since we build the InternalAggregation in postCollection
// that will not get called in such cases. Therefore we need to manually call it again here to build empty Internal Aggregation
// objects for this collector tree.
if (internals.stream().allMatch(Objects::isNull)) {
List<Aggregator> aggregators = context.bucketCollectorProcessor().toAggregators(collectors);
for (Aggregator a : aggregators) {
// // c could be a MultiBucketCollector
if (a instanceof AggregatorBase) {
((AggregatorBase) a).buildAndSetInternalAggregation();
}
}
internals = context.bucketCollectorProcessor().toInternalAggregations(collectors);
}
assert !internals.stream().allMatch(Objects::isNull);
context.aggregations().resetBucketMultiConsumer(); // Not sure if this is thread safe

final InternalAggregations internalAggregations = InternalAggregations.from(internals);
return buildAggregationResult(internalAggregations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.search.aggregations;

import org.opensearch.OpenSearchParseException;
import org.opensearch.common.SetOnce;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.lease.Releasable;
import org.opensearch.core.ParseField;
Expand Down Expand Up @@ -61,6 +62,8 @@
@PublicApi(since = "1.0.0")
public abstract class Aggregator extends BucketCollector implements Releasable {

private final SetOnce<InternalAggregation> internalAggregation = new SetOnce<>();

/**
* Parses the aggregation request and creates the appropriate aggregator factory for it.
*
Expand All @@ -83,6 +86,14 @@ public interface Parser {
AggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException;
}

public void setInternalAggregation(InternalAggregation internalAggregation) {
this.internalAggregation.set(internalAggregation);
}

public InternalAggregation getInternalAggregation() {
return internalAggregation.get();
}

/**
* Return the name of this aggregator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.common.SetOnce;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
Expand Down Expand Up @@ -72,6 +73,8 @@ public abstract class AggregatorBase extends Aggregator {
private final CircuitBreakerService breakerService;
private long requestBytesUsed;

private final SetOnce<InternalAggregation> internalAggregation = new SetOnce<>();

/**
* Constructs a new Aggregator.
*
Expand Down Expand Up @@ -279,6 +282,13 @@ public void postCollection() throws IOException {
collectableSubAggregators.postCollection();
}

public void buildAndSetInternalAggregation() throws IOException {
// Only call buildTopLevel for top level aggregators. This will subsequently build aggregations for child aggs.
if (parent == null) {
internalAggregation.set(buildTopLevel());
}
}

/** Called upon release of the aggregator. */
@Override
public void close() {
Expand All @@ -305,6 +315,10 @@ protected final InternalAggregations buildEmptySubAggregations() {
return InternalAggregations.from(aggs);
}

public InternalAggregation getInternalAggregation() {
return internalAggregation.get();
}

@Override
public String toString() {
return name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.lucene.MinimumScoreCollector;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.aggregation.ProfilingAggregator;
import org.opensearch.search.profile.query.InternalProfileCollector;

import java.io.IOException;
Expand Down Expand Up @@ -63,6 +64,7 @@ public void processPostCollection(Collector collectorTree) throws IOException {
while (!collectors.isEmpty()) {
Collector currentCollector = collectors.poll();
if (currentCollector instanceof InternalProfileCollector) {
// Profile collector should be the top level one so we should be able to call buildAggregation on it here
collectors.offer(((InternalProfileCollector) currentCollector).getCollector());
} else if (currentCollector instanceof MinimumScoreCollector) {
collectors.offer(((MinimumScoreCollector) currentCollector).getCollector());
Expand All @@ -72,6 +74,19 @@ public void processPostCollection(Collector collectorTree) throws IOException {
}
} else if (currentCollector instanceof BucketCollector) {
((BucketCollector) currentCollector).postCollection();

// Call buildTopLevel here -- need to set the InternalAggregation in Aggregator because profiler extends that
if (currentCollector instanceof AggregatorBase) {
((AggregatorBase) currentCollector).buildAndSetInternalAggregation();
} else if (currentCollector instanceof ProfilingAggregator) {
((ProfilingAggregator) currentCollector).setInternalAggregation(
((ProfilingAggregator) currentCollector).buildTopLevel()
);
} else if (currentCollector instanceof MultiBucketCollector) {
for (Collector innerCollector : ((MultiBucketCollector) currentCollector).getCollectors()) {
collectors.offer(innerCollector);
}
}
}
}
}
Expand Down Expand Up @@ -106,4 +121,32 @@ public List<Aggregator> toAggregators(Collection<Collector> collectors) {
}
return aggregators;
}

// The problem is here -- need to check for Aggregator class
public List<InternalAggregation> toInternalAggregations(Collection<Collector> collectors) {
List<InternalAggregation> internalAggregations = new ArrayList<>();

final Deque<Collector> allCollectors = new LinkedList<>(collectors);
while (!allCollectors.isEmpty()) {
final Collector currentCollector = allCollectors.pop();
if (currentCollector instanceof AggregatorBase) {
internalAggregations.add(((AggregatorBase) currentCollector).getInternalAggregation());
} else if (currentCollector instanceof ProfilingAggregator) {
internalAggregations.add(((ProfilingAggregator) currentCollector).getInternalAggregation());
} else if (currentCollector instanceof InternalProfileCollector) {
if (((InternalProfileCollector) currentCollector).getCollector() instanceof Aggregator) {
internalAggregations.add(
((AggregatorBase) ((InternalProfileCollector) currentCollector).getCollector()).getInternalAggregation()
);
} else if (((InternalProfileCollector) currentCollector).getCollector() instanceof MultiBucketCollector) {
allCollectors.addAll(
Arrays.asList(((MultiBucketCollector) ((InternalProfileCollector) currentCollector).getCollector()).getCollectors())
);
}
} else if (currentCollector instanceof MultiBucketCollector) {
allCollectors.addAll(Arrays.asList(((MultiBucketCollector) currentCollector).getCollectors()));
}
}
return internalAggregations;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public Collector newCollector() throws IOException {
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
assert collectors.isEmpty() : "Reduce on GlobalAggregationCollectorManagerWithCollector called with non-empty collectors";
// Something needs to be done here
return super.reduce(List.of(collector));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ public Collector newCollector() throws IOException {
return collector;
}

// Can we add post collection logic in here?
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
assert collectors.isEmpty() : "Reduce on NonGlobalAggregationCollectorManagerWithCollector called with non-empty collectors";

return super.reduce(List.of(collector));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ protected Aggregator createInternal(
@Override
protected boolean supportsConcurrentSegmentSearch() {
// See https://github.com/opensearch-project/OpenSearch/issues/12331 for details
return false;
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ final class CompositeAggregator extends BucketsAggregator {

private final CompositeValuesSourceConfig[] sourceConfigs;
private final SingleDimensionValuesSource<?>[] sources;
private final CompositeValuesCollectorQueue queue;
private CompositeValuesCollectorQueue queue;

private final List<Entry> entries = new ArrayList<>();
private LeafReaderContext currentLeaf;
Expand Down Expand Up @@ -236,6 +236,16 @@ protected void doPreCollection() throws IOException {
@Override
protected void doPostCollection() throws IOException {
finishLeaf();
// Re-create the ValuesSource on the search thread for concurrent search
// for (int i = 0; i < sourceConfigs.length; i++) {
// this.sources[i] = sourceConfigs[i].createValuesSource(
// context.bigArrays(),
// context.searcher().getIndexReader(),
// size,
// this::addRequestCircuitBreakerBytes
// );
// }
// this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ public class GlobalOrdinalsStringTermsAggregator extends AbstractStringTermsAggr
protected int segmentsWithSingleValuedOrds = 0;
protected int segmentsWithMultiValuedOrds = 0;

protected final IndexReader reader;

/**
* Lookup global ordinals
*
Expand Down Expand Up @@ -128,7 +130,9 @@ public GlobalOrdinalsStringTermsAggregator(
super(name, factories, context, parent, order, format, bucketCountThresholds, collectionMode, showTermDocCountError, metadata);
this.resultStrategy = resultStrategy.apply(this); // ResultStrategy needs a reference to the Aggregator to do its job.
this.valuesSource = valuesSource;
final IndexReader reader = context.searcher().getIndexReader();
reader = context.searcher().getIndexReader();
// valuesSource is shared across aggregators and the DocValues here are created when the collector is created.
// Need to delay this creation to when it's actually used in the index_search thread.
final SortedSetDocValues values = reader.leaves().size() > 0
? valuesSource.globalOrdinalsValues(context.searcher().getIndexReader().leaves().get(0))
: DocValues.emptySortedSet();
Expand Down Expand Up @@ -885,7 +889,12 @@ PriorityQueue<OrdBucket> buildPriorityQueue(int size) {
}

StringTerms.Bucket convertTempBucketToRealBucket(OrdBucket temp) throws IOException {
BytesRef term = BytesRef.deepCopyOf(lookupGlobalOrd.apply(temp.globalOrd));
// BytesRef term = BytesRef.deepCopyOf(lookupGlobalOrd.apply(temp.globalOrd));
SortedSetDocValues values = reader.leaves().size() > 0
? valuesSource.globalOrdinalsValues(context.searcher().getIndexReader().leaves().get(0))
: DocValues.emptySortedSet();
BytesRef term = BytesRef.deepCopyOf(values.lookupOrd(temp.globalOrd));

StringTerms.Bucket result = new StringTerms.Bucket(term, temp.docCount, null, showTermDocCountError, 0, format);
result.bucketOrd = temp.bucketOrd;
result.docCountError = 0;
Expand Down Expand Up @@ -1001,7 +1010,11 @@ BucketUpdater<SignificantStringTerms.Bucket> bucketUpdater(long owningBucketOrd)
long subsetSize = subsetSize(owningBucketOrd);
return (spare, globalOrd, bucketOrd, docCount) -> {
spare.bucketOrd = bucketOrd;
oversizedCopy(lookupGlobalOrd.apply(globalOrd), spare.termBytes);
// oversizedCopy(lookupGlobalOrd.apply(globalOrd), spare.termBytes);
SortedSetDocValues values = reader.leaves().size() > 0
? valuesSource.globalOrdinalsValues(context.searcher().getIndexReader().leaves().get(0))
: DocValues.emptySortedSet();
oversizedCopy(values.lookupOrd(globalOrd), spare.termBytes);
spare.subsetDf = docCount;
spare.subsetSize = subsetSize;
spare.supersetDf = backgroundFrequencies.freq(spare.termBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@
"LuceneFixedGap",
"LuceneVarGapFixedInterval",
"LuceneVarGapDocFreqInterval",
"Lucene50" })
"Lucene50",
"Lucene90",
"Lucene94",
"Lucene90",
"Lucene95",
"Lucene99" })
@LuceneTestCase.SuppressReproduceLine
public abstract class OpenSearchTestCase extends LuceneTestCase {

Expand Down

0 comments on commit 2f3f541

Please sign in to comment.