Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Dec 22, 2023
1 parent e7d8af9 commit 029686f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,7 @@
*
* @opensearch.internal
*/
public class FilterRewriteHelper {

/**
* Saves the objects that will be used to try fast filter optimization
*/
public static class FilterContext {
public final DateFieldMapper.DateFieldType fieldType;
public final Weight[] filters;

public FilterContext(DateFieldMapper.DateFieldType fieldType, Weight[] filters) {
this.fieldType = fieldType;
this.filters = filters;
}
}
public class FastFilterRewriteHelper {

private static final int MAX_NUM_FILTER_BUCKETS = 1024;
private static final Map<Class<?>, Function<Query, Query>> queryWrappers;
Expand Down Expand Up @@ -133,7 +120,7 @@ private static Weight[] createFilterForAggregations(
final DateFieldMapper.DateFieldType fieldType,
long low,
final long high,
long afterKey
final long afterKey
) throws IOException {
final OptionalLong intervalOpt = Rounding.getInterval(rounding);
if (intervalOpt.isEmpty()) {
Expand Down Expand Up @@ -191,75 +178,89 @@ protected String toString(int dimension, byte[] value) {
}

/**
* The pre-conditions to initiate fast filter optimization on aggregations are:
* 1. The query with aggregation has to be PointRangeQuery on the same date field
* 2. No parent/sub aggregations
* 3. No missing value/bucket
* 4. No script
*
* @param computeBounds get the lower and upper bound of the field in a shard search
* @param roundingFunction produce Rounding that will provide the interval
* @param preparedRoundingSupplier produce PreparedRounding that will do the rounding
*/
public static FilterContext buildFastFilterContext(
final Object parent,
final int subAggLength,
public static void buildFastFilterContext(
SearchContext context,
Function<long[], Rounding> roundingFunction,
Supplier<Rounding.Prepared> preparedRoundingSupplier,
ValueSourceContext valueSourceContext,
CheckedFunction<ValueSourceContext, long[], IOException> computeBounds
FastFilterContext fastFilterContext,
CheckedFunction<FastFilterContext, long[], IOException> computeBounds
) throws IOException {
if (parent == null && subAggLength == 0 && !valueSourceContext.missing && !valueSourceContext.hasScript) {
MappedFieldType fieldType = valueSourceContext.fieldType;
if (fieldType != null) {
final String fieldName = fieldType.name();
final long[] bounds = computeBounds.apply(valueSourceContext);
if (bounds != null) {
assert fieldType instanceof DateFieldMapper.DateFieldType;
final Rounding rounding = roundingFunction.apply(bounds);
final Weight[] filters = FilterRewriteHelper.createFilterForAggregations(
context,
rounding,
preparedRoundingSupplier.get(),
fieldName,
(DateFieldMapper.DateFieldType) fieldType,
bounds[0],
bounds[1],
valueSourceContext.afterKey
);
return new FilterContext((DateFieldMapper.DateFieldType) fieldType, filters);
}
}
assert fastFilterContext.fieldType instanceof DateFieldMapper.DateFieldType;
DateFieldMapper.DateFieldType fieldType = (DateFieldMapper.DateFieldType) fastFilterContext.fieldType;
final String fieldName = fieldType.name();
final long[] bounds = computeBounds.apply(fastFilterContext);
if (bounds != null) {
final Rounding rounding = roundingFunction.apply(bounds);
final Weight[] filters = FastFilterRewriteHelper.createFilterForAggregations(
context,
rounding,
preparedRoundingSupplier.get(),
fieldName,
fieldType,
bounds[0],
bounds[1],
fastFilterContext.afterKey
);
fastFilterContext.setFilters(filters);
}
return null;
}

/**
* Encapsulates metadata about a value source needed to rewrite
*/
public static class ValueSourceContext {
public static class FastFilterContext {
private final boolean missing;
private final boolean hasScript;
private final MappedFieldType fieldType;
private final long afterKey;

private long afterKey = -1L;
private int size = Integer.MAX_VALUE; // only used by composite aggregation for pagination
private Weight[] filters = null;

/**
* @param missing whether missing value/bucket is set
* @param hasScript whether script is used
* @param fieldType null if the field doesn't exist
* @param afterKey used to paginate for composite aggregation, pass in -1 if not used
*/
public ValueSourceContext(boolean missing, boolean hasScript, MappedFieldType fieldType, long afterKey) {
public FastFilterContext(boolean missing, boolean hasScript, MappedFieldType fieldType) {
this.missing = missing;
this.hasScript = hasScript;
this.fieldType = fieldType;
this.afterKey = afterKey;
}

public MappedFieldType getFieldType() {
return fieldType;
}

public void setSize(int size) {
this.size = size;
}

public void setFilters(Weight[] filters) {
this.filters = filters;
}

public void setAfterKey(long afterKey) {
this.afterKey = afterKey;
}

/**
* The pre-conditions to initiate fast filter optimization on aggregations are:
* 1. The query with aggregation has to be PointRangeQuery on the same date field
* 2. No parent/sub aggregations
* 3. No missing value/bucket
* 4. No script
*/
public boolean isRewriteable(Object parent, int subAggLength) {
if (parent == null && subAggLength == 0 && !missing && !hasScript) {
return fieldType != null && fieldType instanceof DateFieldMapper.DateFieldType;
}
return false;
}
}

public static long getBucketOrd(long bucketOrd) {
Expand All @@ -272,18 +273,17 @@ public static long getBucketOrd(long bucketOrd) {

/**
* This should be executed for each segment
*
* @param size the maximum number of buckets needed
*/
public static boolean tryFastFilterAggregation(
final LeafReaderContext ctx,
final Weight[] filters,
final DateFieldMapper.DateFieldType fieldType,
final BiConsumer<Long, Integer> incrementDocCount,
final int size
FastFilterContext fastFilterContext,
final BiConsumer<Long, Integer> incrementDocCount
) throws IOException {
if (filters == null) return false;
if (fastFilterContext == null) return false;
if (fastFilterContext.filters == null) return false;

final Weight[] filters = fastFilterContext.filters;
final DateFieldMapper.DateFieldType fieldType = (DateFieldMapper.DateFieldType) fastFilterContext.fieldType;
final int[] counts = new int[filters.length];
int i;
for (i = 0; i < filters.length; i++) {
Expand All @@ -305,7 +305,7 @@ public static boolean tryFastFilterAggregation(
counts[i]
);
s++;
if (s > size) return true;
if (s > fastFilterContext.size) return true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import org.opensearch.common.Rounding;
import org.opensearch.common.lease.Releasables;
import org.opensearch.index.IndexSortConfig;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.lucene.queries.SearchAfterSortedDocQuery;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
Expand All @@ -73,7 +72,7 @@
import org.opensearch.search.aggregations.MultiBucketCollector;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.FilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.missing.MissingOrder;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -116,10 +115,9 @@ final class CompositeAggregator extends BucketsAggregator {

private boolean earlyTerminated;

private Weight[] filters = null;
private LongKeyedBucketOrds bucketOrds = null;
private DateFieldMapper.DateFieldType fieldType = null;
private Rounding.Prepared preparedRounding = null;
private FastFilterRewriteHelper.FastFilterContext fastFilterContext = null;

CompositeAggregator(
String name,
Expand Down Expand Up @@ -170,34 +168,27 @@ final class CompositeAggregator extends BucketsAggregator {
RoundingValuesSource dateHistogramSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
preparedRounding = dateHistogramSource.getPreparedRounding();
long afterValue = 0;
fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
sourceConfigs[0].missingBucket(),
sourceConfigs[0].hasScript(),
sourceConfigs[0].fieldType()
);
if (rawAfterKey != null) {
assert rawAfterKey.size() == 1 && formats.size() == 1;
afterValue = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
long afterValue = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> {
throw new IllegalArgumentException("now() is not supported in [after] key");
});
fastFilterContext.setAfterKey(afterValue);
}
FilterRewriteHelper.ValueSourceContext dateHistogramSourceContext = new FilterRewriteHelper.ValueSourceContext(
sourceConfigs[0].missingBucket(),
sourceConfigs[0].hasScript(),
sourceConfigs[0].fieldType(),
afterValue
);
FilterRewriteHelper.FilterContext filterContext = FilterRewriteHelper.buildFastFilterContext(
parent,
subAggregators.length,
context,
x -> dateHistogramSource.getRounding(),
() -> preparedRounding,
dateHistogramSourceContext,
fc -> FilterRewriteHelper.getAggregationBounds(context, fc.getFieldType().name())
);
if (filterContext != null) {
fieldType = filterContext.fieldType;
filters = filterContext.filters;
} else {
filters = null;
fieldType = null;
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
fastFilterContext.setSize(size);
FastFilterRewriteHelper.buildFastFilterContext(
context,
x -> dateHistogramSource.getRounding(),
() -> preparedRounding,
fastFilterContext,
fc -> FastFilterRewriteHelper.getAggregationBounds(context, fc.getFieldType().name())
);
}
}
}
Expand Down Expand Up @@ -532,9 +523,10 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = FilterRewriteHelper.tryFastFilterAggregation(ctx, filters, fieldType, (key, count) -> {
incrementBucketDocCount(FilterRewriteHelper.getBucketOrd(bucketOrds.add(0, preparedRounding.round(key))), count);
}, size);
boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation(ctx, fastFilterContext, (key, count) -> {
incrementBucketDocCount(FastFilterRewriteHelper.
getBucketOrd(bucketOrds.add(0, preparedRounding.round(key))), count);
});
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.CollectionUtil;
import org.opensearch.common.Rounding;
import org.opensearch.common.Rounding.Prepared;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.IntArray;
import org.opensearch.common.util.LongArray;
import org.opensearch.core.common.util.ByteArray;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -54,7 +52,7 @@
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
import org.opensearch.search.aggregations.bucket.DeferringBucketCollector;
import org.opensearch.search.aggregations.bucket.FilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper;
import org.opensearch.search.aggregations.bucket.MergingBucketsDeferringCollector;
import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder.RoundingInfo;
import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
Expand Down Expand Up @@ -129,14 +127,14 @@ static AutoDateHistogramAggregator build(
* {@link MergingBucketsDeferringCollector#mergeBuckets(long[])}.
*/
private MergingBucketsDeferringCollector deferringCollector;
private final Weight[] filters;
private final DateFieldMapper.DateFieldType fieldType;

protected final RoundingInfo[] roundingInfos;
protected final int targetBuckets;
protected int roundingIdx;
protected Rounding.Prepared preparedRounding;

private final FastFilterRewriteHelper.FastFilterContext fastFilterContext;

private AutoDateHistogramAggregator(
String name,
AggregatorFactories factories,
Expand All @@ -158,29 +156,21 @@ private AutoDateHistogramAggregator(
this.roundingPreparer = roundingPreparer;
this.preparedRounding = prepareRounding(0);

FilterRewriteHelper.ValueSourceContext dateHistogramSourceContext = new FilterRewriteHelper.ValueSourceContext(
fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(
valuesSourceConfig.missing() != null,
valuesSourceConfig.script() != null,
valuesSourceConfig.fieldType(),
-1
);
FilterRewriteHelper.FilterContext filterContext = FilterRewriteHelper.buildFastFilterContext(
parent(),
subAggregators.length,
context,
b -> getMinimumRounding(b[0], b[1]),
// Passing prepared rounding as supplier to ensure the correct prepared
// rounding is set as it is done during getMinimumRounding
() -> preparedRounding,
dateHistogramSourceContext,
fc -> FilterRewriteHelper.getAggregationBounds(context, fc.getFieldType().name())
valuesSourceConfig.fieldType()
);
if (filterContext != null) {
fieldType = filterContext.fieldType;
filters = filterContext.filters;
} else {
fieldType = null;
filters = null;
if (fastFilterContext.isRewriteable(parent, subAggregators.length)) {
FastFilterRewriteHelper.buildFastFilterContext(
context,
b -> getMinimumRounding(b[0], b[1]),
// Passing prepared rounding as supplier to ensure the correct prepared
// rounding is set as it is done during getMinimumRounding
() -> preparedRounding,
fastFilterContext,
fc -> FastFilterRewriteHelper.getAggregationBounds(context, fc.getFieldType().name())
);
}
}

Expand Down Expand Up @@ -234,9 +224,9 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = FilterRewriteHelper.tryFastFilterAggregation(ctx, filters, fieldType, (key, count) -> {
incrementBucketDocCount(FilterRewriteHelper.getBucketOrd(getBucketOrds().add(0, preparedRounding.round(key))), count);
}, Integer.MAX_VALUE);
boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation(ctx, fastFilterContext, (key, count) -> {
incrementBucketDocCount(FastFilterRewriteHelper.getBucketOrd(getBucketOrds().add(0, preparedRounding.round(key))), count);
});
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Loading

0 comments on commit 029686f

Please sign in to comment.