Skip to content

Commit

Permalink
Sub agg support for fast filter optimization
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Carroll <[email protected]>
  • Loading branch information
finnegancarroll committed Aug 14, 2024
1 parent 01acf1c commit 93f4a51
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 284 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.search.aggregations.bucket.composite;

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
Expand Down Expand Up @@ -89,7 +90,6 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -190,9 +190,7 @@ protected boolean canOptimize() {
}

@Override
protected void prepare() throws IOException {
buildRanges(context);
}
protected void prepare() throws IOException { buildRanges(context); }

protected Rounding getRounding(final long low, final long high) {
return valuesSource.getRounding();
Expand All @@ -213,13 +211,21 @@ protected long[] processAfterKey(long[] bounds, long interval) {
}

@Override
protected int getSize() {
protected int rangeMax() {
return size;
}

@Override
protected Function<Long, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key));
protected long getOrd(int rangeIdx){
long rangeStart = LongPoint.decodeDimension(filterRewriteOptimizationContext.getRanges().getLower(rangeIdx), 0);
rangeStart = this.getFieldType().convertNanosToMillis(rangeStart);
long ord = bucketOrds.add(0, getRoundingPrepared().round(rangeStart));

if (ord < 0) { // already seen
ord = -1 - ord;
}

return ord;
}
};
filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
Expand Down Expand Up @@ -557,7 +563,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, sub, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.aggregations.LeafBucketCollector;

import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;

/**
* This interface provides a bridge between an aggregator and the optimization context, allowing
* the aggregator to provide data and optimize the aggregation process.
Expand All @@ -37,9 +40,9 @@ public abstract class AggregatorBridge {
*/
MappedFieldType fieldType;

Consumer<Ranges> setRanges;
Consumer<PackedValueRanges> setRanges;

