From 9ab1447c1c33565406c9dfa206a26626f8bcac46 Mon Sep 17 00:00:00 2001 From: Finn Carroll Date: Thu, 1 Aug 2024 15:09:01 -0700 Subject: [PATCH] Modify TreeTraversal usage to match new format Signed-off-by: Finn Carroll --- .../BKDTreeMultiRangesTraverseBenchmark.java | 12 +++--- .../CompositeAggregatorBridge.java | 32 ++++++++++++---- .../DateHistogramAggregatorBridge.java | 38 +++++++++++-------- .../filterrewrite/RangeAggregatorBridge.java | 29 +++++++------- .../filterrewrite/TreeTraversal.java | 4 +- 5 files changed, 70 insertions(+), 45 deletions(-) diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java index 134dcdcb7266f..0af2a3eb98adb 100644 --- a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/BKDTreeMultiRangesTraverseBenchmark.java @@ -54,6 +54,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.mapper.NumericPointEncoder; import org.opensearch.search.optimization.filterrewrite.Ranges; +import org.opensearch.search.optimization.filterrewrite.TreeTraversal; import java.util.*; import java.util.concurrent.TimeUnit; @@ -141,16 +142,15 @@ public void tearDown() throws IOException { public Map> multiRangeTraverseTree(treeState state) throws Exception { Map> mockIDCollect = new HashMap<>(); - BiConsumer> collectRangeIDs = (activeIndex, docIDs) -> { + TreeTraversal.RangeAwareIntersectVisitor treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(state.pointTree, state.ranges, state.maxNumNonZeroRanges, (activeIndex, docID) -> { if (mockIDCollect.containsKey(activeIndex)) { - mockIDCollect.get(activeIndex).addAll(docIDs); + mockIDCollect.get(activeIndex).add(docID); } else { - mockIDCollect.put(activeIndex, docIDs); + mockIDCollect.put(activeIndex, List.of(docID)); } - }; - - multiRangesTraverse(state.pointTree, state.ranges, collectRangeIDs, state.maxNumNonZeroRanges); + }); + multiRangesTraverse(treeVisitor); return mockIDCollect; } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java index 9cd6c35ad9541..86c695ea4dcdc 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java @@ -46,13 +46,29 @@ private boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType @Override public final void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { DateFieldMapper.DateFieldType fieldType = getFieldType(); - BiConsumer> collectRangeIDs = (activeIndex, docIDs) -> { - long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); - rangeStart = fieldType.convertNanosToMillis(rangeStart); - long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); - incrementDocCount.accept(ord, (long) docIDs.size()); - }; - - optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, getSize())); + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, activeIndex); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); + } + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + incrementDocCount.accept(ord, (long) docCount); + }); + } + + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); } } diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java index bff63137f8575..ecba0ee0c3566 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java @@ -126,25 +126,31 @@ protected int getSize() { @Override public void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { - int size = getSize(); - DateFieldMapper.DateFieldType fieldType = getFieldType(); - BiConsumer> collectRangeIDs = (activeIndex, docIDs) -> { - long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); - rangeStart = fieldType.convertNanosToMillis(rangeStart); - long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); - incrementDocCount.accept(ord, (long) docIDs.size()); - - try { - for (int docID : docIDs) { - sub.collect(docID, ord); + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docID) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + + try { + incrementDocCount.accept(ord, (long) 1); + sub.collect(docID, activeIndex); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); } - } catch ( IOException ioe) { - throw new RuntimeException(ioe); - } - }; + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), getSize(), (activeIndex, docCount) -> { + long rangeStart = LongPoint.decodeDimension(optimizationContext.getRanges().lowers[activeIndex], 0); + rangeStart = fieldType.convertNanosToMillis(rangeStart); + long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart)); + incrementDocCount.accept(ord, (long) docCount); + }); + } - optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, size)); + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); } protected static long getBucketOrd(long bucketOrd) { diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java index 3bda6c349de86..a6b456d6971bd 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/RangeAggregatorBridge.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.BiConsumer; import java.util.function.Function; -import java.util.List; import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; @@ -77,22 +76,26 @@ public void prepareFromSegment(LeafReaderContext leaf) { @Override public final void tryOptimize(PointValues values, BiConsumer incrementDocCount, final LeafBucketCollector sub) throws IOException { + TreeTraversal.RangeAwareIntersectVisitor treeVisitor; + if (sub != null) { + treeVisitor = new TreeTraversal.DocCollectRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docID) -> { + long ord = bucketOrdProducer().apply(activeIndex); - - BiConsumer> collectRangeIDs = (activeIndex, docIDs) -> { - long ord = bucketOrdProducer().apply(activeIndex); - incrementDocCount.accept(ord, (long) docIDs.size()); - - try { - for (int docID : docIDs) { + try { + incrementDocCount.accept(ord, (long) 1); sub.collect(docID, activeIndex); + } catch ( IOException ioe) { + throw new RuntimeException(ioe); } - } catch ( IOException ioe) { - throw new RuntimeException(ioe); - } - }; + }); + } else { + treeVisitor = new TreeTraversal.DocCountRangeAwareIntersectVisitor(values.getPointTree(), optimizationContext.getRanges(), Integer.MAX_VALUE, (activeIndex, docCount) -> { + long ord = bucketOrdProducer().apply(activeIndex); + incrementDocCount.accept(ord, (long) docCount); + }); + } - optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), collectRangeIDs, Integer.MAX_VALUE)); + optimizationContext.consumeDebugInfo(multiRangesTraverse(treeVisitor)); } /** diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java index 1e6418e081186..8ff3821f45dc8 100644 --- a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java @@ -173,7 +173,7 @@ protected boolean iterateRangeEnd(byte[] packedValue) { * 1.) activeIndex for range in which document(s) reside * 2.) total documents counted */ - private static abstract class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + public static class DocCountRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { BiConsumer countDocs; public DocCountRangeAwareIntersectVisitor( @@ -220,7 +220,7 @@ protected void consumeCrossedNode(PointValues.PointTree pointTree) throws IOExce * 1.) activeIndex for range in which document(s) reside * 2.) document id to collect */ - private static abstract class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { + public static class DocCollectRangeAwareIntersectVisitor extends RangeAwareIntersectVisitor { BiConsumer collectDocs; public DocCollectRangeAwareIntersectVisitor(