diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java new file mode 100644 index 0000000000000..c29808cb3ae4f --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java @@ -0,0 +1,171 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.benchmark.search.aggregations; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.opensearch.common.logging.LogConfigurator; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.NumericPointEncoder; +import org.opensearch.search.optimization.filterrewrite.PackedValueRanges; +import org.opensearch.search.optimization.filterrewrite.TreeTraversal; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + +@Warmup(iterations = 10) +@Measurement(iterations = 5) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Thread) +@Fork(value = 1) +public class BKDTreeMultiRangesTraverseBenchmark { + @State(Scope.Benchmark) + public static class treeState { + @Param({ "10000", "10000000" }) + int treeSize; + + @Param({ "10000", "10000000" }) + int valMax; + + @Param({ "10", "100" }) + int buckets; + + @Param({ "12345" }) + int seed; + + private Random random; + + Path tmpDir; + Directory directory; + IndexWriter writer; + IndexReader reader; + + // multiRangesTraverse params + PointValues.PointTree pointTree; + PackedValueRanges packedValueRanges; + BiConsumer> collectRangeIDs; + int maxNumNonZeroRanges = Integer.MAX_VALUE; + + @Setup + public void setup() throws IOException { + LogConfigurator.setNodeName("sample-name"); + random = new Random(seed); + tmpDir = Files.createTempDirectory("tree-test"); + directory = FSDirectory.open(tmpDir); + writer = new IndexWriter(directory, new IndexWriterConfig()); + + for (int i = 0; i < treeSize; i++) { + writer.addDocument(List.of(new IntField("val", random.nextInt(valMax), Field.Store.NO))); + } + + reader = DirectoryReader.open(writer); + + // should only contain single segment + for (LeafReaderContext lrc : reader.leaves()) { + pointTree = lrc.reader().getPointValues("val").getPointTree(); + } + + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("val", NumberFieldMapper.NumberType.INTEGER); + NumericPointEncoder numericPointEncoder = (NumericPointEncoder) fieldType; + + int bucketWidth = valMax / buckets; + byte[][] lowers = new byte[buckets][]; + byte[][] uppers = new byte[buckets][]; + for (int i = 0; i < buckets; i++) { + lowers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + uppers[i] = numericPointEncoder.encodePoint(i * bucketWidth); + } + + packedValueRanges = new PackedValueRanges(lowers, uppers); + } + + @TearDown + public void tearDown() throws IOException { + for (String indexFile : FSDirectory.listAll(tmpDir)) { + Files.deleteIfExists(tmpDir.resolve(indexFile)); + } + Files.deleteIfExists(tmpDir); + } + } + + @Benchmark + public Map> multiRangeTraverseTree(treeState state) throws Exception { + Map> mockIDCollect = new HashMap<>(); + + TreeTraversal.RangeAwareIntersectVisitor treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor( + state.pointTree, + state.packedValueRanges, + state.maxNumNonZeroRanges, + (activeIndex, docID) -> { + if (mockIDCollect.containsKey(activeIndex)) { + mockIDCollect.get(activeIndex).add(docID); + } else { + mockIDCollect.put(activeIndex, List.of(docID)); + } + } + ); + + multiRangesTraverse(treeVisitor); + return mockIDCollect; + } +} diff --git a/release-notes/opensearch.release-notes-2.1.0.md b/release-notes/opensearch.release-notes-2.1.0.md index b32864990b82e..e55aab4810c2c 100644 --- a/release-notes/opensearch.release-notes-2.1.0.md +++ b/release-notes/opensearch.release-notes-2.1.0.md @@ -61,7 +61,7 @@ * Update github action gradle-check to use pull_request_target for accessing token (#3728) ([#3731](https://github.com/opensearch-project/opensearch/pull/3731)) * Add gradle check test for github workflows (#3717) ([#3723](https://github.com/opensearch-project/opensearch/pull/3723)) * Used set to make shell scripts more strict (#3278) ([#3344](https://github.com/opensearch-project/opensearch/pull/3344)) -* Bootstrap should implement a denylist of Java versions (ranges) (#3164) ([#3292](https://github.com/opensearch-project/opensearch/pull/3292)) +* Bootstrap should implement a denylist of Java versions (packedValueRanges) (#3164) ([#3292](https://github.com/opensearch-project/opensearch/pull/3292)) * Add Github Workflow to build and publish lucene snapshots. (#2906) ([#3038](https://github.com/opensearch-project/opensearch/pull/3038)) * Remove JavaVersion in favour of standard Runtime.Version (java-version-checker) (#3027) ([#3034](https://github.com/opensearch-project/opensearch/pull/3034)) * Remove JavaVersion, use builtin Runtime.Version to deal with runtime versions (#3006) ([#3013](https://github.com/opensearch-project/opensearch/pull/3013)) diff --git a/release-notes/opensearch.release-notes-2.14.0.md b/release-notes/opensearch.release-notes-2.14.0.md index c5fc3e895c45d..c55e0a6c27196 100644 --- a/release-notes/opensearch.release-notes-2.14.0.md +++ b/release-notes/opensearch.release-notes-2.14.0.md @@ -34,7 +34,7 @@ - [Search Pipeline] Handle default pipeline for multiple indices ([#13276](https://github.com/opensearch-project/OpenSearch/pull/13276)) - [Batch Ingestion] Add `batch_size` to `_bulk` API. ([#12457](https://github.com/opensearch-project/OpenSearch/issues/12457)) - [Remote Store] Add capability of doing refresh as determined by the translog ([#12992](https://github.com/opensearch-project/OpenSearch/pull/12992)) -- Support multi ranges traversal when doing date histogram rewrite optimization. ([#13317](https://github.com/opensearch-project/OpenSearch/pull/13317)) +- Support multi packedValueRanges traversal when doing date histogram rewrite optimization. ([#13317](https://github.com/opensearch-project/OpenSearch/pull/13317)) ### Dependencies - Bump `org.apache.commons:commons-configuration2` from 2.10.0 to 2.10.1 ([#12896](https://github.com/opensearch-project/OpenSearch/pull/12896)) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml index a75b1d0eac793..940e5adc6468f 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/10_histogram.yml @@ -706,3 +706,80 @@ setup: - match: { profile.shards.0.aggregations.0.debug.unoptimized_segments: 0 } - match: { profile.shards.0.aggregations.0.debug.leaf_visited: 1 } - match: { profile.shards.0.aggregations.0.debug.inner_visited: 0 } + +--- +"date_histogram with range sub aggregation": + - do: + indices.create: + index: test_date_hist_range_sub_agg + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + date: + type: date + - do: + bulk: + refresh: true + index: test_date_hist_range_sub_agg + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 11}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 12}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 23}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 39}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 4}' + - do: + search: + body: + size: 0 + aggs: + histo: + date_histogram: + field: date + calendar_interval: day + aggs: + my_range: + range: + field: v + ranges: + - to: 10 + - from: 10 + to: 20 + - from: 20 + to: 30 + - from: 30 + to: 40 + + - match: { hits.total.value: 8 } + - length: { aggregations.histo.buckets: 9 } + + - match: { aggregations.histo.buckets.0.key_as_string: "2020-03-01T00:00:00.000Z" } + - match: { aggregations.histo.buckets.1.key_as_string: "2020-03-02T00:00:00.000Z" } + - match: { aggregations.histo.buckets.7.key_as_string: "2020-03-08T00:00:00.000Z" } + - match: { aggregations.histo.buckets.8.key_as_string: "2020-03-09T00:00:00.000Z" } + + - match: { aggregations.histo.buckets.0.doc_count: 2 } + - match: { aggregations.histo.buckets.1.doc_count: 1 } + - match: { aggregations.histo.buckets.2.doc_count: 0 } + - match: { aggregations.histo.buckets.7.doc_count: 4 } + - match: { aggregations.histo.buckets.8.doc_count: 1 } + + - match: { aggregations.histo.buckets.0.my_range.buckets.0.doc_count: 1 } + + - match: { aggregations.histo.buckets.7.my_range.buckets.0.doc_count: 0 } + - match: { aggregations.histo.buckets.7.my_range.buckets.1.doc_count: 0 } + - match: { aggregations.histo.buckets.7.my_range.buckets.2.doc_count: 3 } + - match: { aggregations.histo.buckets.7.my_range.buckets.3.doc_count: 1 } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml index 0897e0bdd894b..593083cb384ef 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/330_auto_date_histogram.yml @@ -158,3 +158,75 @@ setup: - match: { profile.shards.0.aggregations.0.debug.unoptimized_segments: 0 } - match: { profile.shards.0.aggregations.0.debug.leaf_visited: 1 } - match: { profile.shards.0.aggregations.0.debug.inner_visited: 0 } + +--- +"Range aggregation with auto_date_histogram sub-aggregation": + - do: + indices.create: + index: sub_agg_profile + body: + mappings: + properties: + "@timestamp": + type: date + metrics.size: + type: long + + - do: + bulk: + refresh: true + index: sub_agg_profile + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 2}' + - '{"index": {}}' + - '{"date": "2020-03-03", "v": 3}' + - '{"index": {}}' + - '{"date": "2020-04-09", "v": 4}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 13}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 14}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 15}' + - '{"index": {}}' + - '{"date": "2020-04-11", "v": 19}' + + - do: + search: + index: sub_agg_profile + body: + size: 0 + aggs: + range_histo: + range: + field: v + ranges: + - to: 0 + - from: 0 + to: 10 + - from: 10 + aggs: + date: + auto_date_histogram: + field: "date" + buckets: 3 + + - match: { hits.total.value: 8 } + - length: { aggregations.range_histo.buckets: 3 } + + - match: { aggregations.range_histo.buckets.0.key: "*-0.0" } + - match: { aggregations.range_histo.buckets.1.key: "0.0-10.0" } + - match: { aggregations.range_histo.buckets.2.key: "10.0-*" } + + - match: { aggregations.range_histo.buckets.0.doc_count: 0 } + - match: { aggregations.range_histo.buckets.1.doc_count: 4 } + - match: { aggregations.range_histo.buckets.2.doc_count: 4 } + + - match: { aggregations.range_histo.buckets.1.date.buckets.0.doc_count: 3 } + - match: { aggregations.range_histo.buckets.1.date.buckets.1.doc_count: 1 } + + - match: { aggregations.range_histo.buckets.2.date.buckets.0.doc_count: 3 } + - match: { aggregations.range_histo.buckets.2.date.buckets.1.doc_count: 1 } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml index 80aad96ce1f6b..807117bee81ea 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/40_range.yml @@ -673,3 +673,71 @@ setup: - match: { aggregations.my_range.buckets.3.from: 1.5 } - is_false: aggregations.my_range.buckets.3.to - match: { aggregations.my_range.buckets.3.doc_count: 2 } + +--- +"range with auto date sub aggregation": + - do: + indices.create: + index: test_range_auto_date + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + properties: + date: + type: date + - do: + bulk: + refresh: true + index: test_range_auto_date + body: + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 1}' + - '{"index": {}}' + - '{"date": "2020-03-01", "v": 11}' + - '{"index": {}}' + - '{"date": "2020-03-02", "v": 12}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 23}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 28}' + - '{"index": {}}' + - '{"date": "2020-03-08", "v": 39}' + - '{"index": {}}' + - '{"date": "2020-03-09", "v": 4}' + - do: + search: + body: + size: 0 + aggs: + my_range: + range: + field: v + ranges: + - to: 10 + - from: 10 + to: 20 + - from: 20 + aggs: + histo: + date_histogram: + field: date + calendar_interval: day + + - match: { hits.total.value: 8 } + - length: { aggregations.my_range.buckets: 3 } + + - match: { aggregations.my_range.buckets.0.key: "*-10.0" } + - match: { aggregations.my_range.buckets.1.key: "10.0-20.0" } + - match: { aggregations.my_range.buckets.2.key: "20.0-*" } + + - match: { aggregations.my_range.buckets.0.doc_count: 2 } + - match: { aggregations.my_range.buckets.1.doc_count: 2 } + - match: { aggregations.my_range.buckets.2.doc_count: 4 } + + - match: { aggregations.my_range.buckets.0.histo.buckets.0.doc_count: 1 } + - match: { aggregations.my_range.buckets.0.histo.buckets.1.doc_count: 0 } + - match: { aggregations.my_range.buckets.2.histo.buckets.0.doc_count: 4 } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index b51bea511e067..c8ecb0d9dc68c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -32,6 +32,7 @@ package org.opensearch.search.aggregations.bucket.composite; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; @@ -77,6 +78,7 @@ import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.optimization.filterrewrite.CompositeAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import org.opensearch.search.searchafter.SearchAfterBuilder; import org.opensearch.search.sort.SortAndFormats; @@ -90,6 +92,7 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.LongFunction; import java.util.function.LongUnaryOperator; import java.util.stream.Collectors; @@ -173,6 +176,7 @@ public final class CompositeAggregator extends BucketsAggregator { @Override public boolean canOptimize() { + if (parent != null || subAggregators.length != 0) return false; if (canOptimize(sourceConfigs)) { this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource(); if (rawAfterKey != null) { @@ -190,9 +194,7 @@ public boolean canOptimize() { } @Override - public void prepare() throws IOException { - buildRanges(context); - } + public void prepare() throws IOException { buildRanges(context); } protected Rounding getRounding(final long low, final long high) { return valuesSource.getRounding(); @@ -213,16 +215,24 @@ protected long[] processAfterKey(long[] bounds, long interval) { } @Override - protected int getSize() { + protected int rangeMax() { return size; } @Override - protected Function bucketOrdProducer() { - return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key)); + protected long getOrd(int rangeIdx){ + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().getLower(rangeIdx), 0); + rangeStart = this.getFieldType().convertNanosToMillis(rangeStart); + long ord = bucketOrds.add(0, getRoundingPrepared().round(rangeStart)); + + if (ord < 0) { // already seen + ord = -1 - ord; + } + + return ord; } }); - if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { + if (optimizationContext.canOptimize(parent, context)) { optimizationContext.prepare(); } } @@ -559,7 +569,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); finishLeaf(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index 9263d3935538d..ee036c8c837aa 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -31,6 +31,7 @@ package org.opensearch.search.aggregations.bucket.histogram; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.CollectionTerminatedException; @@ -66,6 +67,7 @@ import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.LongFunction; import java.util.function.LongToIntFunction; import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; @@ -166,9 +168,7 @@ public boolean canOptimize() { } @Override - public void prepare() throws IOException { - buildRanges(context); - } + public void prepare() throws IOException { buildRanges(context); } @Override protected Rounding getRounding(final long low, final long high) { @@ -200,12 +200,19 @@ protected Prepared getRoundingPrepared() { } @Override - protected Function bucketOrdProducer() { - return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)); - } + protected long getOrd(int rangeIdx){ + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().getLower(rangeIdx), 0); + rangeStart = this.getFieldType().convertNanosToMillis(rangeStart); + long ord = getBucketOrds().add(0, getRoundingPrepared().round(rangeStart)); + if (ord < 0) { // already seen + ord = -1 - ord; + } + + return ord; + } }); - if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { + if (optimizationContext.canOptimize(parent, context)) { optimizationContext.prepare(); } } @@ -239,7 +246,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDocValues values = valuesSource.longValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index 20f62f4d6e3f8..99d8d44393ecf 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -31,6 +31,7 @@ package org.opensearch.search.aggregations.bucket.histogram; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.CollectionTerminatedException; @@ -39,6 +40,7 @@ import org.opensearch.common.Nullable; import org.opensearch.common.Rounding; import org.opensearch.common.lease.Releasables; +import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; @@ -54,12 +56,15 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; import org.opensearch.search.optimization.filterrewrite.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.PackedValueRanges; +import org.opensearch.search.optimization.filterrewrite.RangeAggregatorBridge; import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.LongFunction; import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; @@ -126,9 +131,7 @@ public boolean canOptimize() { } @Override - public void prepare() throws IOException { - buildRanges(context); - } + public void prepare() throws IOException { buildRanges(context); } @Override protected Rounding getRounding(long low, long high) { @@ -146,12 +149,19 @@ protected long[] processHardBounds(long[] bounds) { } @Override - protected Function bucketOrdProducer() { - return (key) -> bucketOrds.add(0, preparedRounding.round((long) key)); - } + protected long getOrd(int rangeIdx){ + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().getLower(rangeIdx), 0); + rangeStart = this.getFieldType().convertNanosToMillis(rangeStart); + long ord = bucketOrds.add(0, getRoundingPrepared().round(rangeStart)); + if (ord < 0) { // already seen + ord = -1 - ord; + } + + return ord; + } }); - if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { + if (optimizationContext.canOptimize(parent, context)) { optimizationContext.prepare(); } } @@ -170,7 +180,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); SortedNumericDocValues values = valuesSource.longValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index c206d1e522e01..74ce313803c8b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -68,6 +68,7 @@ import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.function.LongFunction; import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -288,16 +289,9 @@ public boolean canOptimize() { } @Override - public void prepare() { - buildRanges(ranges); - } - - @Override - protected Function bucketOrdProducer() { - return (activeIndex) -> subBucketOrdinal(0, (int) activeIndex); - } + public void prepare() { buildRanges(ranges); } }); - if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { + if (optimizationContext.canOptimize(parent, context)) { optimizationContext.prepare(); } } @@ -312,7 +306,7 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); + boolean optimized = optimizationContext.tryOptimize(ctx, sub, this::incrementBucketDocCount, false); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java index 9e1c75e659989..ed8c86747698c 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java @@ -11,9 +11,14 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; import java.io.IOException; import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.LongFunction; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; /** * This interface provides a bridge between an aggregator and the optimization context, allowing @@ -67,10 +72,59 @@ void setOptimizationContext(OptimizationContext context) { public abstract void prepareFromSegment(LeafReaderContext leaf) throws IOException; /** - * Attempts to build aggregation results for a segment + * @return max range to stop collecting at. + * Utilized by aggs which stop early. + */ + protected int rangeMax() { + return Integer.MAX_VALUE; + } + + /** + * Translate an index of the packed value range array to an agg bucket ordinal. + */ + protected long getOrd(int rangeIdx){ + return rangeIdx; + } + + /** + * Attempts to build aggregation results for a segment. + * With no sub agg count docs and avoid iterating docIds. + * If a sub agg is present we must iterate through and collect docIds to support it. * * @param values the point values (index structure for numeric values) for a segment * @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket */ - public abstract void tryOptimize(PointValues values, BiConsumer incrementDocCount) throws IOException; + public final void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) + throws IOException { + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor( + values.getPointTree(), + optimizationContext.getRanges(), + rangeMax(), + (activeIndex, docID) -> { + long ord = this.getOrd(activeIndex); + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, ord); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + ); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor( + values.getPointTree(), + optimizationContext.getRanges(), + rangeMax(), + (activeIndex, docCount) -> { + long ord = this.getOrd(activeIndex); + incrementDocCount.accept(ord, (long) docCount); + } + ); + } + + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); + } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java index 1982793332605..1d7d381128a13 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java @@ -8,11 +8,19 @@ package org.opensearch.search.optimization.filterrewrite; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.PointValues; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; +import java.io.IOException; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; + /** * For composite aggregation to do optimization when it only has a single date histogram source */ diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java index da53e4aa73684..1e7008134ad32 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java @@ -16,7 +16,9 @@ import org.opensearch.common.Rounding; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -31,7 +33,6 @@ * For date histogram aggregation */ public abstract class DateHistogramAggregatorBridge extends AggregatorBridge { - protected boolean canOptimize(ValuesSourceConfig config) { if (config.script() == null && config.missing() == null) { MappedFieldType fieldType = config.fieldType(); @@ -56,7 +57,7 @@ public void prepareFromSegment(LeafReaderContext leaf) throws IOException { optimizationContext.setRangesFromSegment(buildRanges(bounds)); } - private Ranges buildRanges(long[] bounds) { + private PackedValueRanges buildRanges(long[] bounds) { bounds = processHardBounds(bounds); if (bounds == null) { return null; @@ -113,45 +114,11 @@ protected long[] processHardBounds(long[] bounds, LongBounds hardBounds) { return bounds; } - private DateFieldMapper.DateFieldType getFieldType() { + protected DateFieldMapper.DateFieldType getFieldType() { assert fieldType instanceof DateFieldMapper.DateFieldType; return (DateFieldMapper.DateFieldType) fieldType; } - protected int getSize() { - return Integer.MAX_VALUE; - } - - @Override - public final void tryOptimize(PointValues values, BiConsumer incrementDocCount) throws IOException { - int size = getSize(); - - DateFieldMapper.DateFieldType fieldType = getFieldType(); - BiConsumer incrementFunc = (activeIndex, docCount) -> { - long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); - rangeStart = fieldType.convertNanosToMillis(rangeStart); - long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); - incrementDocCount.accept(ord, (long) docCount); - }; - - optimizationContext.consumeDebugInfo( - multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), incrementFunc, size) - ); - } - - private static long getBucketOrd(long bucketOrd) { - if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; - } - - return bucketOrd; - } - - /** - * Provides a function to produce bucket ordinals from the lower bound of the range - */ - protected abstract Function bucketOrdProducer(); - /** * Checks whether the top level query matches all documents on the segment * diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java index eb57cd90b9ad9..a79a7d5d0b1a0 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java @@ -155,7 +155,7 @@ private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldN * Creates the date ranges from date histo aggregations using its interval, * and min/max boundaries */ - static Ranges createRangesFromAgg( + static PackedValueRanges createRangesFromAgg( final DateFieldMapper.DateFieldType fieldType, final long interval, final Rounding.Prepared preparedRounding, @@ -208,6 +208,6 @@ static Ranges createRangesFromAgg( uppers[i] = max; } - return new Ranges(lowers, uppers); + return new PackedValueRanges(lowers, uppers); } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java index d4d5880b37ce1..96e6e6b652c1a 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java @@ -15,6 +15,7 @@ import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.PointValues; import org.opensearch.index.mapper.DocCountFieldMapper; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -41,8 +42,8 @@ public final class OptimizationContext { int maxAggRewriteFilters; String shardId; - private Ranges ranges; - private Ranges rangesFromSegment; + private PackedValueRanges packedValueRanges; + private PackedValueRanges packedValueRangesFromSegment; // debug info related fields private int leaf; @@ -54,10 +55,9 @@ public OptimizationContext(AggregatorBridge aggregatorBridge) { this.aggregatorBridge = aggregatorBridge; } - public boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) { + public boolean canOptimize(final Object parent, SearchContext context) { if (context.maxAggRewriteFilters() == 0) return false; - - if (parent != null || subAggLength != 0) return false; + if (parent != null) return false; this.canOptimize = aggregatorBridge.canOptimize(); if (canOptimize) { @@ -70,9 +70,9 @@ public boolean canOptimize(final Object parent, final int subAggLength, SearchCo } public void prepare() throws IOException { - assert ranges == null : "Ranges should only be built once at shard level, but they are already built"; + assert packedValueRanges == null : "Ranges should only be built once at shard level, but they are already built"; aggregatorBridge.prepare(); - if (ranges != null) { + if (packedValueRanges != null) { preparedAtShardLevel = true; } } @@ -81,17 +81,17 @@ public void prepareFromSegment(LeafReaderContext leaf) throws IOException { aggregatorBridge.prepareFromSegment(leaf); } - void setRanges(Ranges ranges) { - this.ranges = ranges; + void setRanges(PackedValueRanges packedValueRanges) { + this.packedValueRanges = packedValueRanges; } - void setRangesFromSegment(Ranges ranges) { - this.rangesFromSegment = ranges; + void setRangesFromSegment(PackedValueRanges packedValueRanges) { + this.packedValueRangesFromSegment = packedValueRanges; } - Ranges getRanges() { - if (rangesFromSegment != null) return rangesFromSegment; - return ranges; + public PackedValueRanges getRanges() { + if (packedValueRangesFromSegment != null) return packedValueRangesFromSegment; + return packedValueRanges; } /** @@ -102,8 +102,12 @@ Ranges getRanges() { * @param incrementDocCount consume the doc_count results for certain ordinal * @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment */ - public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer incrementDocCount, boolean segmentMatchAll) - throws IOException { + public boolean tryOptimize( + final LeafReaderContext leafCtx, + LeafBucketCollector sub, + final BiConsumer incrementDocCount, + boolean segmentMatchAll + ) throws IOException { segments++; if (!canOptimize) { return false; @@ -126,16 +130,16 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer= 0; + } + + public static boolean withinUpperBound(byte[] value, byte[] upperBound) { + return compareByteValue(value, upperBound) < 0; + } + + public byte[] getLower(int idx){ + return lowers[idx]; + } + + public byte[] getUpper(int idx){ + return uppers[idx]; + } + + public boolean withinLowerBound(byte[] value, int idx) { + return PackedValueRanges.withinLowerBound(value, lowers[idx]); + } + + public boolean withinUpperBound(byte[] value, int idx) { + return PackedValueRanges.withinUpperBound(value, uppers[idx]); + } + + public boolean withinRange(byte[] value, int idx) { + return withinLowerBound(value, idx) && withinUpperBound(value, idx); + } + public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { if (compareByteValue(lowers[0], globalMax) > 0) { return -1; @@ -42,16 +75,4 @@ public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { } return i; } - - public static int compareByteValue(byte[] value1, byte[] value2) { - return comparator.compare(value1, 0, value2, 0); - } - - public static boolean withinLowerBound(byte[] value, byte[] lowerBound) { - return compareByteValue(value, lowerBound) >= 0; - } - - public static boolean withinUpperBound(byte[] value, byte[] upperBound) { - return compareByteValue(value, upperBound) < 0; - } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java index f48e863ab8934..2adddbbf535b1 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.PointValues; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumericPointEncoder; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.bucket.range.RangeAggregator; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; @@ -65,30 +66,11 @@ protected void buildRanges(RangeAggregator.Range[] ranges) { uppers[i] = upper; } - optimizationContext.setRanges(new Ranges(lowers, uppers)); + optimizationContext.setRanges(new PackedValueRanges(lowers, uppers)); } @Override public void prepareFromSegment(LeafReaderContext leaf) { throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level"); } - - @Override - public final void tryOptimize(PointValues values, BiConsumer incrementDocCount) throws IOException { - int size = Integer.MAX_VALUE; - - BiConsumer incrementFunc = (activeIndex, docCount) -> { - long ord = bucketOrdProducer().apply(activeIndex); - incrementDocCount.accept(ord, (long) docCount); - }; - - optimizationContext.consumeDebugInfo( - multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), incrementFunc, size) - ); - } - - /** - * Provides a function to produce bucket ordinals from index of the corresponding range in the range array - */ - protected abstract Function bucketOrdProducer(); } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java index aad833324a841..02ed661c99265 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java @@ -13,7 +13,6 @@ import org.apache.lucene.index.PointValues; import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; -import org.opensearch.common.CheckedRunnable; import java.io.IOException; import java.util.function.BiConsumer; @@ -24,201 +23,240 @@ /** * Utility class for traversing a {@link PointValues.PointTree} and collecting document counts for the ranges. * - *

The main entry point is the {@link #multiRangesTraverse(PointValues.PointTree, Ranges, - * BiConsumer, int)} method + *

The main entry point is the {@link #multiRangesTraverse(RangeAwareIntersectVisitor)} method * - *

The class uses a {@link RangeCollectorForPointTree} to keep track of the active ranges and - * determine which parts of the tree to visit. The {@link - * PointValues.IntersectVisitor} implementation is responsible for the actual visitation and - * document count collection. + *

The class uses a {@link RangeAwareIntersectVisitor} to keep track of the active ranges, traverse the tree, and + * consume documents. */ -final class TreeTraversal { - private TreeTraversal() {} - +public final class TreeTraversal { private static final Logger logger = LogManager.getLogger(loggerName); /** - * Traverses the given {@link PointValues.PointTree} and collects document counts for the intersecting ranges. - * - * @param tree the point tree to traverse - * @param ranges the set of ranges to intersect with - * @param incrementDocCount a callback to increment the document count for a range bucket - * @param maxNumNonZeroRanges the maximum number of non-zero ranges to collect + * Traverse the RangeAwareIntersectVisitor PointTree. + * Collects and returns DebugInfo from traversal + * @param visitor the maximum number of non-zero ranges to collect * @return a {@link OptimizationContext.DebugInfo} object containing debug information about the traversal */ - static OptimizationContext.DebugInfo multiRangesTraverse( - final PointValues.PointTree tree, - final Ranges ranges, - final BiConsumer incrementDocCount, - final int maxNumNonZeroRanges - ) throws IOException { + public static OptimizationContext.DebugInfo multiRangesTraverse(RangeAwareIntersectVisitor visitor) throws IOException { OptimizationContext.DebugInfo debugInfo = new OptimizationContext.DebugInfo(); - int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); - if (activeIndex < 0) { + + if (visitor.getActiveIndex() < 0) { logger.debug("No ranges match the query, skip the fast filter optimization"); return debugInfo; } - RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); - PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); + try { - intersectWithRanges(visitor, tree, collector, debugInfo); + visitor.traverse(debugInfo); } catch (CollectionTerminatedException e) { logger.debug("Early terminate since no more range to collect"); } - collector.finalizePreviousRange(); return debugInfo; } - private static void intersectWithRanges( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - RangeCollectorForPointTree collector, - OptimizationContext.DebugInfo debug - ) throws IOException { - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - - switch (r) { - case CELL_INSIDE_QUERY: - collector.countNode((int) pointTree.size()); - debug.visitInner(); - break; - case CELL_CROSSES_QUERY: - if (pointTree.moveToChild()) { - do { - intersectWithRanges(visitor, pointTree, collector, debug); - } while (pointTree.moveToSibling()); - pointTree.moveToParent(); - } else { - pointTree.visitDocValues(visitor); - debug.visitLeaf(); - } - break; - case CELL_OUTSIDE_QUERY: + /** + * This IntersectVisitor contains a packed value representation of Ranges + * as well as the current activeIndex being considered for collection. + */ + public static abstract class RangeAwareIntersectVisitor implements PointValues.IntersectVisitor { + private final PointValues.PointTree pointTree; + private final PackedValueRanges packedValueRanges; + private final int maxNumNonZeroRange; + protected int visitedRange = 0; + protected int activeIndex; + + public RangeAwareIntersectVisitor(PointValues.PointTree pointTree, PackedValueRanges packedValueRanges, int maxNumNonZeroRange) { + this.packedValueRanges = packedValueRanges; + this.pointTree = pointTree; + this.maxNumNonZeroRange = maxNumNonZeroRange; + this.activeIndex = packedValueRanges.firstRangeIndex(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); } - } - private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { - return new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited" - ); + public long getActiveIndex() { + return activeIndex; + } + + public abstract void visit(int docID); + + public abstract void visit(int docID, byte[] packedValue); + + public abstract void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException; + + protected abstract void consumeContainedNode(PointValues.PointTree pointTree) throws IOException; + + protected abstract void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException; + + public void traverse(OptimizationContext.DebugInfo debug) throws IOException { + PointValues.Relation r = compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + switch (r) { + case CELL_INSIDE_QUERY: + consumeContainedNode(pointTree); + debug.visitInner(); + break; + case CELL_CROSSES_QUERY: + if (pointTree.moveToChild()) { + do { + traverse(debug); + } while (pointTree.moveToSibling()); + pointTree.moveToParent(); + } else { + consumeCrossedNode(pointTree); + debug.visitLeaf(); + } + break; + case CELL_OUTSIDE_QUERY: } + } - @Override - public void visit(int docID, byte[] packedValue) throws IOException { - visitPoints(packedValue, collector::count); + /** + * increment activeIndex until we run out of ranges or find a valid range that contains maxPackedValue + * else throw CollectionTerminatedException if we run out of ranges to check + * @param minPackedValue lower bound of PointValues.PointTree node + * @param maxPackedValue upper bound of PointValues.PointTree node + * @return the min/max values of the PointValues.PointTree node can be one of: + * 1.) Completely outside the activeIndex range + * 2.) Completely inside the activeIndex range + * 3.) Overlapping with the activeIndex range + */ + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + // try to find the first range that may collect values from this cell + if (!packedValueRanges.withinUpperBound(minPackedValue, activeIndex) && iterateRangeEnd(minPackedValue)) { + throw new CollectionTerminatedException(); } - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - visitPoints(packedValue, () -> { - for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { - collector.count(); - } - }); + // after the loop, min < upper + // cell could be outside [min max] lower + if (!packedValueRanges.withinLowerBound(maxPackedValue, activeIndex) && iterateRangeEnd(maxPackedValue)) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; } - private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { - if (!collector.withinUpperBound(packedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(packedValue)) { - throw new CollectionTerminatedException(); - } - } + if (packedValueRanges.withinRange(minPackedValue, activeIndex) && packedValueRanges.withinRange(maxPackedValue, activeIndex)) { + return PointValues.Relation.CELL_INSIDE_QUERY; + } + return PointValues.Relation.CELL_CROSSES_QUERY; + } - if (collector.withinRange(packedValue)) { - collect.run(); - } + /** + * throws CollectionTerminatedException if we have reached our last range, and it does not contain packedValue + * @param packedValue determine if packedValue falls within the range at activeIndex + * @return true when packedValue falls within the activeIndex range + */ + protected boolean canCollect(byte[] packedValue) { + if (!packedValueRanges.withinUpperBound(packedValue, activeIndex) && iterateRangeEnd(packedValue)) { + throw new CollectionTerminatedException(); } + return packedValueRanges.withinRange(packedValue, activeIndex); + } - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - // try to find the first range that may collect values from this cell - if (!collector.withinUpperBound(minPackedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(minPackedValue)) { - throw new CollectionTerminatedException(); - } - } - // after the loop, min < upper - // cell could be outside [min max] lower - if (!collector.withinLowerBound(maxPackedValue)) { - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } - if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) { - return PointValues.Relation.CELL_INSIDE_QUERY; + /** + * @param packedValue increment active index until we reach a range containing value + * @return true when we've exhausted all available ranges or visited maxNumNonZeroRange and can stop early + */ + protected boolean iterateRangeEnd(byte[] packedValue) { + // the new value may not be contiguous to the previous one + // so try to find the first next range that cross the new value + while (!packedValueRanges.withinUpperBound(packedValue, activeIndex)) { + if (++activeIndex >= packedValueRanges.size) { + return true; } - return PointValues.Relation.CELL_CROSSES_QUERY; } - }; + visitedRange++; + return visitedRange > maxNumNonZeroRange; + } } - private static class RangeCollectorForPointTree { - private final BiConsumer incrementRangeDocCount; - private int counter = 0; - - private final Ranges ranges; - private int activeIndex; - - private int visitedRange = 0; - private final int maxNumNonZeroRange; + /** + * Traverse PointTree with countDocs callback where countDock inputs are + * 1.) activeIndex for range in which document(s) reside + * 2.) total documents counted + */ + public static class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + BiConsumer countDocs; - public RangeCollectorForPointTree( - BiConsumer incrementRangeDocCount, + public DocCountRangeAwareIntersectVisitor( + PointValues.PointTree pointTree, + PackedValueRanges packedValueRanges, int maxNumNonZeroRange, - Ranges ranges, - int activeIndex + BiConsumer countDocs ) { - this.incrementRangeDocCount = incrementRangeDocCount; - this.maxNumNonZeroRange = maxNumNonZeroRange; - this.ranges = ranges; - this.activeIndex = activeIndex; + super(pointTree, packedValueRanges, maxNumNonZeroRange); + this.countDocs = countDocs; } - private void count() { - counter++; + @Override + public void visit(int docID) { + countDocs.accept(activeIndex, 1); } - private void countNode(int count) { - counter += count; + @Override + public void visit(int docID, byte[] packedValue) { + if (canCollect(packedValue)) { + countDocs.accept(activeIndex, 1); + } } - private void finalizePreviousRange() { - if (counter > 0) { - incrementRangeDocCount.accept(activeIndex, counter); - counter = 0; + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (canCollect(packedValue)) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { + countDocs.accept(activeIndex, 1); + } } } - /** - * @return true when iterator exhausted or collect enough non-zero ranges - */ - private boolean iterateRangeEnd(byte[] value) { - // the new value may not be contiguous to the previous one - // so try to find the first next range that cross the new value - while (!withinUpperBound(value)) { - if (++activeIndex >= ranges.size) { - return true; - } + protected void consumeContainedNode(PointValues.PointTree pointTree) throws IOException { + countDocs.accept(activeIndex, (int) pointTree.size()); + } + + protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocValues(this); + } + } + + /** + * Traverse PointTree with collectDocs callback where collectDocs inputs are + * 1.) activeIndex for range in which document(s) reside + * 2.) document id to collect + */ + public static class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + BiConsumer collectDocs; + + public DocCollectRangeAwareIntersectVisitor( + PointValues.PointTree pointTree, + PackedValueRanges packedValueRanges, + int maxNumNonZeroRange, + BiConsumer collectDocs + ) { + super(pointTree, packedValueRanges, maxNumNonZeroRange); + this.collectDocs = collectDocs; + } + + @Override + public void visit(int docID) { + collectDocs.accept(activeIndex, docID); + } + + @Override + public void visit(int docID, byte[] packedValue) { + if (canCollect(packedValue)) { + collectDocs.accept(activeIndex, docID); } - visitedRange++; - return visitedRange > maxNumNonZeroRange; } - private boolean withinLowerBound(byte[] value) { - return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]); + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (canCollect(packedValue)) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { + collectDocs.accept(activeIndex, iterator.docID()); + } + } } - private boolean withinUpperBound(byte[] value) { - return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]); + protected void consumeContainedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocIDs(this); } - private boolean withinRange(byte[] value) { - return withinLowerBound(value) && withinUpperBound(value); + protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOException { + pointTree.visitDocValues(this); } } }