Skip to content

Commit

Permalink
Only build filter once at segment level
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Feb 2, 2024
1 parent b6ec756 commit 8573fe5
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
Expand All @@ -23,6 +24,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.CheckedBiFunction;
import org.opensearch.common.Rounding;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.mapper.DateFieldMapper;
Expand Down Expand Up @@ -97,7 +99,6 @@ private static long[] getIndexBounds(final SearchContext context, final String f
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
for (LeafReaderContext leaf : leaves) {
final PointValues values = leaf.reader().getPointValues(fieldName);
// "values" is null here means this segment doesn't have any values for this field
if (values != null) {
min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0));
Expand All @@ -119,16 +120,9 @@ public static long[] getDateHistoAggBounds(final SearchContext context, final St
final Query cq = unwrapIntoConcreteQuery(context.query());
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
final long[] indexBounds = getIndexBounds(context, fieldName);
if (indexBounds == null) return null;
return new long[] {
// Minimum bound for aggregation is the max between query and global
Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]),
// Maximum bound for aggregation is the min between query and global
Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]) };
}
final long[] indexBounds = getIndexBounds(context, fieldName);
if (indexBounds == null) return null;
return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
} else if (cq instanceof MatchAllDocsQuery) {
return getIndexBounds(context, fieldName);
} else if (cq instanceof FieldExistsQuery) {
Expand All @@ -141,6 +135,32 @@ public static long[] getDateHistoAggBounds(final SearchContext context, final St
return null;
}

private static long[] getDateHistoAggBoundsSegLevel(final SearchContext context, final String fieldName) throws IOException {
final long[] indexBounds = getIndexBounds(context, fieldName);
if (indexBounds == null) return null;
final Query cq = unwrapIntoConcreteQuery(context.query());
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
return getBoundsWithRangeQuery(prq, fieldName, indexBounds);
}
return indexBounds;
}

private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) {
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
// Minimum bound for aggregation is the max between query and global
long lower = Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]);
// Maximum bound for aggregation is the min between query and global
long upper = Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]);
if (lower > upper) {
return null;
}
return new long[]{lower, upper};
}
return null;
}

/**
* Creates the date range filters for aggregations using the interval, min/max
* bounds and prepared rounding
Expand Down Expand Up @@ -192,7 +212,7 @@ private static Weight[] createFilterForAggregations(
filters[i++] = context.searcher().createWeight(new PointRangeQuery(fieldType.name(), lower, upper, 1) {
@Override
protected String toString(int dimension, byte[] value) {
return null;
return Long.toString(LongPoint.decodeDimension(value, 0));
}
}, ScoreMode.COMPLETE_NO_SCORES, 1);
}
Expand All @@ -210,6 +230,7 @@ public static class FastFilterContext {
private boolean rewriteable = false;
private Weight[] filters = null;
private boolean filtersBuiltAtShardLevel = false;
private boolean shouldBuildFiltersAtSegmentLevel = true;

private AggregationType aggregationType;
private final SearchContext context;
Expand All @@ -234,21 +255,17 @@ public boolean isRewriteable(final Object parent, final int subAggLength) {
}

public void buildFastFilter() throws IOException {
this.buildFastFilter(FastFilterRewriteHelper::getDateHistoAggBounds);
}

private void buildFastFilter(GetBounds<SearchContext, String, long[]> getBounds) throws IOException {
assert filters == null : "Filters should only be built once, but they are already built";
Weight[] filters = this.aggregationType.buildFastFilter(context, getBounds);
this.filters = this.buildFastFilter(FastFilterRewriteHelper::getDateHistoAggBounds);
if (filters != null) {
logger.debug("Fast filter built for shard {}", context.indexShard().shardId());
filtersBuiltAtShardLevel = true;
this.filters = filters;
}
}

private boolean filterBuilt() {
return filters != null;
// This method can also be used at segment level
private Weight[] buildFastFilter(CheckedBiFunction<SearchContext, String, long[], IOException> getBounds) throws IOException {
assert filters == null : "Filters should only be built once, but they are already built";
return this.aggregationType.buildFastFilter(context, getBounds);
}
}

Expand All @@ -259,15 +276,8 @@ interface AggregationType {

boolean isRewriteable(Object parent, int subAggLength);

Weight[] buildFastFilter(SearchContext ctx, GetBounds<SearchContext, String, long[]> getBounds) throws IOException;
}

/**
* Functional interface for getting bounds for date histogram aggregation
*/
@FunctionalInterface
interface GetBounds<T, U, R> {
R apply(T t, U u) throws IOException;
Weight[] buildFastFilter(SearchContext ctx, CheckedBiFunction<SearchContext, String, long[], IOException> getBounds)
throws IOException;
}