void setRangesConsumer(Consumer<Ranges> setRanges) {
void setRangesConsumer(Consumer<PackedValueRanges> setRanges) {
this.setRanges = setRanges;
}

Expand Down Expand Up @@ -67,18 +70,62 @@ void setRangesConsumer(Consumer<Ranges> setRanges) {
*
* @param leaf the leaf reader context for the segment
*/
abstract Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException;
abstract PackedValueRanges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException;

/**
* @return max range to stop collecting at.
* Utilized by aggs which stop early.
*/
protected int rangeMax() {
return Integer.MAX_VALUE;
}

/**
* Translate an index of the packed value range array to an agg bucket ordinal.
*/
protected long getOrd(int rangeIdx){
return rangeIdx;
}

/**
* Attempts to build aggregation results for a segment
* Attempts to build aggregation results for a segment.
* With no sub agg count docs and avoid iterating docIds.
* If a sub agg is present we must iterate through and collect docIds to support it.
*
* @param values the point values (index structure for numeric values) for a segment
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param ranges
* @param values the point values (index structure for numeric values) for a segment
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
*/
abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
) throws IOException;
public final FilterRewriteOptimizationContext.DebugInfo tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, PackedValueRanges ranges, final LeafBucketCollector sub)
throws IOException {
PointTreeTraversal.RangeAwareIntersectVisitor treeVisitor;

if (sub != null) {
treeVisitor = new PointTreeTraversal.DocCollectRangeAwareIntersectVisitor(
values.getPointTree(),
ranges,
rangeMax(),
(activeIndex, docID) -> {
long ord = this.getOrd(activeIndex);
try {
incrementDocCount.accept(ord, (long) 1);
sub.collect(docID, ord);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
);
} else {
treeVisitor = new PointTreeTraversal.DocCountRangeAwareIntersectVisitor(
values.getPointTree(),
ranges,
rangeMax(),
(activeIndex, docCount) -> {
long ord = this.getOrd(activeIndex);
incrementDocCount.accept(ord, (long) docCount);
}
);
}

return multiRangesTraverse(treeVisitor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import org.opensearch.common.Rounding;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -22,10 +21,6 @@

import java.io.IOException;
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;

/**
* For date histogram aggregation
Expand Down Expand Up @@ -54,12 +49,12 @@ protected void buildRanges(SearchContext context) throws IOException {
}

@Override
final Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException {
final PackedValueRanges tryBuildRangesFromSegment(LeafReaderContext leaf) throws IOException {
long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name());
return buildRanges(bounds, maxRewriteFilters);
}

private Ranges buildRanges(long[] bounds, int maxRewriteFilters) {
private PackedValueRanges buildRanges(long[] bounds, int maxRewriteFilters) {
bounds = processHardBounds(bounds);
if (bounds == null) {
return null;
Expand Down Expand Up @@ -116,47 +111,11 @@ protected long[] processHardBounds(long[] bounds, LongBounds hardBounds) {
return bounds;
}

private DateFieldMapper.DateFieldType getFieldType() {
public DateFieldMapper.DateFieldType getFieldType() {
assert fieldType instanceof DateFieldMapper.DateFieldType;
return (DateFieldMapper.DateFieldType) fieldType;
}

protected int getSize() {
return Integer.MAX_VALUE;
}

@Override
final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(bucketOrd, (long) docCount);
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
}

private static long getBucketOrd(long bucketOrd) {
if (bucketOrd < 0) { // already seen
bucketOrd = -1 - bucketOrd;
}

return bucketOrd;
}

/**
* Provides a function to produce bucket ordinals from the lower bound of the range
*/
protected abstract Function<Long, Long> bucketOrdProducer();

/**
* Checks whether the top level query matches all documents on the segment
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.search.internal.SearchContext;

Expand All @@ -40,7 +41,7 @@ public final class FilterRewriteOptimizationContext {
private final AggregatorBridge aggregatorBridge;
private String shardId;

private Ranges ranges; // built at shard level
private PackedValueRanges ranges; // built at shard level

// debug info related fields
private final AtomicInteger leafNodeVisited = new AtomicInteger();
Expand Down Expand Up @@ -84,10 +85,14 @@ private boolean canOptimize(final Object parent, final int subAggLength, SearchC
return canOptimize;
}

void setRanges(Ranges ranges) {
public void setRanges(PackedValueRanges ranges) {
this.ranges = ranges;
}

public PackedValueRanges getRanges() {
return this.ranges;
}

/**
* Try to populate the bucket doc counts for aggregation
* <p>
Expand All @@ -96,7 +101,7 @@ void setRanges(Ranges ranges) {
* @param incrementDocCount consume the doc_count results for certain ordinal
* @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment
*/
public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount, boolean segmentMatchAll)
public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount, LeafBucketCollector sub, boolean segmentMatchAll)
throws IOException {
segments.incrementAndGet();
if (!canOptimize) {
Expand All @@ -120,10 +125,10 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
return false;
}

Ranges ranges = getRanges(leafCtx, segmentMatchAll);
PackedValueRanges ranges = getRanges(leafCtx, segmentMatchAll);
if (ranges == null) return false;

consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges));
consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, sub));

optimizedSegments.incrementAndGet();
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
Expand All @@ -132,7 +137,7 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
return true;
}

Ranges getRanges(LeafReaderContext leafCtx, boolean segmentMatchAll) {
public PackedValueRanges getRanges(LeafReaderContext leafCtx, boolean segmentMatchAll) {
if (!preparedAtShardLevel) {
try {
return getRangesFromSegment(leafCtx, segmentMatchAll);
Expand All @@ -148,7 +153,7 @@ Ranges getRanges(LeafReaderContext leafCtx, boolean segmentMatchAll) {
* Even when ranges cannot be built at shard level, we can still build ranges
* at segment level when it's functionally match-all at segment level
*/
private Ranges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMatchAll) throws IOException {
private PackedValueRanges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMatchAll) throws IOException {
if (!segmentMatchAll) {
return null;
}
Expand All @@ -160,7 +165,7 @@ private Ranges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMa
/**
* Contains debug info of BKD traversal to show in profile
*/
static class DebugInfo {
public static class DebugInfo {
private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited
private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldN
* Creates the date ranges from date histo aggregations using its interval,
* and min/max boundaries
*/
static Ranges createRangesFromAgg(
static PackedValueRanges createRangesFromAgg(
final DateFieldMapper.DateFieldType fieldType,
final long interval,
final Rounding.Prepared preparedRounding,
Expand Down Expand Up @@ -208,6 +208,6 @@ static Ranges createRangesFromAgg(
uppers[i] = max;
}

return new Ranges(lowers, uppers);
return new PackedValueRanges(lowers, uppers);
}
}
Loading

0 comments on commit 93f4a51

Please sign in to comment.