From 57fb50b22bf30148a632bd4c5e78bde53116f00f Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Wed, 19 Jun 2024 15:48:48 -0700 Subject: [PATCH] Apply the date histogram rewrite optimization to range aggregation (#13865) * Refactor the ranges representation Signed-off-by: bowenlan-amzn * Refactor try fast filter Signed-off-by: bowenlan-amzn * Main work finished; left the handling of different numeric data types Signed-off-by: bowenlan-amzn * buildRanges accepts field type Signed-off-by: bowenlan-amzn * first working draft probably Signed-off-by: bowenlan-amzn * add change log Signed-off-by: bowenlan-amzn * accommodate geo distance agg Signed-off-by: bowenlan-amzn * Fix test support all numeric types minus one on the upper range Signed-off-by: bowenlan-amzn * [Refactor] range is lower inclusive, right exclusive Signed-off-by: bowenlan-amzn * adding test Signed-off-by: bowenlan-amzn * Adding test and refactor Signed-off-by: bowenlan-amzn * refactor Signed-off-by: bowenlan-amzn * add test Signed-off-by: bowenlan-amzn * add test and update the compare logic in tree traversal Signed-off-by: bowenlan-amzn * fix test, add random test Signed-off-by: bowenlan-amzn * refactor to address comments Signed-off-by: bowenlan-amzn * small potential performance update Signed-off-by: bowenlan-amzn * fix precommit Signed-off-by: bowenlan-amzn * refactor Signed-off-by: bowenlan-amzn * refactor Signed-off-by: bowenlan-amzn * set refresh_interval to -1 Signed-off-by: bowenlan-amzn * address comment Signed-off-by: bowenlan-amzn * address comment Signed-off-by: bowenlan-amzn * address comment Signed-off-by: bowenlan-amzn * Fix test To understand fully about the double and bigdecimal usage in scaled float field will take more time. Signed-off-by: bowenlan-amzn --------- Signed-off-by: bowenlan-amzn --- CHANGELOG.md | 1 + .../index/mapper/ScaledFloatFieldMapper.java | 18 +- .../test/search.aggregation/40_range.yml | 139 ++++++ .../index/mapper/DateFieldMapper.java | 9 +- .../index/mapper/NumberFieldMapper.java | 87 +++- .../index/mapper/NumericPointEncoder.java | 16 + .../bucket/FastFilterRewriteHelper.java | 470 +++++++++++------- .../bucket/composite/CompositeAggregator.java | 19 +- .../AutoDateHistogramAggregator.java | 17 +- .../histogram/DateHistogramAggregator.java | 17 +- .../range/AbstractRangeAggregatorFactory.java | 3 +- .../range/GeoDistanceAggregatorSupplier.java | 4 +- .../GeoDistanceRangeAggregatorFactory.java | 9 +- .../bucket/range/RangeAggregator.java | 37 +- .../bucket/range/RangeAggregatorSupplier.java | 4 +- .../DateHistogramAggregatorTests.java | 2 +- .../bucket/range/RangeAggregatorTests.java | 282 ++++++++++- 17 files changed, 902 insertions(+), 232 deletions(-) create mode 100644 server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 24a4f6fd1b1f1..8c3e63e36bc82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - Add fingerprint ingest processor ([#13724](https://github.com/opensearch-project/OpenSearch/pull/13724)) - [Remote Store] Rate limiter for remote store low priority uploads ([#14374](https://github.com/opensearch-project/OpenSearch/pull/14374/)) +- Apply the date histogram rewrite optimization to range aggregation ([#13865](https://github.com/opensearch-project/OpenSearch/pull/13865)) ### Dependencies - Bump `org.gradle.test-retry` from 1.5.8 to 1.5.9 ([#13442](https://github.com/opensearch-project/OpenSearch/pull/13442)) diff --git a/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java b/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java index 400d867296e5f..3115dce6c10a5 100644 --- a/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java @@ -35,6 +35,7 @@ import com.fasterxml.jackson.core.JsonParseException; import org.apache.lucene.document.Field; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; @@ -165,7 +166,7 @@ public ScaledFloatFieldMapper build(BuilderContext context) { public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n, c.getSettings())); - public static final class ScaledFloatFieldType extends SimpleMappedFieldType { + public static final class ScaledFloatFieldType extends SimpleMappedFieldType implements NumericPointEncoder { private final double scalingFactor; private final Double nullValue; @@ -188,6 +189,21 @@ public ScaledFloatFieldType(String name, double scalingFactor) { this(name, true, false, true, Collections.emptyMap(), scalingFactor, null); } + @Override + public byte[] encodePoint(Number value) { + assert value instanceof Double; + double doubleValue = (Double) value; + byte[] point = new byte[Long.BYTES]; + if (doubleValue == Double.POSITIVE_INFINITY) { + LongPoint.encodeDimension(Long.MAX_VALUE, point, 0); + } else if (doubleValue == Double.NEGATIVE_INFINITY) { + LongPoint.encodeDimension(Long.MIN_VALUE, point, 0); + } else { + LongPoint.encodeDimension(Math.round(scale(value)), point, 0); + } + return point; + } + public double getScalingFactor() { return scalingFactor; } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml index 7d887d56ae8fe..2fd926276d0b4 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml @@ -14,6 +14,9 @@ setup: date: type: date format: epoch_second + scaled_field: + type: scaled_float + scaling_factor: 100 - do: cluster.health: @@ -528,3 +531,139 @@ setup: - is_false: aggregations.unsigned_long_range.buckets.2.to - match: { aggregations.unsigned_long_range.buckets.2.doc_count: 0 } + +--- +"Double range profiler shows filter rewrite info": + - skip: + version: " - 2.99.99" + reason: debug info for filter rewrite added in 3.0.0 (to be backported to 2.15.0) + + - do: + indices.create: + index: test_profile + body: + settings: + number_of_replicas: 0 + refresh_interval: -1 + mappings: + properties: + ip: + type: ip + double: + type: double + date: + type: date + format: epoch_second + + - do: + bulk: + index: test_profile + refresh: true + body: + - '{"index": {}}' + - '{"double" : 42}' + - '{"index": {}}' + - '{"double" : 100}' + - '{"index": {}}' + - '{"double" : 50}' + + - do: + search: + index: test_profile + body: + size: 0 + profile: true + aggs: + double_range: + range: + field: double + ranges: + - to: 50 + - from: 50 + to: 150 + - from: 150 + + - length: { aggregations.double_range.buckets: 3 } + + - match: { aggregations.double_range.buckets.0.key: "*-50.0" } + - is_false: aggregations.double_range.buckets.0.from + - match: { aggregations.double_range.buckets.0.to: 50.0 } + - match: { aggregations.double_range.buckets.0.doc_count: 1 } + - match: { aggregations.double_range.buckets.1.key: "50.0-150.0" } + - match: { aggregations.double_range.buckets.1.from: 50.0 } + - match: { aggregations.double_range.buckets.1.to: 150.0 } + - match: { aggregations.double_range.buckets.1.doc_count: 2 } + - match: { aggregations.double_range.buckets.2.key: "150.0-*" } + - match: { aggregations.double_range.buckets.2.from: 150.0 } + - is_false: aggregations.double_range.buckets.2.to + - match: { aggregations.double_range.buckets.2.doc_count: 0 } + + - match: { profile.shards.0.aggregations.0.debug.optimized_segments: 1 } + - match: { profile.shards.0.aggregations.0.debug.unoptimized_segments: 0 } + - match: { profile.shards.0.aggregations.0.debug.leaf_visited: 1 } + - match: { profile.shards.0.aggregations.0.debug.inner_visited: 0 } + +--- +"Scaled Float Range Aggregation": + - do: + index: + index: test + id: 1 + body: { "scaled_field": 1 } + + - do: + index: + index: test + id: 2 + body: { "scaled_field": 1.53 } + + - do: + index: + index: test + id: 3 + body: { "scaled_field": -2.1 } + + - do: + index: + index: test + id: 4 + body: { "scaled_field": 1.53 } + + - do: + indices.refresh: { } + + - do: + search: + index: test + body: + size: 0 + aggs: + my_range: + range: + field: scaled_field + ranges: + - to: 0 + - from: 0 + to: 1 + - from: 1 + to: 1.5 + - from: 1.5 + + - length: { aggregations.my_range.buckets: 4 } + + - match: { aggregations.my_range.buckets.0.key: "*-0.0" } + - is_false: aggregations.my_range.buckets.0.from + - match: { aggregations.my_range.buckets.0.to: 0.0 } + - match: { aggregations.my_range.buckets.0.doc_count: 1 } + - match: { aggregations.my_range.buckets.1.key: "0.0-1.0" } + - match: { aggregations.my_range.buckets.1.from: 0.0 } + - match: { aggregations.my_range.buckets.1.to: 1.0 } + - match: { aggregations.my_range.buckets.1.doc_count: 0 } + - match: { aggregations.my_range.buckets.2.key: "1.0-1.5" } + - match: { aggregations.my_range.buckets.2.from: 1.0 } + - match: { aggregations.my_range.buckets.2.to: 1.5 } + - match: { aggregations.my_range.buckets.2.doc_count: 1 } + - match: { aggregations.my_range.buckets.3.key: "1.5-*" } + - match: { aggregations.my_range.buckets.3.from: 1.5 } + - is_false: aggregations.my_range.buckets.3.to + - match: { aggregations.my_range.buckets.3.doc_count: 2 } diff --git a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java index b7ee3bb8ca3e3..cf8703209fb37 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java @@ -348,7 +348,7 @@ public DateFieldMapper build(BuilderContext context) { * * @opensearch.internal */ - public static final class DateFieldType extends MappedFieldType { + public static final class DateFieldType extends MappedFieldType implements NumericPointEncoder { protected final DateFormatter dateTimeFormatter; protected final DateMathParser dateMathParser; protected final Resolution resolution; @@ -549,6 +549,13 @@ public static long parseToLong( return resolution.convert(dateParser.parse(BytesRefs.toString(value), now, roundUp, zone)); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Long.BYTES]; + LongPoint.encodeDimension(value.longValue(), point, 0); + return point; + } + @Override public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) { failIfNotIndexedAndNoDocValues(); diff --git a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java index eb3a99b0e0388..25e5f9970795f 100644 --- a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java @@ -171,7 +171,7 @@ public NumberFieldMapper build(BuilderContext context) { * * @opensearch.internal */ - public enum NumberType { + public enum NumberType implements NumericPointEncoder { HALF_FLOAT("half_float", NumericType.HALF_FLOAT) { @Override public Float parse(Object value, boolean coerce) { @@ -194,6 +194,13 @@ public Number parsePoint(byte[] value) { return HalfFloatPoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[HalfFloatPoint.BYTES]; + HalfFloatPoint.encodeDimension(value.floatValue(), point, 0); + return point; + } + @Override public Float parse(XContentParser parser, boolean coerce) throws IOException { float parsed = parser.floatValue(coerce); @@ -331,6 +338,13 @@ public Number parsePoint(byte[] value) { return FloatPoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Float.BYTES]; + FloatPoint.encodeDimension(value.floatValue(), point, 0); + return point; + } + @Override public Float parse(XContentParser parser, boolean coerce) throws IOException { float parsed = parser.floatValue(coerce); @@ -457,6 +471,13 @@ public Number parsePoint(byte[] value) { return DoublePoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Double.BYTES]; + DoublePoint.encodeDimension(value.doubleValue(), point, 0); + return point; + } + @Override public Double parse(XContentParser parser, boolean coerce) throws IOException { double parsed = parser.doubleValue(coerce); @@ -582,6 +603,13 @@ public Number parsePoint(byte[] value) { return INTEGER.parsePoint(value).byteValue(); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Integer.BYTES]; + IntPoint.encodeDimension(value.intValue(), point, 0); + return point; + } + @Override public Short parse(XContentParser parser, boolean coerce) throws IOException { int value = parser.intValue(coerce); @@ -654,6 +682,13 @@ public Number parsePoint(byte[] value) { return INTEGER.parsePoint(value).shortValue(); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Integer.BYTES]; + IntPoint.encodeDimension(value.intValue(), point, 0); + return point; + } + @Override public Short parse(XContentParser parser, boolean coerce) throws IOException { return parser.shortValue(coerce); @@ -722,6 +757,13 @@ public Number parsePoint(byte[] value) { return IntPoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Integer.BYTES]; + IntPoint.encodeDimension(value.intValue(), point, 0); + return point; + } + @Override public Integer parse(XContentParser parser, boolean coerce) throws IOException { return parser.intValue(coerce); @@ -868,6 +910,13 @@ public Number parsePoint(byte[] value) { return LongPoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[Long.BYTES]; + LongPoint.encodeDimension(value.longValue(), point, 0); + return point; + } + @Override public Long parse(XContentParser parser, boolean coerce) throws IOException { return parser.longValue(coerce); @@ -988,6 +1037,13 @@ public Number parsePoint(byte[] value) { return BigIntegerPoint.decodeDimension(value, 0); } + @Override + public byte[] encodePoint(Number value) { + byte[] point = new byte[BigIntegerPoint.BYTES]; + BigIntegerPoint.encodeDimension(objectToUnsignedLong(value, false, true), point, 0); + return point; + } + @Override public BigInteger parse(XContentParser parser, boolean coerce) throws IOException { return parser.bigIntegerValue(coerce); @@ -1215,16 +1271,30 @@ public static long objectToLong(Object value, boolean coerce) { return Numbers.toLong(stringValue, coerce); } + public static BigInteger objectToUnsignedLong(Object value, boolean coerce) { + return objectToUnsignedLong(value, coerce, false); + } + /** - * Converts and Object to a {@code long} by checking it against known + * Converts an Object to a {@code BigInteger} by checking it against known * types and checking its range. + * + * @param lenientBound if true, use MIN or MAX if the value is out of bound */ - public static BigInteger objectToUnsignedLong(Object value, boolean coerce) { + public static BigInteger objectToUnsignedLong(Object value, boolean coerce, boolean lenientBound) { if (value instanceof Long) { return Numbers.toUnsignedBigInteger(((Long) value).longValue()); } double doubleValue = objectToDouble(value); + if (lenientBound) { + if (doubleValue < Numbers.MIN_UNSIGNED_LONG_VALUE.doubleValue()) { + return Numbers.MIN_UNSIGNED_LONG_VALUE; + } + if (doubleValue > Numbers.MAX_UNSIGNED_LONG_VALUE.doubleValue()) { + return Numbers.MAX_UNSIGNED_LONG_VALUE; + } + } if (doubleValue < Numbers.MIN_UNSIGNED_LONG_VALUE.doubleValue() || doubleValue > Numbers.MAX_UNSIGNED_LONG_VALUE.doubleValue()) { throw new IllegalArgumentException("Value [" + value + "] is out of range for an unsigned long"); @@ -1349,7 +1419,7 @@ public static Query unsignedLongRangeQuery( * * @opensearch.internal */ - public static class NumberFieldType extends SimpleMappedFieldType { + public static class NumberFieldType extends SimpleMappedFieldType implements NumericPointEncoder { private final NumberType type; private final boolean coerce; @@ -1394,6 +1464,10 @@ public String typeName() { return type.name; } + public NumberType numberType() { + return type; + } + public NumericType numericType() { return type.numericType(); } @@ -1501,6 +1575,11 @@ public DocValueFormat docValueFormat(String format, ZoneId timeZone) { public Number parsePoint(byte[] value) { return type.parsePoint(value); } + + @Override + public byte[] encodePoint(Number value) { + return type.encodePoint(value); + } } private final NumberType type; diff --git a/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java b/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java new file mode 100644 index 0000000000000..be746a5526594 --- /dev/null +++ b/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.mapper; + +/** + * Interface for encoding a point value + */ +public interface NumericPointEncoder { + byte[] encodePoint(Number value); +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java index c8ce39a52f869..2ab003fb94e33 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; @@ -32,24 +33,26 @@ import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.DocCountFieldMapper; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumericPointEncoder; import org.opensearch.index.query.DateRangeIncludingNowQuery; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregator; import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.bucket.range.RangeAggregator.Range; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.OptionalLong; import java.util.function.BiConsumer; -import java.util.function.BiFunction; import java.util.function.Function; +import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** @@ -187,8 +190,8 @@ public static class FastFilterContext { private AggregationType aggregationType; private final SearchContext context; - private String fieldName; - private long[][] ranges; + private MappedFieldType fieldType; + private Ranges ranges; // debug info related fields public int leaf; @@ -196,12 +199,13 @@ public static class FastFilterContext { public int segments; public int optimizedSegments; - public void setFieldName(String fieldName) { - this.fieldName = fieldName; + public FastFilterContext(SearchContext context) { + this.context = context; } - public FastFilterContext(SearchContext context) { + public FastFilterContext(SearchContext context, AggregationType aggregationType) { this.context = context; + this.aggregationType = aggregationType; } public AggregationType getAggregationType() { @@ -221,23 +225,87 @@ public boolean isRewriteable(final Object parent, final int subAggLength) { return rewriteable; } - public void buildRanges() throws IOException { + public void buildRanges(MappedFieldType fieldType) throws IOException { assert ranges == null : "Ranges should only be built once at shard level, but they are already built"; - this.ranges = this.aggregationType.buildRanges(context); + this.fieldType = fieldType; + this.ranges = this.aggregationType.buildRanges(context, fieldType); if (ranges != null) { logger.debug("Ranges built for shard {}", context.indexShard().shardId()); rangesBuiltAtShardLevel = true; } } - public long[][] buildRanges(LeafReaderContext leaf) throws IOException { - long[][] ranges = this.aggregationType.buildRanges(leaf, context); + private Ranges buildRanges(LeafReaderContext leaf) throws IOException { + Ranges ranges = this.aggregationType.buildRanges(leaf, context, fieldType); if (ranges != null) { logger.debug("Ranges built for shard {} segment {}", context.indexShard().shardId(), leaf.ord); } return ranges; } + /** + * Try to populate the bucket doc counts for aggregation + *

+ * Usage: invoked at segment level — in getLeafCollector of aggregator + * + * @param bucketOrd bucket ordinal producer + * @param incrementDocCount consume the doc_count results for certain ordinal + */ + public boolean tryFastFilterAggregation( + final LeafReaderContext ctx, + final BiConsumer incrementDocCount, + final Function bucketOrd + ) throws IOException { + this.segments++; + if (!this.rewriteable) { + return false; + } + + if (ctx.reader().hasDeletions()) return false; + + PointValues values = ctx.reader().getPointValues(this.fieldType.name()); + if (values == null) return false; + // only proceed if every document corresponds to exactly one point + if (values.getDocCount() != values.size()) return false; + + NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME); + if (docCountValues.nextDoc() != NO_MORE_DOCS) { + logger.debug( + "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", + this.context.indexShard().shardId(), + ctx.ord + ); + return false; + } + + // even if no ranges built at shard level, we can still perform the optimization + // when functionally match-all at segment level + if (!this.rangesBuiltAtShardLevel && !segmentMatchAll(this.context, ctx)) { + return false; + } + + Ranges ranges = this.ranges; + if (ranges == null) { + logger.debug( + "Shard {} segment {} functionally match all documents. Build the fast filter", + this.context.indexShard().shardId(), + ctx.ord + ); + ranges = this.buildRanges(ctx); + if (ranges == null) { + return false; + } + } + + DebugInfo debugInfo = this.aggregationType.tryFastFilterAggregation(values, ranges, incrementDocCount, bucketOrd); + this.consumeDebugInfo(debugInfo); + + this.optimizedSegments++; + logger.debug("Fast filter optimization applied to shard {} segment {}", this.context.indexShard().shardId(), ctx.ord); + logger.debug("crossed leaf nodes: {}, inner nodes: {}", this.leaf, this.inner); + return true; + } + private void consumeDebugInfo(DebugInfo debug) { leaf += debug.leaf; inner += debug.inner; @@ -250,9 +318,16 @@ private void consumeDebugInfo(DebugInfo debug) { interface AggregationType { boolean isRewriteable(Object parent, int subAggLength); - long[][] buildRanges(SearchContext ctx) throws IOException; + Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) throws IOException; - long[][] buildRanges(LeafReaderContext leaf, SearchContext ctx) throws IOException; + Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) throws IOException; + + DebugInfo tryFastFilterAggregation( + PointValues values, + Ranges ranges, + BiConsumer incrementDocCount, + Function bucketOrd + ) throws IOException; } /** @@ -286,13 +361,20 @@ public boolean isRewriteable(Object parent, int subAggLength) { } @Override - public long[][] buildRanges(SearchContext context) throws IOException { + public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) throws IOException { long[] bounds = getDateHistoAggBounds(context, fieldType.name()); logger.debug("Bounds are {} for shard {}", bounds, context.indexShard().shardId()); return buildRanges(context, bounds); } - private long[][] buildRanges(SearchContext context, long[] bounds) throws IOException { + @Override + public Ranges buildRanges(LeafReaderContext leaf, SearchContext context, MappedFieldType fieldType) throws IOException { + long[] bounds = getSegmentBounds(leaf, fieldType.name()); + logger.debug("Bounds are {} for shard {} segment {}", bounds, context.indexShard().shardId(), leaf.ord); + return buildRanges(context, bounds); + } + + private Ranges buildRanges(SearchContext context, long[] bounds) throws IOException { bounds = processHardBounds(bounds); if (bounds == null) { return null; @@ -319,13 +401,6 @@ private long[][] buildRanges(SearchContext context, long[] bounds) throws IOExce ); } - @Override - public long[][] buildRanges(LeafReaderContext leaf, SearchContext context) throws IOException { - long[] bounds = getSegmentBounds(leaf, fieldType.name()); - logger.debug("Bounds are {} for shard {} segment {}", bounds, context.indexShard().shardId(), leaf.ord); - return buildRanges(context, bounds); - } - protected abstract Rounding getRounding(final long low, final long high); protected abstract Rounding.Prepared getRoundingPrepared(); @@ -354,86 +429,118 @@ public DateFieldMapper.DateFieldType getFieldType() { assert fieldType instanceof DateFieldMapper.DateFieldType; return (DateFieldMapper.DateFieldType) fieldType; } - } - public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) { - return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource; - } + @Override + public DebugInfo tryFastFilterAggregation( + PointValues values, + Ranges ranges, + BiConsumer incrementDocCount, + Function bucketOrd + ) throws IOException { + int size = Integer.MAX_VALUE; + if (this instanceof CompositeAggregator.CompositeAggregationType) { + size = ((CompositeAggregator.CompositeAggregationType) this).getSize(); + } + + DateFieldMapper.DateFieldType fieldType = getFieldType(); + BiConsumer incrementFunc = (activeIndex, docCount) -> { + long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrd.apply(rangeStart)); + incrementDocCount.accept(ord, (long) docCount); + }; - public static long getBucketOrd(long bucketOrd) { - if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; + return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); } - return bucketOrd; + private static long getBucketOrd(long bucketOrd) { + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + } + + return bucketOrd; + } } /** - * Try to get the bucket doc counts for the date histogram aggregation - *

- * Usage: invoked at segment level — in getLeafCollector of aggregator - * - * @param incrementDocCount takes in the bucket key value and the bucket count + * For range aggregation */ - public static boolean tryFastFilterAggregation( - final LeafReaderContext ctx, - FastFilterContext fastFilterContext, - final BiConsumer incrementDocCount - ) throws IOException { - fastFilterContext.segments++; - if (!fastFilterContext.rewriteable) { - return false; - } + public static class RangeAggregationType implements AggregationType { - if (ctx.reader().hasDeletions()) return false; + private final ValuesSourceConfig config; + private final Range[] ranges; - PointValues values = ctx.reader().getPointValues(fastFilterContext.fieldName); - if (values == null) return false; - // only proceed if every document corresponds to exactly one point - if (values.getDocCount() != values.size()) return false; - - NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME); - if (docCountValues.nextDoc() != NO_MORE_DOCS) { - logger.debug( - "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", - fastFilterContext.context.indexShard().shardId(), - ctx.ord - ); - return false; + public RangeAggregationType(ValuesSourceConfig config, Range[] ranges) { + this.config = config; + this.ranges = ranges; } - // even if no ranges built at shard level, we can still perform the optimization - // when functionally match-all at segment level - if (!fastFilterContext.rangesBuiltAtShardLevel && !segmentMatchAll(fastFilterContext.context, ctx)) { + @Override + public boolean isRewriteable(Object parent, int subAggLength) { + if (config.fieldType() == null) return false; + MappedFieldType fieldType = config.fieldType(); + if (fieldType.isSearchable() == false || !(fieldType instanceof NumericPointEncoder)) return false; + + if (parent == null && subAggLength == 0 && config.script() == null && config.missing() == null) { + if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) { + // ranges are already sorted by from and then to + // we want ranges not overlapping with each other + double prevTo = ranges[0].getTo(); + for (int i = 1; i < ranges.length; i++) { + if (prevTo > ranges[i].getFrom()) { + return false; + } + prevTo = ranges[i].getTo(); + } + return true; + } + } return false; } - long[][] ranges = fastFilterContext.ranges; - if (ranges == null) { - logger.debug( - "Shard {} segment {} functionally match all documents. Build the fast filter", - fastFilterContext.context.indexShard().shardId(), - ctx.ord - ); - ranges = fastFilterContext.buildRanges(ctx); - if (ranges == null) { - return false; + + @Override + public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) { + assert fieldType instanceof NumericPointEncoder; + NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; + byte[][] lowers = new byte[ranges.length][]; + byte[][] uppers = new byte[ranges.length][]; + for (int i = 0; i < ranges.length; i++) { + double rangeMin = ranges[i].getFrom(); + double rangeMax = ranges[i].getTo(); + byte[] lower = numericPointEncoder.encodePoint(rangeMin); + byte[] upper = numericPointEncoder.encodePoint(rangeMax); + lowers[i] = lower; + uppers[i] = upper; } + + return new Ranges(lowers, uppers); } - final AggregationType aggregationType = fastFilterContext.aggregationType; - assert aggregationType instanceof AbstractDateHistogramAggregationType; - final DateFieldMapper.DateFieldType fieldType = ((AbstractDateHistogramAggregationType) aggregationType).getFieldType(); - int size = Integer.MAX_VALUE; - if (aggregationType instanceof CompositeAggregator.CompositeAggregationType) { - size = ((CompositeAggregator.CompositeAggregationType) aggregationType).getSize(); + @Override + public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) { + throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level"); + } + + @Override + public DebugInfo tryFastFilterAggregation( + PointValues values, + Ranges ranges, + BiConsumer incrementDocCount, + Function bucketOrd + ) throws IOException { + int size = Integer.MAX_VALUE; + + BiConsumer incrementFunc = (activeIndex, docCount) -> { + long ord = bucketOrd.apply(activeIndex); + incrementDocCount.accept(ord, (long) docCount); + }; + + return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); } - DebugInfo debugInfo = multiRangesTraverse(values.getPointTree(), ranges, incrementDocCount, fieldType, size); - fastFilterContext.consumeDebugInfo(debugInfo); + } - fastFilterContext.optimizedSegments++; - logger.debug("Fast filter optimization applied to shard {} segment {}", fastFilterContext.context.indexShard().shardId(), ctx.ord); - logger.debug("crossed leaf nodes: {}, inner nodes: {}", fastFilterContext.leaf, fastFilterContext.inner); - return true; + public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) { + return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource; } private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException { @@ -445,7 +552,7 @@ private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leaf * Creates the date ranges from date histo aggregations using its interval, * and min/max boundaries */ - private static long[][] createRangesFromAgg( + private static Ranges createRangesFromAgg( final SearchContext context, final DateFieldMapper.DateFieldType fieldType, final long interval, @@ -481,9 +588,8 @@ private static long[][] createRangesFromAgg( long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow); roundedLow = preparedRounding.round(roundedLow + interval); - // Subtract -1 if the minimum is roundedLow as roundedLow itself - // is included in the next bucket - long upper = i + 1 == bucketCount ? high : fieldType.convertRoundedMillisToNanos(roundedLow) - 1; + // plus one on high value because upper bound is exclusive, but high value exists + long upper = i + 1 == bucketCount ? high + 1 : fieldType.convertRoundedMillisToNanos(roundedLow); ranges[i][0] = lower; ranges[i][1] = upper; @@ -491,7 +597,16 @@ private static long[][] createRangesFromAgg( } } - return ranges; + byte[][] lowers = new byte[ranges.length][]; + byte[][] uppers = new byte[ranges.length][]; + for (int i = 0; i < ranges.length; i++) { + byte[] lower = LONG.encodePoint(ranges[i][0]); + byte[] max = LONG.encodePoint(ranges[i][1]); + lowers[i] = lower; + uppers[i] = max; + } + + return new Ranges(lowers, uppers); } /** @@ -499,39 +614,18 @@ private static long[][] createRangesFromAgg( */ private static DebugInfo multiRangesTraverse( final PointValues.PointTree tree, - final long[][] ranges, - final BiConsumer incrementDocCount, - final DateFieldMapper.DateFieldType fieldType, + final Ranges ranges, + final BiConsumer incrementDocCount, final int maxNumNonZeroRanges ) throws IOException { - // ranges are connected and in ascending order - Iterator rangeIter = Arrays.stream(ranges).iterator(); - long[] activeRange = rangeIter.next(); - - // make sure the first range at least crosses the min value of the tree DebugInfo debugInfo = new DebugInfo(); - if (activeRange[0] > NumericUtils.sortableBytesToLong(tree.getMaxPackedValue(), 0)) { + int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); + if (activeIndex < 0) { logger.debug("No ranges match the query, skip the fast filter optimization"); return debugInfo; } - while (activeRange[1] < NumericUtils.sortableBytesToLong(tree.getMinPackedValue(), 0)) { - if (!rangeIter.hasNext()) { - logger.debug("No ranges match the query, skip the fast filter optimization"); - return debugInfo; - } - activeRange = rangeIter.next(); - } - - RangeCollectorForPointTree collector = new RangeCollectorForPointTree( - incrementDocCount, - fieldType, - rangeIter, - maxNumNonZeroRanges, - activeRange - ); - - final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(8); - PointValues.IntersectVisitor visitor = getIntersectVisitor(collector, comparator); + RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); + PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); try { intersectWithRanges(visitor, tree, collector, debugInfo); } catch (CollectionTerminatedException e) { @@ -542,6 +636,49 @@ private static DebugInfo multiRangesTraverse( return debugInfo; } + private static class Ranges { + byte[][] lowers; // inclusive + byte[][] uppers; // exclusive + int size; + int byteLen; + static ArrayUtil.ByteArrayComparator comparator; + + Ranges(byte[][] lowers, byte[][] uppers) { + this.lowers = lowers; + this.uppers = uppers; + assert lowers.length == uppers.length; + this.size = lowers.length; + this.byteLen = lowers[0].length; + comparator = ArrayUtil.getUnsignedComparator(byteLen); + } + + public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { + if (compareByteValue(lowers[0], globalMax) > 0) { + return -1; + } + int i = 0; + while (compareByteValue(uppers[i], globalMin) <= 0) { + i++; + if (i >= size) { + return -1; + } + } + return i; + } + + public static int compareByteValue(byte[] value1, byte[] value2) { + return comparator.compare(value1, 0, value2, 0); + } + + public static boolean withinLowerBound(byte[] value, byte[] lowerBound) { + return compareByteValue(value, lowerBound) >= 0; + } + + public static boolean withinUpperBound(byte[] value, byte[] upperBound) { + return compareByteValue(value, upperBound) < 0; + } + } + private static void intersectWithRanges( PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, @@ -570,10 +707,7 @@ private static void intersectWithRanges( } } - private static PointValues.IntersectVisitor getIntersectVisitor( - RangeCollectorForPointTree collector, - ArrayUtil.ByteArrayComparator comparator - ) { + private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { return new PointValues.IntersectVisitor() { @Override public void visit(int docID) throws IOException { @@ -591,86 +725,67 @@ public void visit(int docID, byte[] packedValue) throws IOException { @Override public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { visitPoints(packedValue, () -> { - for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { collector.count(); } }); } private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { - if (comparator.compare(packedValue, 0, collector.activeRangeAsByteArray[1], 0) > 0) { - // need to move to next range + if (!collector.withinUpperBound(packedValue)) { collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(packedValue, this::compareByteValue)) { + if (collector.iterateRangeEnd(packedValue)) { throw new CollectionTerminatedException(); } } - if (pointCompare(collector.activeRangeAsByteArray[0], collector.activeRangeAsByteArray[1], packedValue)) { + if (collector.withinRange(packedValue)) { collect.run(); } } - private boolean pointCompare(byte[] lower, byte[] upper, byte[] packedValue) { - if (compareByteValue(packedValue, lower) < 0) { - return false; - } - return compareByteValue(packedValue, upper) <= 0; - } - - private int compareByteValue(byte[] value1, byte[] value2) { - return comparator.compare(value1, 0, value2, 0); - } - @Override public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - byte[] rangeMin = collector.activeRangeAsByteArray[0]; - byte[] rangeMax = collector.activeRangeAsByteArray[1]; - - if (compareByteValue(rangeMax, minPackedValue) < 0) { + // try to find the first range that may collect values from this cell + if (!collector.withinUpperBound(minPackedValue)) { collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(minPackedValue, this::compareByteValue)) { + if (collector.iterateRangeEnd(minPackedValue)) { throw new CollectionTerminatedException(); } - // compare the next range with this node's min max again - // new rangeMin = previous rangeMax + 1 <= min - rangeMax = collector.activeRangeAsByteArray[1]; } - - if (compareByteValue(rangeMin, minPackedValue) > 0 || compareByteValue(rangeMax, maxPackedValue) < 0) { - return PointValues.Relation.CELL_CROSSES_QUERY; - } else { + // after the loop, min < upper + // cell could be outside [min max] lower + if (!collector.withinLowerBound(maxPackedValue)) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) { return PointValues.Relation.CELL_INSIDE_QUERY; } + return PointValues.Relation.CELL_CROSSES_QUERY; } }; } private static class RangeCollectorForPointTree { - private final BiConsumer incrementDocCount; - private final DateFieldMapper.DateFieldType fieldType; + private final BiConsumer incrementRangeDocCount; private int counter = 0; - private long[] activeRange; - private byte[][] activeRangeAsByteArray; - private final Iterator rangeIter; + private final Ranges ranges; + private int activeIndex; private int visitedRange = 0; private final int maxNumNonZeroRange; public RangeCollectorForPointTree( - BiConsumer incrementDocCount, - DateFieldMapper.DateFieldType fieldType, - Iterator rangeIter, + BiConsumer incrementRangeDocCount, int maxNumNonZeroRange, - long[] activeRange + Ranges ranges, + int activeIndex ) { - this.incrementDocCount = incrementDocCount; - this.fieldType = fieldType; - this.rangeIter = rangeIter; + this.incrementRangeDocCount = incrementRangeDocCount; this.maxNumNonZeroRange = maxNumNonZeroRange; - this.activeRange = activeRange; - this.activeRangeAsByteArray = activeRangeAsByteArray(); + this.ranges = ranges; + this.activeIndex = activeIndex; } private void count() { @@ -683,9 +798,7 @@ private void countNode(int count) { private void finalizePreviousRange() { if (counter > 0) { - logger.debug("finalize previous range: {}", activeRange[0]); - logger.debug("counter: {}", counter); - incrementDocCount.accept(fieldType.convertNanosToMillis(activeRange[0]), counter); + incrementRangeDocCount.accept(activeIndex, counter); counter = 0; } } @@ -693,29 +806,34 @@ private void finalizePreviousRange() { /** * @return true when iterator exhausted or collect enough non-zero ranges */ - private boolean iterateRangeEnd(byte[] value, BiFunction comparator) { + private boolean iterateRangeEnd(byte[] value) { // the new value may not be contiguous to the previous one // so try to find the first next range that cross the new value - while (comparator.apply(activeRangeAsByteArray[1], value) < 0) { - if (!rangeIter.hasNext()) { + while (!withinUpperBound(value)) { + if (++activeIndex >= ranges.size) { return true; } - activeRange = rangeIter.next(); - activeRangeAsByteArray = activeRangeAsByteArray(); } visitedRange++; return visitedRange > maxNumNonZeroRange; } - private byte[][] activeRangeAsByteArray() { - byte[] lower = new byte[8]; - byte[] upper = new byte[8]; - NumericUtils.longToSortableBytes(activeRange[0], lower, 0); - NumericUtils.longToSortableBytes(activeRange[1], upper, 0); - return new byte[][] { lower, upper }; + private boolean withinLowerBound(byte[] value) { + return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]); + } + + private boolean withinUpperBound(byte[] value) { + return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]); + } + + private boolean withinRange(byte[] value) { + return withinLowerBound(value) && withinUpperBound(value); } } + /** + * Contains debug info of BKD traversal to show in profile + */ private static class DebugInfo { private int leaf = 0; // leaf node visited private int inner = 0; // inner node visited diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index 3713d8f83990d..bfb484dcf478d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -74,6 +74,7 @@ import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.bucket.BucketsAggregator; import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; +import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper.AbstractDateHistogramAggregationType; import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.internal.SearchContext; @@ -166,21 +167,22 @@ public final class CompositeAggregator extends BucketsAggregator { this.rawAfterKey = rawAfterKey; fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context); - if (!FastFilterRewriteHelper.isCompositeAggRewriteable(sourceConfigs)) return; + if (!FastFilterRewriteHelper.isCompositeAggRewriteable(sourceConfigs)) { + return; + } fastFilterContext.setAggregationType(new CompositeAggregationType()); if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { // bucketOrds is used for saving date histogram results bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE); preparedRounding = ((CompositeAggregationType) fastFilterContext.getAggregationType()).getRoundingPrepared(); - fastFilterContext.setFieldName(sourceConfigs[0].fieldType().name()); - fastFilterContext.buildRanges(); + fastFilterContext.buildRanges(sourceConfigs[0].fieldType()); } } /** * Currently the filter rewrite is only supported for date histograms */ - public class CompositeAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType { + public class CompositeAggregationType extends AbstractDateHistogramAggregationType { private final RoundingValuesSource valuesSource; private long afterKey = -1L; @@ -549,13 +551,10 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation( + boolean optimized = fastFilterContext.tryFastFilterAggregation( ctx, - fastFilterContext, - (key, count) -> incrementBucketDocCount( - FastFilterRewriteHelper.getBucketOrd(bucketOrds.add(0, preparedRounding.round(key))), - count - ) + this::incrementBucketDocCount, + (key) -> bucketOrds.add(0, preparedRounding.round((long) key)) ); if (optimized) throw new CollectionTerminatedException(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index f326426800909..d13d575a9d696 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -64,6 +64,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongToIntFunction; @@ -157,8 +158,8 @@ private AutoDateHistogramAggregator( this.roundingPreparer = roundingPreparer; this.preparedRounding = prepareRounding(0); - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context); - fastFilterContext.setAggregationType( + fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( + context, new AutoHistogramAggregationType( valuesSourceConfig.fieldType(), valuesSourceConfig.missing() != null, @@ -166,8 +167,7 @@ private AutoDateHistogramAggregator( ) ); if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - fastFilterContext.setFieldName(valuesSourceConfig.fieldType().name()); - fastFilterContext.buildRanges(); + fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType())); } } @@ -236,13 +236,10 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation( + boolean optimized = fastFilterContext.tryFastFilterAggregation( ctx, - fastFilterContext, - (key, count) -> incrementBucketDocCount( - FastFilterRewriteHelper.getBucketOrd(getBucketOrds().add(0, preparedRounding.round(key))), - count - ) + this::incrementBucketDocCount, + (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)) ); if (optimized) throw new CollectionTerminatedException(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index dd4ee9196fd62..4b84797c18922 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -58,6 +58,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.Objects; import java.util.function.BiConsumer; /** @@ -116,8 +117,8 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality); - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context); - fastFilterContext.setAggregationType( + fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( + context, new DateHistogramAggregationType( valuesSourceConfig.fieldType(), valuesSourceConfig.missing() != null, @@ -126,8 +127,7 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg ) ); if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - fastFilterContext.setFieldName(valuesSourceConfig.fieldType().name()); - fastFilterContext.buildRanges(); + fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType())); } } @@ -162,13 +162,10 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = FastFilterRewriteHelper.tryFastFilterAggregation( + boolean optimized = fastFilterContext.tryFastFilterAggregation( ctx, - fastFilterContext, - (key, count) -> incrementBucketDocCount( - FastFilterRewriteHelper.getBucketOrd(bucketOrds.add(0, preparedRounding.round(key))), - count - ) + this::incrementBucketDocCount, + (key) -> bucketOrds.add(0, preparedRounding.round((long) key)) ); if (optimized) throw new CollectionTerminatedException(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/AbstractRangeAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/AbstractRangeAggregatorFactory.java index 41f2768eb7544..fd334638a0c1f 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/AbstractRangeAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/AbstractRangeAggregatorFactory.java @@ -119,7 +119,8 @@ protected Aggregator doCreateInternal( searchContext, parent, cardinality, - metadata + metadata, + config ); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceAggregatorSupplier.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceAggregatorSupplier.java index c4a9efda18bda..d72c817c4515b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceAggregatorSupplier.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceAggregatorSupplier.java @@ -39,6 +39,7 @@ import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.CardinalityUpperBound; import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -64,6 +65,7 @@ Aggregator build( SearchContext context, Aggregator parent, CardinalityUpperBound cardinality, - Map metadata + Map metadata, + ValuesSourceConfig valuesSourceConfig ) throws IOException; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceRangeAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceRangeAggregatorFactory.java index 728f43094cf7e..f9e966deb3cc9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceRangeAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/GeoDistanceRangeAggregatorFactory.java @@ -82,7 +82,8 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { context, parent, cardinality, - metadata) -> { + metadata, + config) -> { DistanceSource distanceSource = new DistanceSource((ValuesSource.GeoPoint) valuesSource, distanceType, origin, units); return new RangeAggregator( name, @@ -95,7 +96,8 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { context, parent, cardinality, - metadata + metadata, + config ); }, true @@ -168,7 +170,8 @@ protected Aggregator doCreateInternal( searchContext, parent, cardinality, - metadata + metadata, + config ); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index b56b817b8177b..2ba2b06514de1 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -32,6 +32,7 @@ package org.opensearch.search.aggregations.bucket.range; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ScoreMode; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; @@ -54,7 +55,9 @@ import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.NonCollectingAggregator; import org.opensearch.search.aggregations.bucket.BucketsAggregator; +import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -62,6 +65,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.BiConsumer; import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -245,6 +249,8 @@ public boolean equals(Object obj) { final double[] maxTo; + private final FastFilterRewriteHelper.FastFilterContext fastFilterContext; + public RangeAggregator( String name, AggregatorFactories factories, @@ -256,17 +262,16 @@ public RangeAggregator( SearchContext context, Aggregator parent, CardinalityUpperBound cardinality, - Map metadata + Map metadata, + ValuesSourceConfig config ) throws IOException { - super(name, factories, context, parent, cardinality.multiply(ranges.length), metadata); assert valuesSource != null; this.valuesSource = valuesSource; this.format = format; this.keyed = keyed; this.rangeFactory = rangeFactory; - - this.ranges = ranges; + this.ranges = ranges; // already sorted by the range.from and range.to maxTo = new double[this.ranges.length]; maxTo[0] = this.ranges[0].to; @@ -274,6 +279,13 @@ public RangeAggregator( maxTo[i] = Math.max(this.ranges[i].to, maxTo[i - 1]); } + fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( + context, + new FastFilterRewriteHelper.RangeAggregationType(config, ranges) + ); + if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { + fastFilterContext.buildRanges(Objects.requireNonNull(config.fieldType())); + } } @Override @@ -286,6 +298,13 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + boolean optimized = fastFilterContext.tryFastFilterAggregation( + ctx, + this::incrementBucketDocCount, + (activeIndex) -> subBucketOrdinal(0, (int) activeIndex) + ); + if (optimized) throw new CollectionTerminatedException(); + final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); return new LeafBucketCollectorBase(sub, values) { @Override @@ -430,4 +449,14 @@ public InternalAggregation buildEmptyAggregation() { } } + @Override + public void collectDebugInfo(BiConsumer add) { + super.collectDebugInfo(add); + if (fastFilterContext.optimizedSegments > 0) { + add.accept("optimized_segments", fastFilterContext.optimizedSegments); + add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments); + add.accept("leaf_visited", fastFilterContext.leaf); + add.accept("inner_visited", fastFilterContext.inner); + } + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorSupplier.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorSupplier.java index de9b979a56107..02b0c2e612d57 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorSupplier.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorSupplier.java @@ -36,6 +36,7 @@ import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.CardinalityUpperBound; import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -58,6 +59,7 @@ Aggregator build( SearchContext context, Aggregator parent, CardinalityUpperBound cardinality, - Map metadata + Map metadata, + ValuesSourceConfig config ) throws IOException; } diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregatorTests.java index cf95999ec5086..f6e06cce6e233 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregatorTests.java @@ -1614,7 +1614,7 @@ public void testMultiRangeTraversalNotApplicable() throws IOException { }, true, collectCount -> assertTrue(collectCount > 0), - true + false ); } diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java index dd7ae915c3b45..7e796b684e869 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java @@ -37,29 +37,44 @@ import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType; +import org.opensearch.index.mapper.NumberFieldMapper.NumberType; +import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.MultiBucketConsumerService; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.support.AggregationInspectionHelper; import java.io.IOException; import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import static java.util.Collections.singleton; +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; import static org.hamcrest.Matchers.equalTo; public class RangeAggregatorTests extends AggregatorTestCase { @@ -199,7 +214,7 @@ public void testMissingDateWithNumberField() throws IOException { .addRange(-2d, 5d) .missing("1979-01-01T00:00:00"); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); expectThrows(NumberFormatException.class, () -> testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new NumericDocValuesField(NUMBER_FIELD_NAME, 7))); @@ -212,7 +227,7 @@ public void testUnmappedWithMissingNumber() throws IOException { .addRange(-2d, 5d) .missing(0L); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new NumericDocValuesField(NUMBER_FIELD_NAME, 7))); @@ -230,7 +245,7 @@ public void testUnmappedWithMissingDate() throws IOException { .addRange(-2d, 5d) .missing("2020-02-13T10:11:12"); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); expectThrows(NumberFormatException.class, () -> testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new NumericDocValuesField(NUMBER_FIELD_NAME, 7))); @@ -257,7 +272,7 @@ public void testBadMissingField() { .addRange(-2d, 5d) .missing("bogus"); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); expectThrows(NumberFormatException.class, () -> testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new NumericDocValuesField(NUMBER_FIELD_NAME, 7))); @@ -270,7 +285,7 @@ public void testUnmappedWithBadMissingField() { .addRange(-2d, 5d) .missing("bogus"); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); expectThrows(NumberFormatException.class, () -> testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { iw.addDocument(singleton(new NumericDocValuesField(NUMBER_FIELD_NAME, 7))); @@ -305,12 +320,185 @@ public void testSubAggCollectsFromManyBucketsIfManyRanges() throws IOException { }); } + public void testOverlappingRanges() throws IOException { + testRewriteOptimizationCase( + new NumberFieldType(NumberType.DOUBLE.typeName(), NumberType.DOUBLE), + new double[][] { { 1, 2 }, { 1, 1.5 }, { 0, 0.5 } }, + new MatchAllDocsQuery(), + new Number[] { 0.1, 1.1, 2.1 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(3, ranges.size()); + assertEquals("0.0-0.5", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("1.0-1.5", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertEquals("1.0-2.0", ranges.get(2).getKeyAsString()); + assertEquals(1, ranges.get(2).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + false + ); + } + + /** + * @return Map [lower, upper) TO data points + */ + private Map buildRandomRanges(double[][] possibleRanges) { + Map dataSet = new LinkedHashMap<>(); + for (double[] range : possibleRanges) { + double lower = randomDoubleBetween(range[0], range[1], true); + double upper = randomDoubleBetween(range[0], range[1], true); + if (lower > upper) { + double d = lower; + lower = upper; + upper = d; + } + + int dataNumber = randomInt(200); + double[] data = new double[dataNumber]; + for (int i = 0; i < dataNumber; i++) { + data[i] = randomDoubleBetween(lower, upper, true); + } + dataSet.put(new double[] { lower, upper }, data); + } + + return dataSet; + } + + public void testRandomRanges() throws IOException { + Map dataSet = buildRandomRanges(new double[][] { { 0, 100 }, { 200, 1000 }, { 1000, 3000 } }); + + int size = dataSet.size(); + double[][] ranges = new double[size][]; + int[] expected = new int[size]; + List dataPoints = new LinkedList<>(); + + int i = 0; + for (Map.Entry entry : dataSet.entrySet()) { + ranges[i] = entry.getKey(); + expected[i] = entry.getValue().length; + for (double dataPoint : entry.getValue()) { + dataPoints.add(dataPoint); + } + i++; + } + + testRewriteOptimizationCase( + new NumberFieldType(NumberType.DOUBLE.typeName(), NumberType.DOUBLE), + ranges, + new MatchAllDocsQuery(), + dataPoints.toArray(new Number[0]), + range -> { + List rangeBuckets = range.getBuckets(); + assertEquals(size, rangeBuckets.size()); + for (int j = 0; j < rangeBuckets.size(); j++) { + assertEquals(expected[j], rangeBuckets.get(j).getDocCount()); + } + }, + true + ); + } + + public void testDoubleType() throws IOException { + testRewriteOptimizationCase( + new NumberFieldType(NumberType.DOUBLE.typeName(), NumberType.DOUBLE), + new double[][] { { 1, 2 }, { 2, 3 } }, + new MatchAllDocsQuery(), + new Number[] { 0.1, 1.1, 2.1 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(2, ranges.size()); + assertEquals("1.0-2.0", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("2.0-3.0", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + true + ); + } + + public void testHalfFloatType() throws IOException { + testRewriteOptimizationCase( + new NumberFieldType(NumberType.HALF_FLOAT.typeName(), NumberType.HALF_FLOAT), + new double[][] { { 1, 2 }, { 2, 3 } }, + new MatchAllDocsQuery(), + new Number[] { 0.1, 1.1, 2.1 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(2, ranges.size()); + assertEquals("1.0-2.0", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("2.0-3.0", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + true + ); + } + + public void testFloatType() throws IOException { + testRewriteOptimizationCase( + new NumberFieldType(NumberType.FLOAT.typeName(), NumberType.FLOAT), + new double[][] { { 1, 2 }, { 2, 3 } }, + new MatchAllDocsQuery(), + new Number[] { 0.1, 1.1, 2.1 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(2, ranges.size()); + assertEquals("1.0-2.0", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("2.0-3.0", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + true + ); + } + + public void testUnsignedLongType() throws IOException { + testRewriteOptimizationCase( + new NumberFieldType(NumberType.UNSIGNED_LONG.typeName(), NumberType.UNSIGNED_LONG), + new double[][] { { 1, 2 }, { 2, 3 } }, + new MatchAllDocsQuery(), + new Number[] { 0, 1, 2 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(2, ranges.size()); + assertEquals("1.0-2.0", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("2.0-3.0", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + true + ); + + testRewriteOptimizationCase( + new NumberFieldType(NumberType.UNSIGNED_LONG.typeName(), NumberType.UNSIGNED_LONG), + new double[][] { { Double.NEGATIVE_INFINITY, 1 }, { 2, Double.POSITIVE_INFINITY } }, + new MatchAllDocsQuery(), + new Number[] { 0, 1, 2 }, + range -> { + List ranges = range.getBuckets(); + assertEquals(2, ranges.size()); + assertEquals("*-1.0", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("2.0-*", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, + true + ); + } + private void testCase( Query query, CheckedConsumer buildIndex, Consumer> verify ) throws IOException { - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("test_range_agg"); aggregationBuilder.field(NUMBER_FIELD_NAME); aggregationBuilder.addRange(0d, 5d); @@ -323,9 +511,9 @@ private void simpleTestCase( Query query, Consumer> verify ) throws IOException { - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NUMBER_FIELD_NAME, NumberFieldMapper.NumberType.INTEGER); + MappedFieldType fieldType = new NumberFieldType(NUMBER_FIELD_NAME, NumberType.INTEGER); - testCase(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + testCase(aggregationBuilder, query, iw -> { iw.addDocument(singleton(new SortedNumericDocValuesField(NUMBER_FIELD_NAME, 7))); iw.addDocument(singleton(new SortedNumericDocValuesField(NUMBER_FIELD_NAME, 2))); iw.addDocument(singleton(new SortedNumericDocValuesField(NUMBER_FIELD_NAME, 3))); @@ -354,8 +542,84 @@ private void testCase( fieldType ); verify.accept(agg); + } + } + } + private void testRewriteOptimizationCase( + NumberFieldType fieldType, + double[][] ranges, + Query query, + Number[] dataPoints, + Consumer> verify, + boolean optimized + ) throws IOException { + NumberType numberType = fieldType.numberType(); + String fieldName = numberType.typeName(); + + try (Directory directory = newDirectory()) { + try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()))) { + for (Number dataPoint : dataPoints) { + indexWriter.addDocument(numberType.createFields(fieldName, dataPoint, true, true, false)); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("range").field(fieldName); + for (double[] range : ranges) { + aggregationBuilder.addRange(range[0], range[1]); + } + + CountingAggregator aggregator = createCountingAggregator(query, aggregationBuilder, indexSearcher, fieldType); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction( + aggregator.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + InternalRange topLevel = (InternalRange) aggregator.buildTopLevel(); + InternalRange agg = (InternalRange) topLevel.reduce(Collections.singletonList(topLevel), context); + doAssertReducedMultiBucketConsumer(agg, reduceBucketConsumer); + + verify.accept(agg); + + if (optimized) { + assertEquals(0, aggregator.getCollectCount().get()); + } else { + assertTrue(aggregator.getCollectCount().get() > 0); + } } } } + + protected CountingAggregator createCountingAggregator( + Query query, + AggregationBuilder builder, + IndexSearcher searcher, + MappedFieldType... fieldTypes + ) throws IOException { + return new CountingAggregator( + new AtomicInteger(), + createAggregator( + query, + builder, + searcher, + new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + fieldTypes + ) + ); + } }