/**
Expand Down Expand Up @@ -301,7 +311,8 @@ public boolean isRewriteable(Object parent, int subAggLength) {
}

@Override
public Weight[] buildFastFilter(SearchContext context, GetBounds<SearchContext, String, long[]> getBounds) throws IOException {
public Weight[] buildFastFilter(SearchContext context, CheckedBiFunction<SearchContext, String, long[], IOException> getBounds)
throws IOException {
long[] bounds = getBounds.apply(context, fieldType.name());
bounds = processHardBounds(bounds);
logger.debug("Bounds are {} for shard {}", bounds, context.indexShard().shardId());
Expand Down Expand Up @@ -383,7 +394,9 @@ public static boolean tryFastFilterAggregation(
final BiConsumer<Long, Integer> incrementDocCount
) throws IOException {
if (fastFilterContext == null) return false;
if (!fastFilterContext.rewriteable) return false;
if (!fastFilterContext.rewriteable || !fastFilterContext.shouldBuildFiltersAtSegmentLevel) {
return false;
}

NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
Expand All @@ -397,18 +410,25 @@ public static boolean tryFastFilterAggregation(

// if no filters built at shard level (see getDateHistoAggBounds method for possible reasons)
// check if the query is functionally match-all at segment level
if (!fastFilterContext.filtersBuiltAtShardLevel && !segmentMatchAll(fastFilterContext.context, ctx)) return false;
if (!fastFilterContext.filterBuilt()) {
if (!fastFilterContext.filtersBuiltAtShardLevel && !segmentMatchAll(fastFilterContext.context, ctx)) {
return false;
}
Weight[] filters = fastFilterContext.filters;
if (filters == null) {
logger.debug(
"Shard {} segment {} functionally match all documents. Build the fast filter",
fastFilterContext.context.indexShard().shardId(),
ctx.ord
);
fastFilterContext.buildFastFilter(FastFilterRewriteHelper::getIndexBounds);
filters = fastFilterContext.buildFastFilter(FastFilterRewriteHelper::getDateHistoAggBoundsSegLevel);
if (filters == null) {
// At segment level, build filter should only be called once
// since the conditions for build filter won't change for other segments
fastFilterContext.shouldBuildFiltersAtSegmentLevel = false;
return false;
}
}
if (!fastFilterContext.filterBuilt()) return false;

final Weight[] filters = fastFilterContext.filters;
final int[] counts = new int[filters.length];
int i;
for (i = 0; i < filters.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,76 @@ public void testHardBoundsNotOverlapping() throws IOException {
);
}

public void testRangeQuery() throws IOException {
testSearchCase(
LongPoint.newRangeQuery(SEARCHABLE_DATE, asLong("2018-01-01"), asLong("2020-01-01")),
Arrays.asList("2017-02-01", "2017-02-02", "2017-02-02", "2017-02-03", "2017-02-03", "2017-02-03", "2017-02-05"),
aggregation -> aggregation.calendarInterval(DateHistogramInterval.DAY)
.field(AGGREGABLE_DATE),
histogram -> {
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
assertEquals(0, buckets.size());
},
false
);

testSearchCase(
LongPoint.newRangeQuery(SEARCHABLE_DATE, asLong("2016-01-01"), asLong("2017-01-01")),
Arrays.asList("2017-02-01", "2017-02-02", "2017-02-02", "2017-02-03", "2017-02-03", "2017-02-03", "2017-02-05"),
aggregation -> aggregation.calendarInterval(DateHistogramInterval.DAY)
.field(AGGREGABLE_DATE),
histogram -> {
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
assertEquals(0, buckets.size());
},
false
);

testSearchCase(
LongPoint.newRangeQuery(SEARCHABLE_DATE, asLong("2016-01-01"), asLong("2017-02-02")),
Arrays.asList("2017-02-01", "2017-02-02", "2017-02-02", "2017-02-03", "2017-02-03", "2017-02-03", "2017-02-05"),
aggregation -> aggregation.calendarInterval(DateHistogramInterval.DAY)
.field(AGGREGABLE_DATE),
histogram -> {
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
assertEquals(2, buckets.size());

Histogram.Bucket bucket = buckets.get(0);
assertEquals("2017-02-01T00:00:00.000Z", bucket.getKeyAsString());
assertEquals(1, bucket.getDocCount());

bucket = buckets.get(1);
assertEquals("2017-02-02T00:00:00.000Z", bucket.getKeyAsString());
assertEquals(2, bucket.getDocCount());
},
false
);

testSearchCase(
LongPoint.newRangeQuery(SEARCHABLE_DATE, asLong("2017-02-03"), asLong("2020-01-01")),
Arrays.asList("2017-02-01", "2017-02-02", "2017-02-02", "2017-02-03", "2017-02-03", "2017-02-03", "2017-02-05"),
aggregation -> aggregation.calendarInterval(DateHistogramInterval.DAY)
.field(AGGREGABLE_DATE),
histogram -> {
List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
assertEquals(3, buckets.size());

Histogram.Bucket bucket = buckets.get(0);
assertEquals("2017-02-03T00:00:00.000Z", bucket.getKeyAsString());
assertEquals(3, bucket.getDocCount());

bucket = buckets.get(1);
assertEquals("2017-02-04T00:00:00.000Z", bucket.getKeyAsString());
assertEquals(0, bucket.getDocCount());

bucket = buckets.get(2);
assertEquals("2017-02-05T00:00:00.000Z", bucket.getKeyAsString());
assertEquals(1, bucket.getDocCount());
},
false
);
}

public void testDocCountField() throws IOException {
testSearchCase(
new MatchAllDocsQuery(),
Expand Down Expand Up @@ -1323,6 +1393,7 @@ private void testSearchCase(
boolean useDocCountField
) throws IOException {
boolean aggregableDateIsSearchable = randomBoolean();
logger.debug("Aggregable date is searchable {}", aggregableDateIsSearchable);
DateFieldMapper.DateFieldType fieldType = aggregableDateFieldType(useNanosecondResolution, aggregableDateIsSearchable);

try (Directory directory = newDirectory()) {
Expand Down

0 comments on commit 8573fe5

Please sign in to comment.