diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java new file mode 100644 index 0000000000000..0af2a3eb98adb --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java @@ -0,0 +1,156 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.benchmark.search.aggregations; + +import org.openjdk.jmh.annotations.*; + +import org.apache.lucene.index.PointValues; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.opensearch.common.logging.LogConfigurator; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumericPointEncoder; +import org.opensearch.search.optimization.filterrewrite.Ranges; +import org.opensearch.search.optimization.filterrewrite.TreeTraversal; + +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +@Warmup(iterations = 10) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Fork(value = 1) +public class BKDTreeMultiRangesTraverseBenchmark { + @State(Scope.Benchmark) + public static class treeState { + @Param({ "10000", "10000000" }) + int treeSize; + + @Param({ "10000", "10000000" }) + int valMax; + + @Param({ "10", "100" }) + int buckets; + + @Param({ "12345" }) + int seed; + + private Random random; + + Path tmpDir; + Directory directory; + IndexWriter writer; + IndexReader reader; + + // multiRangesTraverse params + PointValues.PointTree pointTree; + Ranges ranges; + BiConsumer> collectRangeIDs; + int maxNumNonZeroRanges = Integer.MAX_VALUE; + + @Setup + public void setup() throws IOException { + LogConfigurator.setNodeName("sample-name"); + random = new Random(seed); + tmpDir = Files.createTempDirectory("tree-test"); + directory = FSDirectory.open(tmpDir); + writer = new IndexWriter(directory, new IndexWriterConfig()); + + for (int i = 0; i < treeSize; i++) { + writer.addDocument(List.of(new IntField("val", random.nextInt(valMax), Field.Store.NO))); + } + + reader = DirectoryReader.open(writer); + + // should only contain single segment + for (LeafReaderContext lrc : reader.leaves()) { + pointTree = lrc.reader().getPointValues("val").getPointTree(); + } + + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("val", NumberFieldMapper.NumberType.INTEGER); + NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; + + int bucketWidth = valMax/buckets; + byte[][] lowers = new byte[buckets][]; + byte[][] uppers = new byte[buckets][]; + for (int i = 0; i < buckets; i++) { + lowers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + uppers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + } + + ranges = new Ranges(lowers, uppers); + } + + @TearDown + public void tearDown() throws IOException { + for (String indexFile : FSDirectory.listAll(tmpDir)) { + Files.deleteIfExists(tmpDir.resolve(indexFile)); + } + Files.deleteIfExists(tmpDir); + } + } + + @Benchmark + public Map> multiRangeTraverseTree(treeState state) throws Exception { + Map> mockIDCollect = new HashMap<>(); + + TreeTraversal.RangeAwareIntersectVisitor treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(state.pointTree, state.ranges, state.maxNumNonZeroRanges, (activeIndex, docID) -> { + if (mockIDCollect.containsKey(activeIndex)) { + mockIDCollect.get(activeIndex).add(docID); + } else { + mockIDCollect.put(activeIndex, List.of(docID)); + } + }); + + multiRangesTraverse(treeVisitor); + return mockIDCollect; + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml index a75b1d0eac793..940e5adc6468f 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml @@ -706,3 +706,80 @@ setup: - match: { profile.shards.0.aggregations.0.debug.unoptimized_segments: 0 } - match: { profile.shards.0.aggregations.0.debug.leaf_visited: 1 } - match: { profile.shards.0.aggregations.0.debug.inner_visited: 0 } + +--- +"date_histogram with range sub aggregation": + - do: + indices.create: + index: test_date_hist_range_sub_agg + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + date: + type: date + - do: + bulk: + refresh: true + index: test_date_hist_range_sub_agg + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 11}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 12}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 23}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 39}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 4}' + - do: + search: + body: + size: 0 + aggs: + histo: + date_histogram: + field: date + calendar_interval: day + aggs: + my_range: + range: + field: v + ranges: + - to: 10 + - from: 10 + to: 20 + - from: 20 + to: 30 + - from: 30 + to: 40 + + - match: { hits.total.value: 8 } + - length: { aggregations.histo.buckets: 9 } + + - match: { aggregations.histo.buckets.0.key_as_string: "2020-03-01T00:00:00.000Z" } + - match: { aggregations.histo.buckets.1.key_as_string: "2020-03-02T00:00:00.000Z" } + - match: { aggregations.histo.buckets.7.key_as_string: "2020-03-08T00:00:00.000Z" } + - match: { aggregations.histo.buckets.8.key_as_string: "2020-03-09T00:00:00.000Z" } + + - match: { aggregations.histo.buckets.0.doc_count: 2 } + - match: { aggregations.histo.buckets.1.doc_count: 1 } + - match: { aggregations.histo.buckets.2.doc_count: 0 } + - match: { aggregations.histo.buckets.7.doc_count: 4 } + - match: { aggregations.histo.buckets.8.doc_count: 1 } + + - match: { aggregations.histo.buckets.0.my_range.buckets.0.doc_count: 1 } + + - match: { aggregations.histo.buckets.7.my_range.buckets.0.doc_count: 0 } + - match: { aggregations.histo.buckets.7.my_range.buckets.1.doc_count: 0 } + - match: { aggregations.histo.buckets.7.my_range.buckets.2.doc_count: 3 } + - match: { aggregations.histo.buckets.7.my_range.buckets.3.doc_count: 1 } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml index 0897e0bdd894b..593083cb384ef 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml @@ -158,3 +158,75 @@ setup: - match: { profile.shards.0.aggregations.0.debug.unoptimized_segments: 0 } - match: { profile.shards.0.aggregations.0.debug.leaf_visited: 1 } - match: { profile.shards.0.aggregations.0.debug.inner_visited: 0 } + +--- +"Range aggregation with auto_date_histogram sub-aggregation": + - do: + indices.create: + index: sub_agg_profile + body: + mappings: + properties: + "@timestamp": + type: date + metrics.size: + type: long + + - do: + bulk: + refresh: true + index: sub_agg_profile + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 2}' + - '{"index": {}}' + - '{"date": "2020-03-03", "v": 3}' + - '{"index": {}}' + - '{"date": "2020-04-09", "v": 4}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 13}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 14}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 15}' + - '{"index": {}}' + - '{"date": "2020-04-11", "v": 19}' + + - do: + search: + index: sub_agg_profile + body: + size: 0 + aggs: + range_histo: + range: + field: v + ranges: + - to: 0 + - from: 0 + to: 10 + - from: 10 + aggs: + date: + auto_date_histogram: + field: "date" + buckets: 3 + + - match: { hits.total.value: 8 } + - length: { aggregations.range_histo.buckets: 3 } + + - match: { aggregations.range_histo.buckets.0.key: "*-0.0" } + - match: { aggregations.range_histo.buckets.1.key: "0.0-10.0" } + - match: { aggregations.range_histo.buckets.2.key: "10.0-*" } + + - match: { aggregations.range_histo.buckets.0.doc_count: 0 } + - match: { aggregations.range_histo.buckets.1.doc_count: 4 } + - match: { aggregations.range_histo.buckets.2.doc_count: 4 } + + - match: { aggregations.range_histo.buckets.1.date.buckets.0.doc_count: 3 } + - match: { aggregations.range_histo.buckets.1.date.buckets.1.doc_count: 1 } + + - match: { aggregations.range_histo.buckets.2.date.buckets.0.doc_count: 3 } + - match: { aggregations.range_histo.buckets.2.date.buckets.1.doc_count: 1 } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml index 80aad96ce1f6b..807117bee81ea 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml @@ -673,3 +673,71 @@ setup: - match: { aggregations.my_range.buckets.3.from: 1.5 } - is_false: aggregations.my_range.buckets.3.to - match: { aggregations.my_range.buckets.3.doc_count: 2 } + +--- +"range with auto date sub aggregation": + - do: + indices.create: + index: test_range_auto_date + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + date: + type: date + - do: + bulk: + refresh: true + index: test_range_auto_date + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 11}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 12}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 23}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 39}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 4}' + - do: + search: + body: + size: 0 + aggs: + my_range: + range: + field: v + ranges: + - to: 10 + - from: 10 + to: 20 + - from: 20 + aggs: + histo: + date_histogram: + field: date + calendar_interval: day + + - match: { hits.total.value: 8 } + - length: { aggregations.my_range.buckets: 3 } + + - match: { aggregations.my_range.buckets.0.key: "*-10.0" } + - match: { aggregations.my_range.buckets.1.key: "10.0-20.0" } + - match: { aggregations.my_range.buckets.2.key: "20.0-*" } + + - match: { aggregations.my_range.buckets.0.doc_count: 2 } + - match: { aggregations.my_range.buckets.1.doc_count: 2 } + - match: { aggregations.my_range.buckets.2.doc_count: 4 } + + - match: { aggregations.my_range.buckets.0.histo.buckets.0.doc_count: 1 } + - match: { aggregations.my_range.buckets.0.histo.buckets.1.doc_count: 0 } + - match: { aggregations.my_range.buckets.2.histo.buckets.0.doc_count: 4 } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java deleted file mode 100644 index 2ab003fb94e33..0000000000000 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java +++ /dev/null @@ -1,849 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.aggregations.bucket; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.LongPoint; -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.index.PointValues; -import org.apache.lucene.search.CollectionTerminatedException; -import org.apache.lucene.search.ConstantScoreQuery; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.FieldExistsQuery; -import org.apache.lucene.search.IndexOrDocValuesQuery; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.PointRangeQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Weight; -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.NumericUtils; -import org.opensearch.common.CheckedRunnable; -import org.opensearch.common.Rounding; -import org.opensearch.common.lucene.search.function.FunctionScoreQuery; -import org.opensearch.index.mapper.DateFieldMapper; -import org.opensearch.index.mapper.DocCountFieldMapper; -import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.mapper.NumericPointEncoder; -import org.opensearch.index.query.DateRangeIncludingNowQuery; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregator; -import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; -import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; -import org.opensearch.search.aggregations.bucket.histogram.LongBounds; -import org.opensearch.search.aggregations.bucket.range.RangeAggregator.Range; -import org.opensearch.search.aggregations.support.ValuesSource; -import org.opensearch.search.aggregations.support.ValuesSourceConfig; -import org.opensearch.search.internal.SearchContext; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.OptionalLong; -import java.util.function.BiConsumer; -import java.util.function.Function; - -import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -/** - * Utility class to help rewrite aggregations into filters. - * Instead of aggregation collects documents one by one, filter may count all documents that match in one pass. - *

- * Currently supported rewrite: - *

    - *
  • date histogram : date range filter. - * Applied: DateHistogramAggregator, AutoDateHistogramAggregator, CompositeAggregator
  • - *
- * - * @opensearch.internal - */ -public final class FastFilterRewriteHelper { - - private FastFilterRewriteHelper() {} - - private static final Logger logger = LogManager.getLogger(FastFilterRewriteHelper.class); - - private static final Map, Function> queryWrappers; - - // Initialize the wrapper map for unwrapping the query - static { - queryWrappers = new HashMap<>(); - queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery()); - queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery()); - queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery()); - queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery()); - } - - /** - * Recursively unwraps query into the concrete form - * for applying the optimization - */ - private static Query unwrapIntoConcreteQuery(Query query) { - while (queryWrappers.containsKey(query.getClass())) { - query = queryWrappers.get(query.getClass()).apply(query); - } - - return query; - } - - /** - * Finds the global min and max bounds of the field for the shard across all segments - * - * @return null if the field is empty or not indexed - */ - private static long[] getShardBounds(final SearchContext context, final String fieldName) throws IOException { - final List leaves = context.searcher().getIndexReader().leaves(); - long min = Long.MAX_VALUE, max = Long.MIN_VALUE; - for (LeafReaderContext leaf : leaves) { - final PointValues values = leaf.reader().getPointValues(fieldName); - if (values != null) { - min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0)); - max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0)); - } - } - - if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) { - return null; - } - return new long[] { min, max }; - } - - /** - * Finds the min and max bounds of the field for the segment - * - * @return null if the field is empty or not indexed - */ - private static long[] getSegmentBounds(final LeafReaderContext context, final String fieldName) throws IOException { - long min = Long.MAX_VALUE, max = Long.MIN_VALUE; - final PointValues values = context.reader().getPointValues(fieldName); - if (values != null) { - min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0)); - max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0)); - } - - if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) { - return null; - } - return new long[] { min, max }; - } - - /** - * Gets the min and max bounds of the field for the shard search - * Depending on the query part, the bounds are computed differently - * - * @return null if the processed query not supported by the optimization - */ - public static long[] getDateHistoAggBounds(final SearchContext context, final String fieldName) throws IOException { - final Query cq = unwrapIntoConcreteQuery(context.query()); - if (cq instanceof PointRangeQuery) { - final PointRangeQuery prq = (PointRangeQuery) cq; - final long[] indexBounds = getShardBounds(context, fieldName); - if (indexBounds == null) return null; - return getBoundsWithRangeQuery(prq, fieldName, indexBounds); - } else if (cq instanceof MatchAllDocsQuery) { - return getShardBounds(context, fieldName); - } else if (cq instanceof FieldExistsQuery) { - // when a range query covers all values of a shard, it will be rewrite field exists query - if (((FieldExistsQuery) cq).getField().equals(fieldName)) { - return getShardBounds(context, fieldName); - } - } - - return null; - } - - private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) { - // Ensure that the query and aggregation are on the same field - if (prq.getField().equals(fieldName)) { - // Minimum bound for aggregation is the max between query and global - long lower = Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]); - // Maximum bound for aggregation is the min between query and global - long upper = Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]); - if (lower > upper) { - return null; - } - return new long[] { lower, upper }; - } - - return null; - } - - /** - * Context object for fast filter optimization - *

- * Usage: first set aggregation type, then check isRewriteable, then buildFastFilter - */ - public static class FastFilterContext { - private boolean rewriteable = false; - private boolean rangesBuiltAtShardLevel = false; - - private AggregationType aggregationType; - private final SearchContext context; - - private MappedFieldType fieldType; - private Ranges ranges; - - // debug info related fields - public int leaf; - public int inner; - public int segments; - public int optimizedSegments; - - public FastFilterContext(SearchContext context) { - this.context = context; - } - - public FastFilterContext(SearchContext context, AggregationType aggregationType) { - this.context = context; - this.aggregationType = aggregationType; - } - - public AggregationType getAggregationType() { - return aggregationType; - } - - public void setAggregationType(AggregationType aggregationType) { - this.aggregationType = aggregationType; - } - - public boolean isRewriteable(final Object parent, final int subAggLength) { - if (context.maxAggRewriteFilters() == 0) return false; - - boolean rewriteable = aggregationType.isRewriteable(parent, subAggLength); - logger.debug("Fast filter rewriteable: {} for shard {}", rewriteable, context.indexShard().shardId()); - this.rewriteable = rewriteable; - return rewriteable; - } - - public void buildRanges(MappedFieldType fieldType) throws IOException { - assert ranges == null : "Ranges should only be built once at shard level, but they are already built"; - this.fieldType = fieldType; - this.ranges = this.aggregationType.buildRanges(context, fieldType); - if (ranges != null) { - logger.debug("Ranges built for shard {}", context.indexShard().shardId()); - rangesBuiltAtShardLevel = true; - } - } - - private Ranges buildRanges(LeafReaderContext leaf) throws IOException { - Ranges ranges = this.aggregationType.buildRanges(leaf, context, fieldType); - if (ranges != null) { - logger.debug("Ranges built for shard {} segment {}", context.indexShard().shardId(), leaf.ord); - } - return ranges; - } - - /** - * Try to populate the bucket doc counts for aggregation - *

- * Usage: invoked at segment level — in getLeafCollector of aggregator - * - * @param bucketOrd bucket ordinal producer - * @param incrementDocCount consume the doc_count results for certain ordinal - */ - public boolean tryFastFilterAggregation( - final LeafReaderContext ctx, - final BiConsumer incrementDocCount, - final Function bucketOrd - ) throws IOException { - this.segments++; - if (!this.rewriteable) { - return false; - } - - if (ctx.reader().hasDeletions()) return false; - - PointValues values = ctx.reader().getPointValues(this.fieldType.name()); - if (values == null) return false; - // only proceed if every document corresponds to exactly one point - if (values.getDocCount() != values.size()) return false; - - NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME); - if (docCountValues.nextDoc() != NO_MORE_DOCS) { - logger.debug( - "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", - this.context.indexShard().shardId(), - ctx.ord - ); - return false; - } - - // even if no ranges built at shard level, we can still perform the optimization - // when functionally match-all at segment level - if (!this.rangesBuiltAtShardLevel && !segmentMatchAll(this.context, ctx)) { - return false; - } - - Ranges ranges = this.ranges; - if (ranges == null) { - logger.debug( - "Shard {} segment {} functionally match all documents. Build the fast filter", - this.context.indexShard().shardId(), - ctx.ord - ); - ranges = this.buildRanges(ctx); - if (ranges == null) { - return false; - } - } - - DebugInfo debugInfo = this.aggregationType.tryFastFilterAggregation(values, ranges, incrementDocCount, bucketOrd); - this.consumeDebugInfo(debugInfo); - - this.optimizedSegments++; - logger.debug("Fast filter optimization applied to shard {} segment {}", this.context.indexShard().shardId(), ctx.ord); - logger.debug("crossed leaf nodes: {}, inner nodes: {}", this.leaf, this.inner); - return true; - } - - private void consumeDebugInfo(DebugInfo debug) { - leaf += debug.leaf; - inner += debug.inner; - } - } - - /** - * Different types have different pre-conditions, filter building logic, etc. - */ - interface AggregationType { - boolean isRewriteable(Object parent, int subAggLength); - - Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) throws IOException; - - Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) throws IOException; - - DebugInfo tryFastFilterAggregation( - PointValues values, - Ranges ranges, - BiConsumer incrementDocCount, - Function bucketOrd - ) throws IOException; - } - - /** - * For date histogram aggregation - */ - public static abstract class AbstractDateHistogramAggregationType implements AggregationType { - private final MappedFieldType fieldType; - private final boolean missing; - private final boolean hasScript; - private LongBounds hardBounds; - - public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) { - this.fieldType = fieldType; - this.missing = missing; - this.hasScript = hasScript; - } - - public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) { - this(fieldType, missing, hasScript); - this.hardBounds = hardBounds; - } - - @Override - public boolean isRewriteable(Object parent, int subAggLength) { - if (parent == null && subAggLength == 0 && !missing && !hasScript) { - if (fieldType != null && fieldType instanceof DateFieldMapper.DateFieldType) { - return fieldType.isSearchable(); - } - } - return false; - } - - @Override - public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) throws IOException { - long[] bounds = getDateHistoAggBounds(context, fieldType.name()); - logger.debug("Bounds are {} for shard {}", bounds, context.indexShard().shardId()); - return buildRanges(context, bounds); - } - - @Override - public Ranges buildRanges(LeafReaderContext leaf, SearchContext context, MappedFieldType fieldType) throws IOException { - long[] bounds = getSegmentBounds(leaf, fieldType.name()); - logger.debug("Bounds are {} for shard {} segment {}", bounds, context.indexShard().shardId(), leaf.ord); - return buildRanges(context, bounds); - } - - private Ranges buildRanges(SearchContext context, long[] bounds) throws IOException { - bounds = processHardBounds(bounds); - if (bounds == null) { - return null; - } - assert bounds[0] <= bounds[1] : "Low bound should be less than high bound"; - - final Rounding rounding = getRounding(bounds[0], bounds[1]); - final OptionalLong intervalOpt = Rounding.getInterval(rounding); - if (intervalOpt.isEmpty()) { - return null; - } - final long interval = intervalOpt.getAsLong(); - - // process the after key of composite agg - processAfterKey(bounds, interval); - - return FastFilterRewriteHelper.createRangesFromAgg( - context, - (DateFieldMapper.DateFieldType) fieldType, - interval, - getRoundingPrepared(), - bounds[0], - bounds[1] - ); - } - - protected abstract Rounding getRounding(final long low, final long high); - - protected abstract Rounding.Prepared getRoundingPrepared(); - - protected void processAfterKey(long[] bound, long interval) {} - - protected long[] processHardBounds(long[] bounds) { - if (bounds != null) { - // Update min/max limit if user specified any hard bounds - if (hardBounds != null) { - if (hardBounds.getMin() > bounds[0]) { - bounds[0] = hardBounds.getMin(); - } - if (hardBounds.getMax() - 1 < bounds[1]) { - bounds[1] = hardBounds.getMax() - 1; // hard bounds max is exclusive - } - if (bounds[0] > bounds[1]) { - return null; - } - } - } - return bounds; - } - - public DateFieldMapper.DateFieldType getFieldType() { - assert fieldType instanceof DateFieldMapper.DateFieldType; - return (DateFieldMapper.DateFieldType) fieldType; - } - - @Override - public DebugInfo tryFastFilterAggregation( - PointValues values, - Ranges ranges, - BiConsumer incrementDocCount, - Function bucketOrd - ) throws IOException { - int size = Integer.MAX_VALUE; - if (this instanceof CompositeAggregator.CompositeAggregationType) { - size = ((CompositeAggregator.CompositeAggregationType) this).getSize(); - } - - DateFieldMapper.DateFieldType fieldType = getFieldType(); - BiConsumer incrementFunc = (activeIndex, docCount) -> { - long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0); - rangeStart = fieldType.convertNanosToMillis(rangeStart); - long ord = getBucketOrd(bucketOrd.apply(rangeStart)); - incrementDocCount.accept(ord, (long) docCount); - }; - - return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); - } - - private static long getBucketOrd(long bucketOrd) { - if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; - } - - return bucketOrd; - } - } - - /** - * For range aggregation - */ - public static class RangeAggregationType implements AggregationType { - - private final ValuesSourceConfig config; - private final Range[] ranges; - - public RangeAggregationType(ValuesSourceConfig config, Range[] ranges) { - this.config = config; - this.ranges = ranges; - } - - @Override - public boolean isRewriteable(Object parent, int subAggLength) { - if (config.fieldType() == null) return false; - MappedFieldType fieldType = config.fieldType(); - if (fieldType.isSearchable() == false || !(fieldType instanceof NumericPointEncoder)) return false; - - if (parent == null && subAggLength == 0 && config.script() == null && config.missing() == null) { - if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) { - // ranges are already sorted by from and then to - // we want ranges not overlapping with each other - double prevTo = ranges[0].getTo(); - for (int i = 1; i < ranges.length; i++) { - if (prevTo > ranges[i].getFrom()) { - return false; - } - prevTo = ranges[i].getTo(); - } - return true; - } - } - return false; - } - - @Override - public Ranges buildRanges(SearchContext context, MappedFieldType fieldType) { - assert fieldType instanceof NumericPointEncoder; - NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; - byte[][] lowers = new byte[ranges.length][]; - byte[][] uppers = new byte[ranges.length][]; - for (int i = 0; i < ranges.length; i++) { - double rangeMin = ranges[i].getFrom(); - double rangeMax = ranges[i].getTo(); - byte[] lower = numericPointEncoder.encodePoint(rangeMin); - byte[] upper = numericPointEncoder.encodePoint(rangeMax); - lowers[i] = lower; - uppers[i] = upper; - } - - return new Ranges(lowers, uppers); - } - - @Override - public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) { - throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level"); - } - - @Override - public DebugInfo tryFastFilterAggregation( - PointValues values, - Ranges ranges, - BiConsumer incrementDocCount, - Function bucketOrd - ) throws IOException { - int size = Integer.MAX_VALUE; - - BiConsumer incrementFunc = (activeIndex, docCount) -> { - long ord = bucketOrd.apply(activeIndex); - incrementDocCount.accept(ord, (long) docCount); - }; - - return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); - } - } - - public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) { - return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource; - } - - private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException { - Weight weight = ctx.searcher().createWeight(ctx.query(), ScoreMode.COMPLETE_NO_SCORES, 1f); - return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs(); - } - - /** - * Creates the date ranges from date histo aggregations using its interval, - * and min/max boundaries - */ - private static Ranges createRangesFromAgg( - final SearchContext context, - final DateFieldMapper.DateFieldType fieldType, - final long interval, - final Rounding.Prepared preparedRounding, - long low, - final long high - ) { - // Calculate the number of buckets using range and interval - long roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low)); - long prevRounded = roundedLow; - int bucketCount = 0; - while (roundedLow <= fieldType.convertNanosToMillis(high)) { - bucketCount++; - int maxNumFilterBuckets = context.maxAggRewriteFilters(); - if (bucketCount > maxNumFilterBuckets) { - logger.debug("Max number of filters reached [{}], skip the fast filter optimization", maxNumFilterBuckets); - return null; - } - // Below rounding is needed as the interval could return in - // non-rounded values for something like calendar month - roundedLow = preparedRounding.round(roundedLow + interval); - if (prevRounded == roundedLow) break; // prevents getting into an infinite loop - prevRounded = roundedLow; - } - - long[][] ranges = new long[bucketCount][2]; - if (bucketCount > 0) { - roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low)); - - int i = 0; - while (i < bucketCount) { - // Calculate the lower bucket bound - long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow); - roundedLow = preparedRounding.round(roundedLow + interval); - - // plus one on high value because upper bound is exclusive, but high value exists - long upper = i + 1 == bucketCount ? high + 1 : fieldType.convertRoundedMillisToNanos(roundedLow); - - ranges[i][0] = lower; - ranges[i][1] = upper; - i++; - } - } - - byte[][] lowers = new byte[ranges.length][]; - byte[][] uppers = new byte[ranges.length][]; - for (int i = 0; i < ranges.length; i++) { - byte[] lower = LONG.encodePoint(ranges[i][0]); - byte[] max = LONG.encodePoint(ranges[i][1]); - lowers[i] = lower; - uppers[i] = max; - } - - return new Ranges(lowers, uppers); - } - - /** - * @param maxNumNonZeroRanges the number of non-zero ranges to collect - */ - private static DebugInfo multiRangesTraverse( - final PointValues.PointTree tree, - final Ranges ranges, - final BiConsumer incrementDocCount, - final int maxNumNonZeroRanges - ) throws IOException { - DebugInfo debugInfo = new DebugInfo(); - int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); - if (activeIndex < 0) { - logger.debug("No ranges match the query, skip the fast filter optimization"); - return debugInfo; - } - RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); - PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); - try { - intersectWithRanges(visitor, tree, collector, debugInfo); - } catch (CollectionTerminatedException e) { - logger.debug("Early terminate since no more range to collect"); - } - collector.finalizePreviousRange(); - - return debugInfo; - } - - private static class Ranges { - byte[][] lowers; // inclusive - byte[][] uppers; // exclusive - int size; - int byteLen; - static ArrayUtil.ByteArrayComparator comparator; - - Ranges(byte[][] lowers, byte[][] uppers) { - this.lowers = lowers; - this.uppers = uppers; - assert lowers.length == uppers.length; - this.size = lowers.length; - this.byteLen = lowers[0].length; - comparator = ArrayUtil.getUnsignedComparator(byteLen); - } - - public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { - if (compareByteValue(lowers[0], globalMax) > 0) { - return -1; - } - int i = 0; - while (compareByteValue(uppers[i], globalMin) <= 0) { - i++; - if (i >= size) { - return -1; - } - } - return i; - } - - public static int compareByteValue(byte[] value1, byte[] value2) { - return comparator.compare(value1, 0, value2, 0); - } - - public static boolean withinLowerBound(byte[] value, byte[] lowerBound) { - return compareByteValue(value, lowerBound) >= 0; - } - - public static boolean withinUpperBound(byte[] value, byte[] upperBound) { - return compareByteValue(value, upperBound) < 0; - } - } - - private static void intersectWithRanges( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - RangeCollectorForPointTree collector, - DebugInfo debug - ) throws IOException { - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - - switch (r) { - case CELL_INSIDE_QUERY: - collector.countNode((int) pointTree.size()); - debug.visitInner(); - break; - case CELL_CROSSES_QUERY: - if (pointTree.moveToChild()) { - do { - intersectWithRanges(visitor, pointTree, collector, debug); - } while (pointTree.moveToSibling()); - pointTree.moveToParent(); - } else { - pointTree.visitDocValues(visitor); - debug.visitLeaf(); - } - break; - case CELL_OUTSIDE_QUERY: - } - } - - private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { - return new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) throws IOException { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited" - ); - } - - @Override - public void visit(int docID, byte[] packedValue) throws IOException { - visitPoints(packedValue, collector::count); - } - - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - visitPoints(packedValue, () -> { - for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { - collector.count(); - } - }); - } - - private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { - if (!collector.withinUpperBound(packedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(packedValue)) { - throw new CollectionTerminatedException(); - } - } - - if (collector.withinRange(packedValue)) { - collect.run(); - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - // try to find the first range that may collect values from this cell - if (!collector.withinUpperBound(minPackedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(minPackedValue)) { - throw new CollectionTerminatedException(); - } - } - // after the loop, min < upper - // cell could be outside [min max] lower - if (!collector.withinLowerBound(maxPackedValue)) { - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } - if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) { - return PointValues.Relation.CELL_INSIDE_QUERY; - } - return PointValues.Relation.CELL_CROSSES_QUERY; - } - }; - } - - private static class RangeCollectorForPointTree { - private final BiConsumer incrementRangeDocCount; - private int counter = 0; - - private final Ranges ranges; - private int activeIndex; - - private int visitedRange = 0; - private final int maxNumNonZeroRange; - - public RangeCollectorForPointTree( - BiConsumer incrementRangeDocCount, - int maxNumNonZeroRange, - Ranges ranges, - int activeIndex - ) { - this.incrementRangeDocCount = incrementRangeDocCount; - this.maxNumNonZeroRange = maxNumNonZeroRange; - this.ranges = ranges; - this.activeIndex = activeIndex; - } - - private void count() { - counter++; - } - - private void countNode(int count) { - counter += count; - } - - private void finalizePreviousRange() { - if (counter > 0) { - incrementRangeDocCount.accept(activeIndex, counter); - counter = 0; - } - } - - /** - * @return true when iterator exhausted or collect enough non-zero ranges - */ - private boolean iterateRangeEnd(byte[] value) { - // the new value may not be contiguous to the previous one - // so try to find the first next range that cross the new value - while (!withinUpperBound(value)) { - if (++activeIndex >= ranges.size) { - return true; - } - } - visitedRange++; - return visitedRange > maxNumNonZeroRange; - } - - private boolean withinLowerBound(byte[] value) { - return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]); - } - - private boolean withinUpperBound(byte[] value) { - return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]); - } - - private boolean withinRange(byte[] value) { - return withinLowerBound(value) && withinUpperBound(value); - } - } - - /** - * Contains debug info of BKD traversal to show in profile - */ - private static class DebugInfo { - private int leaf = 0; // leaf node visited - private int inner = 0; // inner node visited - - private void visitLeaf() { - leaf++; - } - - private void visitInner() { - inner++; - } - } -} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index bfb484dcf478d..9ca27a3459ecf 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -73,11 +73,11 @@ import org.opensearch.search.aggregations.MultiBucketCollector; import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.bucket.BucketsAggregator; -import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; -import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper.AbstractDateHistogramAggregationType; import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.optimization.filterrewrite.CompositeAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import org.opensearch.search.searchafter.SearchAfterBuilder; import org.opensearch.search.sort.SortAndFormats; @@ -89,13 +89,15 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.function.LongUnaryOperator; import java.util.stream.Collectors; import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING; +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; /** - * Main aggregator that aggregates docs from mulitple aggregations + * Main aggregator that aggregates docs from multiple aggregations * * @opensearch.internal */ @@ -118,9 +120,8 @@ public final class CompositeAggregator extends BucketsAggregator { private boolean earlyTerminated; - private final FastFilterRewriteHelper.FastFilterContext fastFilterContext; - private LongKeyedBucketOrds bucketOrds = null; - private Rounding.Prepared preparedRounding = null; + private final OptimizationContext optimizationContext; + private LongKeyedBucketOrds bucketOrds; CompositeAggregator( String name, @@ -166,56 +167,64 @@ public final class CompositeAggregator extends BucketsAggregator { this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey); this.rawAfterKey = rawAfterKey; - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext(context); - if (!FastFilterRewriteHelper.isCompositeAggRewriteable(sourceConfigs)) { - return; - } - fastFilterContext.setAggregationType(new CompositeAggregationType()); - if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - // bucketOrds is used for saving date histogram results - bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE); - preparedRounding = ((CompositeAggregationType) fastFilterContext.getAggregationType()).getRoundingPrepared(); - fastFilterContext.buildRanges(sourceConfigs[0].fieldType()); - } - } + optimizationContext = new OptimizationContext(new CompositeAggregatorBridge() { + private RoundingValuesSource valuesSource; + private long afterKey = -1L; - /** - * Currently the filter rewrite is only supported for date histograms - */ - public class CompositeAggregationType extends AbstractDateHistogramAggregationType { - private final RoundingValuesSource valuesSource; - private long afterKey = -1L; - - public CompositeAggregationType() { - super(sourceConfigs[0].fieldType(), sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript()); - this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource(); - if (rawAfterKey != null) { - assert rawAfterKey.size() == 1 && formats.size() == 1; - this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> { - throw new IllegalArgumentException("now() is not supported in [after] key"); - }); + @Override + public boolean canOptimize() { + if (parent != null || subAggregators.length != 0) return false; + if (canOptimize(sourceConfigs)) { + this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource(); + if (rawAfterKey != null) { + assert rawAfterKey.size() == 1 && formats.size() == 1; + this.afterKey = formats.get(0).parseLong(rawAfterKey.get(0).toString(), false, () -> { + throw new IllegalArgumentException("now() is not supported in [after] key"); + }); + } + + // bucketOrds is used for saving the date histogram results got from the optimization path + bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE); + return true; + } + return false; } - } - public Rounding getRounding(final long low, final long high) { - return valuesSource.getRounding(); - } + @Override + public void prepare() throws IOException { + buildRanges(context); + } - public Rounding.Prepared getRoundingPrepared() { - return valuesSource.getPreparedRounding(); - } + protected Rounding getRounding(final long low, final long high) { + return valuesSource.getRounding(); + } - @Override - protected void processAfterKey(long[] bound, long interval) { - // afterKey is the last bucket key in previous response, and the bucket key - // is the minimum of all values in the bucket, so need to add the interval - if (afterKey != -1L) { - bound[0] = afterKey + interval; + protected Rounding.Prepared getRoundingPrepared() { + return valuesSource.getPreparedRounding(); } - } - public int getSize() { - return size; + @Override + protected long[] processAfterKey(long[] bounds, long interval) { + // afterKey is the last bucket key in previous response, and the bucket key + // is the minimum of all values in the bucket, so need to add the interval + if (afterKey != -1L) { + bounds[0] = afterKey + interval; + } + return bounds; + } + + @Override + protected int getSize() { + return size; + } + + @Override + protected Function bucketOrdProducer() { + return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key)); + } + }); + if (optimizationContext.canOptimize(parent, context)) { + optimizationContext.prepare(); } } @@ -368,7 +377,7 @@ private boolean isMaybeMultivalued(LeafReaderContext context, SortField sortFiel return v2 != null && DocValues.unwrapSingleton(v2) == null; default: - // we have no clue whether the field is multi-valued or not so we assume it is. + // we have no clue whether the field is multivalued or not so we assume it is. return true; } } @@ -551,11 +560,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - boolean optimized = fastFilterContext.tryFastFilterAggregation( - ctx, - this::incrementBucketDocCount, - (key) -> bucketOrds.add(0, preparedRounding.round((long) key)) - ); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); finishLeaf(); @@ -709,11 +714,6 @@ private static class Entry { @Override public void collectDebugInfo(BiConsumer add) { - if (fastFilterContext.optimizedSegments > 0) { - add.accept("optimized_segments", fastFilterContext.optimizedSegments); - add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments); - add.accept("leaf_visited", fastFilterContext.leaf); - add.accept("inner_visited", fastFilterContext.inner); - } + optimizationContext.populateDebugInfo(add); } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index d13d575a9d696..dd6a4be8fbf7b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -42,7 +42,6 @@ import org.opensearch.common.util.IntArray; import org.opensearch.common.util.LongArray; import org.opensearch.core.common.util.ByteArray; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; @@ -53,22 +52,24 @@ import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator; import org.opensearch.search.aggregations.bucket.DeferringBucketCollector; -import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; import org.opensearch.search.aggregations.bucket.MergingBucketsDeferringCollector; import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder.RoundingInfo; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongToIntFunction; +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; + /** * An aggregator for date values that attempts to return a specific number of * buckets, reconfiguring how it rounds dates to buckets on the fly as new @@ -135,7 +136,7 @@ static AutoDateHistogramAggregator build( protected int roundingIdx; protected Rounding.Prepared preparedRounding; - private final FastFilterRewriteHelper.FastFilterContext fastFilterContext; + private final OptimizationContext optimizationContext; private AutoDateHistogramAggregator( String name, @@ -158,52 +159,54 @@ private AutoDateHistogramAggregator( this.roundingPreparer = roundingPreparer; this.preparedRounding = prepareRounding(0); - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( - context, - new AutoHistogramAggregationType( - valuesSourceConfig.fieldType(), - valuesSourceConfig.missing() != null, - valuesSourceConfig.script() != null - ) - ); - if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType())); - } - } - - private class AutoHistogramAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType { + optimizationContext = new OptimizationContext(new DateHistogramAggregatorBridge() { + @Override + public boolean canOptimize() { + return canOptimize(valuesSourceConfig); + } - public AutoHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) { - super(fieldType, missing, hasScript); - } + @Override + public void prepare() throws IOException { + buildRanges(context); + } - @Override - protected Rounding getRounding(final long low, final long high) { - // max - min / targetBuckets = bestDuration - // find the right innerInterval this bestDuration belongs to - // since we cannot exceed targetBuckets, bestDuration should go up, - // so the right innerInterval should be an upper bound - long bestDuration = (high - low) / targetBuckets; - // reset so this function is idempotent - roundingIdx = 0; - while (roundingIdx < roundingInfos.length - 1) { - final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx]; - final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1]; - // If the interval duration is covered by the maximum inner interval, - // we can start with this outer interval for creating the buckets - if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) { - break; + @Override + protected Rounding getRounding(final long low, final long high) { + // max - min / targetBuckets = bestDuration + // find the right innerInterval this bestDuration belongs to + // since we cannot exceed targetBuckets, bestDuration should go up, + // so the right innerInterval should be an upper bound + long bestDuration = (high - low) / targetBuckets; + // reset so this function is idempotent + roundingIdx = 0; + while (roundingIdx < roundingInfos.length - 1) { + final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx]; + final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length - 1]; + // If the interval duration is covered by the maximum inner interval, + // we can start with this outer interval for creating the buckets + if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) { + break; + } + roundingIdx++; } - roundingIdx++; + + preparedRounding = prepareRounding(roundingIdx); + return roundingInfos[roundingIdx].rounding; } - preparedRounding = prepareRounding(roundingIdx); - return roundingInfos[roundingIdx].rounding; - } + @Override + protected Prepared getRoundingPrepared() { + return preparedRounding; + } - @Override - protected Prepared getRoundingPrepared() { - return preparedRounding; + @Override + protected Function bucketOrdProducer() { + return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)); + } + + }); + if (optimizationContext.canOptimize(parent, context)) { + optimizationContext.prepare(); } } @@ -236,11 +239,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = fastFilterContext.tryFastFilterAggregation( - ctx, - this::incrementBucketDocCount, - (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)) - ); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDocValues values = valuesSource.longValues(ctx); @@ -308,12 +307,7 @@ protected final void merge(long[] mergeMap, long newNumBuckets) { @Override public void collectDebugInfo(BiConsumer add) { super.collectDebugInfo(add); - if (fastFilterContext.optimizedSegments > 0) { - add.accept("optimized_segments", fastFilterContext.optimizedSegments); - add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments); - add.accept("leaf_visited", fastFilterContext.leaf); - add.accept("inner_visited", fastFilterContext.inner); - } + optimizationContext.populateDebugInfo(add); } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index 4b84797c18922..3be7646518dc4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -39,7 +39,6 @@ import org.opensearch.common.Nullable; import org.opensearch.common.Rounding; import org.opensearch.common.lease.Releasables; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; @@ -49,17 +48,20 @@ import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.bucket.BucketsAggregator; -import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.Objects; import java.util.function.BiConsumer; +import java.util.function.Function; + +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; /** * An aggregator for date values. Every date is rounded down using a configured @@ -84,7 +86,7 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg private final LongBounds hardBounds; private final LongKeyedBucketOrds bucketOrds; - private final FastFilterRewriteHelper.FastFilterContext fastFilterContext; + private final OptimizationContext optimizationContext; DateHistogramAggregator( String name, @@ -117,34 +119,40 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality); - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( - context, - new DateHistogramAggregationType( - valuesSourceConfig.fieldType(), - valuesSourceConfig.missing() != null, - valuesSourceConfig.script() != null, - hardBounds - ) - ); - if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - fastFilterContext.buildRanges(Objects.requireNonNull(valuesSourceConfig.fieldType())); - } - } + optimizationContext = new OptimizationContext(new DateHistogramAggregatorBridge() { + @Override + public boolean canOptimize() { + return canOptimize(valuesSourceConfig); + } - private class DateHistogramAggregationType extends FastFilterRewriteHelper.AbstractDateHistogramAggregationType { + @Override + public void prepare() throws IOException { + buildRanges(context); + } - public DateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) { - super(fieldType, missing, hasScript, hardBounds); - } + @Override + protected Rounding getRounding(long low, long high) { + return rounding; + } - @Override - protected Rounding getRounding(long low, long high) { - return rounding; - } + @Override + protected Rounding.Prepared getRoundingPrepared() { + return preparedRounding; + } + + @Override + protected long[] processHardBounds(long[] bounds) { + return super.processHardBounds(bounds, hardBounds); + } - @Override - protected Rounding.Prepared getRoundingPrepared() { - return preparedRounding; + @Override + protected Function bucketOrdProducer() { + return (key) -> bucketOrds.add(0, preparedRounding.round((long) key)); + } + + }); + if (optimizationContext.canOptimize(parent, context)) { + optimizationContext.prepare(); } } @@ -162,11 +170,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = fastFilterContext.tryFastFilterAggregation( - ctx, - this::incrementBucketDocCount, - (key) -> bucketOrds.add(0, preparedRounding.round((long) key)) - ); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); SortedNumericDocValues values = valuesSource.longValues(ctx); @@ -253,12 +257,7 @@ public void doClose() { @Override public void collectDebugInfo(BiConsumer add) { add.accept("total_buckets", bucketOrds.size()); - if (fastFilterContext.optimizedSegments > 0) { - add.accept("optimized_segments", fastFilterContext.optimizedSegments); - add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments); - add.accept("leaf_visited", fastFilterContext.leaf); - add.accept("inner_visited", fastFilterContext.inner); - } + optimizationContext.populateDebugInfo(add); } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index 2ba2b06514de1..312ba60dcfb83 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -55,10 +55,11 @@ import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.NonCollectingAggregator; import org.opensearch.search.aggregations.bucket.BucketsAggregator; -import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.RangeAggregatorBridge; import java.io.IOException; import java.util.ArrayList; @@ -66,6 +67,7 @@ import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; +import java.util.function.Function; import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -249,7 +251,7 @@ public boolean equals(Object obj) { final double[] maxTo; - private final FastFilterRewriteHelper.FastFilterContext fastFilterContext; + private final OptimizationContext optimizationContext; public RangeAggregator( String name, @@ -279,12 +281,24 @@ public RangeAggregator( maxTo[i] = Math.max(this.ranges[i].to, maxTo[i - 1]); } - fastFilterContext = new FastFilterRewriteHelper.FastFilterContext( - context, - new FastFilterRewriteHelper.RangeAggregationType(config, ranges) - ); - if (fastFilterContext.isRewriteable(parent, subAggregators.length)) { - fastFilterContext.buildRanges(Objects.requireNonNull(config.fieldType())); + optimizationContext = new OptimizationContext(new RangeAggregatorBridge() { + @Override + public boolean canOptimize() { + return canOptimize(config, ranges); + } + + @Override + public void prepare() { + buildRanges(ranges); + } + + @Override + protected Function bucketOrdProducer() { + return (activeIndex) -> subBucketOrdinal(0, (int) activeIndex); + } + }); + if (optimizationContext.canOptimize(parent, context)) { + optimizationContext.prepare(); } } @@ -298,11 +312,7 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - boolean optimized = fastFilterContext.tryFastFilterAggregation( - ctx, - this::incrementBucketDocCount, - (activeIndex) -> subBucketOrdinal(0, (int) activeIndex) - ); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, false); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); @@ -452,11 +462,6 @@ public InternalAggregation buildEmptyAggregation() { @Override public void collectDebugInfo(BiConsumer add) { super.collectDebugInfo(add); - if (fastFilterContext.optimizedSegments > 0) { - add.accept("optimized_segments", fastFilterContext.optimizedSegments); - add.accept("unoptimized_segments", fastFilterContext.segments - fastFilterContext.optimizedSegments); - add.accept("leaf_visited", fastFilterContext.leaf); - add.accept("inner_visited", fastFilterContext.inner); - } + optimizationContext.populateDebugInfo(add); } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java new file mode 100644 index 0000000000000..17b8c9db9a782 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; + +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * This interface provides a bridge between an aggregator and the optimization context, allowing + * the aggregator to provide data and optimize the aggregation process. + * + *

The main purpose of this interface is to encapsulate the aggregator-specific optimization + * logic and provide access to the data in Aggregator that is required for optimization, while keeping the optimization + * business logic separate from the aggregator implementation. + * + *

To use this interface to optimize an aggregator, you should subclass this interface in this package + * and put any specific optimization business logic in it. Then implement this subclass in the aggregator + * to provide data that is needed for doing the optimization + * + * @opensearch.internal + */ +public abstract class AggregatorBridge { + + /** + * The optimization context associated with this aggregator bridge. + */ + OptimizationContext optimizationContext; + + /** + * The field type associated with this aggregator bridge. + */ + MappedFieldType fieldType; + + void setOptimizationContext(OptimizationContext context) { + this.optimizationContext = context; + } + + /** + * Checks whether the aggregator can be optimized. + * + * @return {@code true} if the aggregator can be optimized, {@code false} otherwise. + * The result will be saved in the optimization context. + */ + public abstract boolean canOptimize(); + + /** + * Prepares the optimization at shard level. + * For example, figure out what are the ranges from the aggregation to do the optimization later + */ + public abstract void prepare() throws IOException; + + /** + * Prepares the optimization for a specific segment and ignore whatever built at shard level + * + * @param leaf the leaf reader context for the segment + */ + public abstract void prepareFromSegment(LeafReaderContext leaf) throws IOException; + + /** + * Attempts to build aggregation results for a segment + * + * @param values the point values (index structure for numeric values) for a segment + * @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket + */ + public abstract void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java new file mode 100644 index 0000000000000..74fe30720b85d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.PointValues; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; +import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; + +import java.io.IOException; +import java.util.List; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +/** + * For composite aggregation to do optimization when it only has a single date histogram source + */ +public abstract class CompositeAggregatorBridge extends DateHistogramAggregatorBridge { + protected boolean canOptimize(CompositeValuesSourceConfig[] sourceConfigs) { + if (sourceConfigs.length != 1 || !(sourceConfigs[0].valuesSource() instanceof RoundingValuesSource)) return false; + return canOptimize(sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript(), sourceConfigs[0].fieldType()); + } + + private boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) { + if (!missing && !hasScript) { + if (fieldType instanceof DateFieldMapper.DateFieldType) { + if (fieldType.isSearchable()) { + this.fieldType = fieldType; + return true; + } + } + } + return false; + } + + @Override + public final void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { + DateFieldMapper.DateFieldType fieldType = getFieldType(); + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, ord); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); + } + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + incrementDocCount.accept(ord, (long) docCount); + }); + } + + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java new file mode 100644 index 0000000000000..06d0e251a4105 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.opensearch.common.Rounding; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.OptionalLong; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +/** + * For date histogram aggregation + */ +public abstract class DateHistogramAggregatorBridge extends AggregatorBridge { + + protected boolean canOptimize(ValuesSourceConfig config) { + if (config.script() == null && config.missing() == null) { + MappedFieldType fieldType = config.fieldType(); + if (fieldType instanceof DateFieldMapper.DateFieldType) { + if (fieldType.isSearchable()) { + this.fieldType = fieldType; + return true; + } + } + } + return false; + } + + protected void buildRanges(SearchContext context) throws IOException { + long[] bounds = Helper.getDateHistoAggBounds(context, fieldType.name()); + optimizationContext.setRanges(buildRanges(bounds)); + } + + @Override + public void prepareFromSegment(LeafReaderContext leaf) throws IOException { + long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name()); + optimizationContext.setRangesFromSegment(buildRanges(bounds)); + } + + private Ranges buildRanges(long[] bounds) { + bounds = processHardBounds(bounds); + if (bounds == null) { + return null; + } + assert bounds[0] <= bounds[1] : "Low bound should be less than high bound"; + + final Rounding rounding = getRounding(bounds[0], bounds[1]); + final OptionalLong intervalOpt = Rounding.getInterval(rounding); + if (intervalOpt.isEmpty()) { + return null; + } + final long interval = intervalOpt.getAsLong(); + + // process the after key of composite agg + bounds = processAfterKey(bounds, interval); + + return Helper.createRangesFromAgg( + (DateFieldMapper.DateFieldType) fieldType, + interval, + getRoundingPrepared(), + bounds[0], + bounds[1], + optimizationContext.maxAggRewriteFilters + ); + } + + protected abstract Rounding getRounding(final long low, final long high); + + protected abstract Rounding.Prepared getRoundingPrepared(); + + protected long[] processAfterKey(long[] bounds, long interval) { + return bounds; + } + + protected long[] processHardBounds(long[] bounds) { + return processHardBounds(bounds, null); + } + + protected long[] processHardBounds(long[] bounds, LongBounds hardBounds) { + if (bounds != null) { + // Update min/max limit if user specified any hard bounds + if (hardBounds != null) { + if (hardBounds.getMin() > bounds[0]) { + bounds[0] = hardBounds.getMin(); + } + if (hardBounds.getMax() - 1 < bounds[1]) { + bounds[1] = hardBounds.getMax() - 1; // hard bounds max is exclusive + } + if (bounds[0] > bounds[1]) { + return null; + } + } + } + return bounds; + } + + protected DateFieldMapper.DateFieldType getFieldType() { + assert fieldType instanceof DateFieldMapper.DateFieldType; + return (DateFieldMapper.DateFieldType) fieldType; + } + + protected int getSize() { + return Integer.MAX_VALUE; + } + + @Override + public void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { + DateFieldMapper.DateFieldType fieldType = getFieldType(); + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, ord); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); + } + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + incrementDocCount.accept(ord, (long) docCount); + }); + } + + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); + } + + protected static long getBucketOrd(long bucketOrd) { + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + } + + return bucketOrd; + } + + /** + * Provides a function to produce bucket ordinals from the lower bound of the range + */ + protected abstract Function bucketOrdProducer(); + + /** + * Checks whether the top level query matches all documents on the segment + * + *

This method creates a weight from the search context's query and checks whether the weight's + * document count matches the total number of documents in the leaf reader context. + * + * @param ctx the search context + * @param leafCtx the leaf reader context for the segment + * @return {@code true} if the segment matches all documents, {@code false} otherwise + */ + public static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException { + Weight weight = ctx.query().rewrite(ctx.searcher()).createWeight(ctx.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f); + return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs(); + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java new file mode 100644 index 0000000000000..eb57cd90b9ad9 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java @@ -0,0 +1,213 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.IndexOrDocValuesQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.PointRangeQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.NumericUtils; +import org.opensearch.common.Rounding; +import org.opensearch.common.lucene.search.function.FunctionScoreQuery; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.query.DateRangeIncludingNowQuery; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; + +/** + * Utility class to help range filters rewrite optimization + * + * @opensearch.internal + */ +final class Helper { + + private Helper() {} + + static final String loggerName = Helper.class.getPackageName(); + private static final Logger logger = LogManager.getLogger(loggerName); + + private static final Map, Function> queryWrappers; + + // Initialize the wrapper map for unwrapping the query + static { + queryWrappers = new HashMap<>(); + queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery()); + queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery()); + queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery()); + queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery()); + } + + /** + * Recursively unwraps query into the concrete form + * for applying the optimization + */ + private static Query unwrapIntoConcreteQuery(Query query) { + while (queryWrappers.containsKey(query.getClass())) { + query = queryWrappers.get(query.getClass()).apply(query); + } + + return query; + } + + /** + * Finds the global min and max bounds of the field for the shard across all segments + * + * @return null if the field is empty or not indexed + */ + private static long[] getShardBounds(final List leaves, final String fieldName) throws IOException { + long min = Long.MAX_VALUE, max = Long.MIN_VALUE; + for (LeafReaderContext leaf : leaves) { + final PointValues values = leaf.reader().getPointValues(fieldName); + if (values != null) { + min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0)); + max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0)); + } + } + + if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) { + return null; + } + return new long[] { min, max }; + } + + /** + * Finds the min and max bounds of the field for the segment + * + * @return null if the field is empty or not indexed + */ + static long[] getSegmentBounds(final LeafReaderContext context, final String fieldName) throws IOException { + long min = Long.MAX_VALUE, max = Long.MIN_VALUE; + final PointValues values = context.reader().getPointValues(fieldName); + if (values != null) { + min = Math.min(min, NumericUtils.sortableBytesToLong(values.getMinPackedValue(), 0)); + max = Math.max(max, NumericUtils.sortableBytesToLong(values.getMaxPackedValue(), 0)); + } + + if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) { + return null; + } + return new long[] { min, max }; + } + + /** + * Gets the min and max bounds of the field for the shard search + * Depending on the query part, the bounds are computed differently + * + * @return null if the processed query not supported by the optimization + */ + public static long[] getDateHistoAggBounds(final SearchContext context, final String fieldName) throws IOException { + final Query cq = unwrapIntoConcreteQuery(context.query()); + final List leaves = context.searcher().getIndexReader().leaves(); + + if (cq instanceof PointRangeQuery) { + final PointRangeQuery prq = (PointRangeQuery) cq; + final long[] indexBounds = getShardBounds(leaves, fieldName); + if (indexBounds == null) return null; + return getBoundsWithRangeQuery(prq, fieldName, indexBounds); + } else if (cq instanceof MatchAllDocsQuery) { + return getShardBounds(leaves, fieldName); + } else if (cq instanceof FieldExistsQuery) { + // when a range query covers all values of a shard, it will be rewrite field exists query + if (((FieldExistsQuery) cq).getField().equals(fieldName)) { + return getShardBounds(leaves, fieldName); + } + } + + return null; + } + + private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) { + // Ensure that the query and aggregation are on the same field + if (prq.getField().equals(fieldName)) { + // Minimum bound for aggregation is the max between query and global + long lower = Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]); + // Maximum bound for aggregation is the min between query and global + long upper = Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1]); + if (lower > upper) { + return null; + } + return new long[] { lower, upper }; + } + + return null; + } + + /** + * Creates the date ranges from date histo aggregations using its interval, + * and min/max boundaries + */ + static Ranges createRangesFromAgg( + final DateFieldMapper.DateFieldType fieldType, + final long interval, + final Rounding.Prepared preparedRounding, + long low, + final long high, + final int maxAggRewriteFilters + ) { + // Calculate the number of buckets using range and interval + long roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low)); + long prevRounded = roundedLow; + int bucketCount = 0; + while (roundedLow <= fieldType.convertNanosToMillis(high)) { + bucketCount++; + if (bucketCount > maxAggRewriteFilters) { + logger.debug("Max number of range filters reached [{}], skip the optimization", maxAggRewriteFilters); + return null; + } + // Below rounding is needed as the interval could return in + // non-rounded values for something like calendar month + roundedLow = preparedRounding.round(roundedLow + interval); + if (prevRounded == roundedLow) break; // prevents getting into an infinite loop + prevRounded = roundedLow; + } + + long[][] ranges = new long[bucketCount][2]; + if (bucketCount > 0) { + roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low)); + + int i = 0; + while (i < bucketCount) { + // Calculate the lower bucket bound + long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow); + roundedLow = preparedRounding.round(roundedLow + interval); + + // plus one on high value because upper bound is exclusive, but high value exists + long upper = i + 1 == bucketCount ? high + 1 : fieldType.convertRoundedMillisToNanos(roundedLow); + + ranges[i][0] = lower; + ranges[i][1] = upper; + i++; + } + } + + byte[][] lowers = new byte[ranges.length][]; + byte[][] uppers = new byte[ranges.length][]; + for (int i = 0; i < ranges.length; i++) { + byte[] lower = LONG.encodePoint(ranges[i][0]); + byte[] max = LONG.encodePoint(ranges[i][1]); + lowers[i] = lower; + uppers[i] = max; + } + + return new Ranges(lowers, uppers); + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java new file mode 100644 index 0000000000000..896867ebafd53 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java @@ -0,0 +1,188 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.PointValues; +import org.opensearch.index.mapper.DocCountFieldMapper; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.Helper.loggerName; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Context object for doing the filter rewrite optimization in ranges type aggregation + *

+ * This holds the common business logic and delegate aggregator-specific logic to {@link AggregatorBridge} + * + * @opensearch.internal + */ +public final class OptimizationContext { + + private static final Logger logger = LogManager.getLogger(loggerName); + + private boolean canOptimize = false; + private boolean preparedAtShardLevel = false; + + final AggregatorBridge aggregatorBridge; + int maxAggRewriteFilters; + String shardId; + + private Ranges ranges; + private Ranges rangesFromSegment; + + // debug info related fields + private int leaf; + private int inner; + private int segments; + private int optimizedSegments; + + public OptimizationContext(AggregatorBridge aggregatorBridge) { + this.aggregatorBridge = aggregatorBridge; + } + + public boolean canOptimize(final Object parent, SearchContext context) { + if (context.maxAggRewriteFilters() == 0) return false; + if (parent != null) return false; + + this.canOptimize = aggregatorBridge.canOptimize(); + if (canOptimize) { + aggregatorBridge.setOptimizationContext(this); + this.maxAggRewriteFilters = context.maxAggRewriteFilters(); + this.shardId = context.indexShard().shardId().toString(); + } + logger.debug("Fast filter rewriteable: {} for shard {}", canOptimize, shardId); + return canOptimize; + } + + public void prepare() throws IOException { + assert ranges == null : "Ranges should only be built once at shard level, but they are already built"; + aggregatorBridge.prepare(); + if (ranges != null) { + preparedAtShardLevel = true; + } + } + + public void prepareFromSegment(LeafReaderContext leaf) throws IOException { + aggregatorBridge.prepareFromSegment(leaf); + } + + void setRanges(Ranges ranges) { + this.ranges = ranges; + } + + void setRangesFromSegment(Ranges ranges) { + this.rangesFromSegment = ranges; + } + + Ranges getRanges() { + if (rangesFromSegment != null) return rangesFromSegment; + return ranges; + } + + /** + * Try to populate the bucket doc counts for aggregation + *

+ * Usage: invoked at segment level — in getLeafCollector of aggregator + * + * @param incrementDocCount consume the doc_count results for certain ordinal + * @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment + */ + public boolean tryOptimize(final LeafReaderContext leafCtx, LeafBucketCollector sub, final BiConsumer incrementDocCount, boolean segmentMatchAll) + throws IOException { + segments++; + if (!canOptimize) { + return false; + } + + if (leafCtx.reader().hasDeletions()) return false; + + PointValues values = leafCtx.reader().getPointValues(aggregatorBridge.fieldType.name()); + if (values == null) return false; + // only proceed if every document corresponds to exactly one point + if (values.getDocCount() != values.size()) return false; + + NumericDocValues docCountValues = DocValues.getNumeric(leafCtx.reader(), DocCountFieldMapper.NAME); + if (docCountValues.nextDoc() != NO_MORE_DOCS) { + logger.debug( + "Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", + shardId, + leafCtx.ord + ); + return false; + } + + Ranges ranges = tryBuildRangesFromSegment(leafCtx, segmentMatchAll); + if (ranges == null) return false; + + aggregatorBridge.tryOptimize(values, incrementDocCount, sub); + + optimizedSegments++; + logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord); + logger.debug("crossed leaf nodes: {}, inner nodes: {}", leaf, inner); + + rangesFromSegment = null; + return true; + } + + /** + * Even when ranges cannot be built at shard level, we can still build ranges + * at segment level when it's functionally match-all at segment level + */ + private Ranges tryBuildRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMatchAll) throws IOException { + if (!preparedAtShardLevel && !segmentMatchAll) { + return null; + } + + if (ranges == null) { // not built at shard level but segment match all + logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", shardId, leafCtx.ord); + prepareFromSegment(leafCtx); + return rangesFromSegment; + } + return ranges; + } + + /** + * Contains debug info of BKD traversal to show in profile + */ + public static class DebugInfo { + private int leaf = 0; // leaf node visited + private int inner = 0; // inner node visited + + void visitLeaf() { + leaf++; + } + + void visitInner() { + inner++; + } + } + + void consumeDebugInfo(DebugInfo debug) { + leaf += debug.leaf; + inner += debug.inner; + } + + public void populateDebugInfo(BiConsumer add) { + if (optimizedSegments > 0) { + add.accept("optimized_segments", optimizedSegments); + add.accept("unoptimized_segments", segments - optimizedSegments); + add.accept("leaf_visited", leaf); + add.accept("inner_visited", inner); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java new file mode 100644 index 0000000000000..0c276891f345d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumericPointEncoder; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.range.RangeAggregator; +import org.opensearch.search.aggregations.support.ValuesSource; +import org.opensearch.search.aggregations.support.ValuesSourceConfig; + +import java.io.IOException; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +/** + * For range aggregation + */ +public abstract class RangeAggregatorBridge extends AggregatorBridge { + + protected boolean canOptimize(ValuesSourceConfig config, RangeAggregator.Range[] ranges) { + if (config.fieldType() == null) return false; + MappedFieldType fieldType = config.fieldType(); + assert fieldType != null; + if (fieldType.isSearchable() == false || !(fieldType instanceof NumericPointEncoder)) return false; + + if (config.script() == null && config.missing() == null) { + if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) { + // ranges are already sorted by from and then to + // we want ranges not overlapping with each other + double prevTo = ranges[0].getTo(); + for (int i = 1; i < ranges.length; i++) { + if (prevTo > ranges[i].getFrom()) { + return false; + } + prevTo = ranges[i].getTo(); + } + this.fieldType = config.fieldType(); + return true; + } + } + return false; + } + + protected void buildRanges(RangeAggregator.Range[] ranges) { + assert fieldType instanceof NumericPointEncoder; + NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; + byte[][] lowers = new byte[ranges.length][]; + byte[][] uppers = new byte[ranges.length][]; + for (int i = 0; i < ranges.length; i++) { + double rangeMin = ranges[i].getFrom(); + double rangeMax = ranges[i].getTo(); + byte[] lower = numericPointEncoder.encodePoint(rangeMin); + byte[] upper = numericPointEncoder.encodePoint(rangeMax); + lowers[i] = lower; + uppers[i] = upper; + } + + optimizationContext.setRanges(new Ranges(lowers, uppers)); + } + + @Override + public void prepareFromSegment(LeafReaderContext leaf) { + throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level"); + } + + @Override + public final void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docID) -> { + long ord = bucketOrdProducer().apply(activeIndex); + + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, ord); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); + } + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docCount) -> { + long ord = bucketOrdProducer().apply(activeIndex); + incrementDocCount.accept(ord, (long) docCount); + }); + } + + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); + } + + /** + * Provides a function to produce bucket ordinals from index of the corresponding range in the range array + */ + protected abstract Function bucketOrdProducer(); +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java new file mode 100644 index 0000000000000..18c88a4af2e67 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.util.ArrayUtil; + +/** + * Internal ranges representation for the filter rewrite optimization + */ +public final class Ranges { + static ArrayUtil.ByteArrayComparator comparator; + byte[][] lowers; // inclusive + byte[][] uppers; // exclusive + int size; + int byteLen; + + public Ranges(byte[][] lowers, byte[][] uppers) { + this.lowers = lowers; + this.uppers = uppers; + assert lowers.length == uppers.length; + this.size = lowers.length; + this.byteLen = lowers[0].length; + comparator = ArrayUtil.getUnsignedComparator(byteLen); + } + + public static int compareByteValue(byte[] value1, byte[] value2) { return comparator.compare(value1, 0, value2, 0); } + public static boolean withinLowerBound(byte[] value, byte[] lowerBound) { return compareByteValue(value, lowerBound) >= 0; } + public static boolean withinUpperBound(byte[] value, byte[] upperBound) { return compareByteValue(value, upperBound) < 0; } + public boolean withinLowerBound(byte[] value, int idx) { return Ranges.withinLowerBound(value, lowers[idx]); } + public boolean withinUpperBound(byte[] value, int idx) { return Ranges.withinUpperBound(value, uppers[idx]); } + public boolean withinRange(byte[] value, int idx) { return withinLowerBound(value, idx) && withinUpperBound(value, idx); } + + public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { + if (compareByteValue(lowers[0], globalMax) > 0) { + return -1; + } + int i = 0; + while (compareByteValue(uppers[i], globalMin) <= 0) { + i++; + if (i >= size) { + return -1; + } + } + return i; + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java new file mode 100644 index 0000000000000..8ff3821f45dc8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java @@ -0,0 +1,264 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.Helper.loggerName; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Utility class for traversing a {@link PointValues.PointTree} and collecting document counts for the ranges. + * + *

The main entry point is the {@link #multiRangesTraverse(RangeAwareIntersectVisitor)} method + * + *

The class uses a {@link RangeAwareIntersectVisitor} to keep track of the active ranges, traverse the tree, and + * consume documents. + */ +public final class TreeTraversal { + private static final Logger logger = LogManager.getLogger(loggerName); + + /** + * Traverse the RangeAwareIntersectVisitor PointTree. + * Collects and returns DebugInfo from traversal + * @param visitor the maximum number of non-zero ranges to collect + * @return a {@link OptimizationContext.DebugInfo} object containing debug information about the traversal + */ + public static OptimizationContext.DebugInfo multiRangesTraverse(RangeAwareIntersectVisitor visitor) throws IOException { + OptimizationContext.DebugInfo debugInfo = new OptimizationContext.DebugInfo(); + + if (visitor.getActiveIndex() < 0) { + logger.debug("No ranges match the query, skip the fast filter optimization"); + return debugInfo; + } + + try { + visitor.traverse(debugInfo); + } catch (CollectionTerminatedException e) { + logger.debug("Early terminate since no more range to collect"); + } + + return debugInfo; + } + + /** + * This IntersectVisitor contains a packed value representation of Ranges + * as well as the current activeIndex being considered for collection. + */ + public static abstract class RangeAwareIntersectVisitor implements PointValues.IntersectVisitor { + private final PointValues.PointTree pointTree; + private final Ranges ranges; + private final int maxNumNonZeroRange; + protected int visitedRange = 0; + protected int activeIndex; + + public RangeAwareIntersectVisitor( + PointValues.PointTree pointTree, + Ranges ranges, + int maxNumNonZeroRange + ) { + this.ranges = ranges; + this.pointTree = pointTree; + this.maxNumNonZeroRange = maxNumNonZeroRange; + this.activeIndex = ranges.firstRangeIndex(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + } + + public int getActiveIndex() { + return activeIndex; + } + + public abstract void visit(int docID); + public abstract void visit(int docID, byte[] packedValue); + public abstract void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException; + protected abstract void consumeContainedNode(PointValues.PointTree pointTree) throws IOException; + protected abstract void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException; + + public void traverse(OptimizationContext.DebugInfo debug) throws IOException { + PointValues.Relation r = compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + switch (r) { + case CELL_INSIDE_QUERY: + consumeContainedNode(pointTree); + debug.visitInner(); + break; + case CELL_CROSSES_QUERY: + if (pointTree.moveToChild()) { + do { + traverse(debug); + } while (pointTree.moveToSibling()); + pointTree.moveToParent(); + } else { + consumeCrossedNode(pointTree); + debug.visitLeaf(); + } + break; + case CELL_OUTSIDE_QUERY: + } + } + + /** + * increment activeIndex until we run out of ranges or find a valid range that contains maxPackedValue + * else throw CollectionTerminatedException if we run out of ranges to check + * @param minPackedValue lower bound of PointValues.PointTree node + * @param maxPackedValue upper bound of PointValues.PointTree node + * @return the min/max values of the PointValues.PointTree node can be one of: + * 1.) Completely outside the activeIndex range + * 2.) Completely inside the activeIndex range + * 3.) Overlapping with the activeIndex range + */ + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + // try to find the first range that may collect values from this cell + if (!ranges.withinUpperBound(minPackedValue, activeIndex) && iterateRangeEnd(minPackedValue)) { + throw new CollectionTerminatedException(); + } + + // DOES THIS CONDITION EVER RUN? + + // after the loop, min < upper + // cell could be outside [min max] lower + if (!ranges.withinLowerBound(maxPackedValue, activeIndex) && iterateRangeEnd(maxPackedValue)) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + + if (ranges.withinRange(minPackedValue, activeIndex) && ranges.withinRange(maxPackedValue, activeIndex)) { + return PointValues.Relation.CELL_INSIDE_QUERY; + } + return PointValues.Relation.CELL_CROSSES_QUERY; + } + + /** + * throws CollectionTerminatedException if we have reached our last range, and it does not contain packedValue + * @param packedValue determine if packedValue falls within the range at activeIndex + * @return true when packedValue falls within the activeIndex range + */ + protected boolean canCollect(byte[] packedValue) { + if (!ranges.withinUpperBound(packedValue, activeIndex) && iterateRangeEnd(packedValue)) { + throw new CollectionTerminatedException(); + } + return ranges.withinRange(packedValue, activeIndex); + } + + /** + * @param packedValue increment active index until we reach a range containing value + * @return true when we've exhausted all available ranges or visited maxNumNonZeroRange and can stop early + */ + protected boolean iterateRangeEnd(byte[] packedValue) { + // the new value may not be contiguous to the previous one + // so try to find the first next range that cross the new value + while (!ranges.withinUpperBound(packedValue, activeIndex)) { + if (++activeIndex >= ranges.size) { + return true; + } + } + visitedRange++; + return visitedRange > maxNumNonZeroRange; + } + } + + /** + * Traverse PointTree with countDocs callback where countDock inputs are + * 1.) activeIndex for range in which document(s) reside + * 2.) total documents counted + */ + public static class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + BiConsumer countDocs; + + public DocCountRangeAwareIntersectVisitor( + PointValues.PointTree pointTree, + Ranges ranges, + int maxNumNonZeroRange, + BiConsumer countDocs + ) { + super(pointTree, ranges, maxNumNonZeroRange); + this.countDocs = countDocs; + } + + @Override + public void visit(int docID) { + countDocs.accept(activeIndex, 1); + } + + @Override + public void visit(int docID, byte[] packedValue) { + if (canCollect(packedValue)) { + countDocs.accept(activeIndex, 1); + } + } + + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (canCollect(packedValue)) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { + countDocs.accept(activeIndex, 1); + } + } + } + + protected void consumeContainedNode(PointValues.PointTree pointTree) throws IOException { + countDocs.accept(activeIndex, (int) pointTree.size()); + } + + protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocValues(this); + } + } + + /** + * Traverse PointTree with collectDocs callback where collectDocs inputs are + * 1.) activeIndex for range in which document(s) reside + * 2.) document id to collect + */ + public static class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + BiConsumer collectDocs; + + public DocCollectRangeAwareIntersectVisitor( + PointValues.PointTree pointTree, + Ranges ranges, + int maxNumNonZeroRange, + BiConsumer collectDocs + ) { + super(pointTree, ranges, maxNumNonZeroRange); + this.collectDocs = collectDocs; + } + + @Override + public void visit(int docID) { + collectDocs.accept(activeIndex, docID); + } + + @Override + public void visit(int docID, byte[] packedValue) { + if (canCollect(packedValue)) { + collectDocs.accept(activeIndex, docID); + } + } + + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (canCollect(packedValue)) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { + collectDocs.accept(activeIndex, iterator.docID()); + } + } + } + + protected void consumeContainedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocIDs(this); + } + + protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocValues(this); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java new file mode 100644 index 0000000000000..7c7385bb6102d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains filter rewrite optimization for range-type aggregations + *

+ * The idea is to + *

    + *
  • figure out the "ranges" from the aggregation
  • + *
  • leverage the ranges and bkd index to get the result of each range bucket quickly
  • + *
+ * More details in https://github.com/opensearch-project/OpenSearch/pull/14464 + */ +package org.opensearch.search.optimization.filterrewrite; diff --git a/server/src/main/java/org/opensearch/search/optimization/package-info.java b/server/src/main/java/org/opensearch/search/optimization/package-info.java new file mode 100644 index 0000000000000..05e1d67d28e0e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Optimization logic on the search path + */ +package org.opensearch.search.optimization;