diff --git a/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java b/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java index e538be5d5bece..07747c3f487cb 100644 --- a/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java +++ b/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java @@ -29,6 +29,7 @@ import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregatorFactory; import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.builder.SearchSourceBuilder; @@ -40,6 +41,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -74,9 +76,14 @@ public static StarTreeQueryContext getStarTreeQueryContext(SearchContext context ); for (AggregatorFactory aggregatorFactory : context.aggregations().factories().getFactories()) { + // first check for aggregation is a metric aggregation MetricStat metricStat = validateStarTreeMetricSupport(compositeMappedFieldType, aggregatorFactory); + + // if not a metric aggregation, check for applicable date histogram shape if (metricStat == null) { - return null; + if (validateDateHistogramSupport(compositeMappedFieldType, aggregatorFactory) == false) { + return null; + } } } @@ -159,6 +166,20 @@ private static MetricStat validateStarTreeMetricSupport( return null; } + private static boolean validateDateHistogramSupport( + CompositeDataCubeFieldType compositeIndexFieldInfo, + AggregatorFactory aggregatorFactory + ) { + if (aggregatorFactory instanceof DateHistogramAggregatorFactory && aggregatorFactory.getSubFactories().getFactories().length == 1) { + AggregatorFactory subFactory = aggregatorFactory.getSubFactories().getFactories()[0]; + MetricStat metricStat = validateStarTreeMetricSupport(compositeIndexFieldInfo, subFactory); + if (metricStat != null) { + return true; + } + } + return false; + } + public static CompositeIndexFieldInfo getSupportedStarTree(SearchContext context) { StarTreeQueryContext starTreeQueryContext = context.getStarTreeQueryContext(); return (starTreeQueryContext != null) ? starTreeQueryContext.getStarTree() : null; @@ -240,7 +261,7 @@ public static FixedBitSet getStarTreeFilteredValues(SearchContext context, LeafR throws IOException { FixedBitSet result = context.getStarTreeQueryContext().getStarTreeValues(ctx); if (result == null) { - result = StarTreeFilter.getStarTreeResult(starTreeValues, context.getStarTreeQueryContext().getQueryMap()); + result = StarTreeFilter.getStarTreeResult(starTreeValues, context.getStarTreeQueryContext().getQueryMap(), Set.of()); context.getStarTreeQueryContext().setStarTreeValues(ctx, result); } return result; diff --git a/server/src/main/java/org/opensearch/search/aggregations/StarTreeBucketCollector.java b/server/src/main/java/org/opensearch/search/aggregations/StarTreeBucketCollector.java new file mode 100644 index 0000000000000..722bd147ff510 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/StarTreeBucketCollector.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import java.io.IOException; + +public abstract class StarTreeBucketCollector extends LeafBucketCollector { + + public abstract void collectStarEntry(int starTreeEntry, long bucket) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/StarTreeLeafBucketCollectorBase.java b/server/src/main/java/org/opensearch/search/aggregations/StarTreeLeafBucketCollectorBase.java new file mode 100644 index 0000000000000..458f85bfbf57c --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/StarTreeLeafBucketCollectorBase.java @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.search.aggregations; + +import org.apache.lucene.search.Scorable; +import org.opensearch.common.lucene.ScorerAware; + +import java.io.IOException; + +/** + * A {@link LeafBucketCollector} that delegates all calls to the sub leaf + * aggregator and sets the scorer on its source of values if it implements + * {@link ScorerAware}. + * + * @opensearch.internal + */ +public class StarTreeLeafBucketCollectorBase extends StarTreeBucketCollector { + private final LeafBucketCollector sub; + private final ScorerAware values; + + /** + * @param sub The leaf collector for sub aggregations. + * @param values The values. {@link ScorerAware#setScorer} will be called automatically on them if they implement {@link ScorerAware}. + */ + public StarTreeLeafBucketCollectorBase(LeafBucketCollector sub, Object values) { + this.sub = sub; + if (values instanceof ScorerAware) { + this.values = (ScorerAware) values; + } else { + this.values = null; + } + } + + @Override + public void setScorer(Scorable s) throws IOException { + sub.setScorer(s); + if (values != null) { + values.setScorer(s); + } + } + + @Override + public void collect(int doc, long bucket) throws IOException { + sub.collect(doc, bucket); + } + + @Override + public void collectStarEntry(int starTreeEntry, long bucket) throws IOException {} +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java index eef427754f535..3c08dfaff692a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java @@ -43,6 +43,7 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.StarTreeBucketCollector; import org.opensearch.search.aggregations.bucket.global.GlobalAggregator; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.aggregations.support.AggregationPath; @@ -129,6 +130,14 @@ public final void collectExistingBucket(LeafBucketCollector subCollector, int do subCollector.collect(doc, bucketOrd); } + public final void collectStarTreeBucket(StarTreeBucketCollector subCollector, long docCount, long bucketOrd, int entryBit) + throws IOException { + if (docCounts.increment(bucketOrd, docCount) == docCount) { + multiBucketConsumer.accept(0); + } + subCollector.collectStarEntry(entryBit, bucketOrd); + } + /** * This only tidies up doc counts. Call {@link MergingBucketsDeferringCollector#mergeBuckets(long[])} to merge the actual * ordinals and doc ID deltas. diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index 96a49bc3fd5f6..6cdeda9f4cd04 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -34,11 +34,16 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.CollectionUtil; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.Nullable; import org.opensearch.common.Rounding; import org.opensearch.common.lease.Releasables; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; @@ -47,6 +52,7 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; +import org.opensearch.search.aggregations.StarTreeBucketCollector; import org.opensearch.search.aggregations.bucket.BucketsAggregator; import org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge; import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext; @@ -54,6 +60,7 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.startree.StarTreeFilter; import java.io.IOException; import java.util.Collections; @@ -61,6 +68,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; +import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getStarTreeValues; +import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getSupportedStarTree; import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; /** @@ -171,6 +180,51 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol if (optimized) throw new CollectionTerminatedException(); SortedNumericDocValues values = valuesSource.longValues(ctx); + CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context); + if (supportedStarTree != null) { + StarTreeValues starTreeValues = getStarTreeValues(ctx, supportedStarTree); + assert starTreeValues != null; + + FixedBitSet matchingDocsBitSet = StarTreeFilter.getPredicateValueToFixedBitSetMap(starTreeValues, "@timestamp_month"); + + SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues + .getDimensionValuesIterator("@timestamp_month"); + + SortedNumericStarTreeValuesIterator metricValuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues + .getMetricValuesIterator("startree1__doc_count_doc_count_metric"); + + int numBits = matchingDocsBitSet.length(); + + if (numBits > 0) { + for (int bit = matchingDocsBitSet.nextSetBit(0); bit != DocIdSetIterator.NO_MORE_DOCS; bit = (bit + 1 < numBits) + ? matchingDocsBitSet.nextSetBit(bit + 1) + : DocIdSetIterator.NO_MORE_DOCS) { + + if (!valuesIterator.advanceExact(bit)) { + continue; + } + + for (int i = 0, count = valuesIterator.entryValueCount(); i < count; i++) { + long dimensionValue = valuesIterator.nextValue(); + + if (metricValuesIterator.advanceExact(bit)) { + long metricValue = metricValuesIterator.nextValue(); + + long bucketOrd = bucketOrds.add(0, dimensionValue); + if (bucketOrd < 0) { + bucketOrd = -1 - bucketOrd; + collectStarTreeBucket((StarTreeBucketCollector) sub, metricValue, bucketOrd, bit); + } else { + grow(bucketOrd + 1); + collectStarTreeBucket((StarTreeBucketCollector) sub, metricValue, bucketOrd, bit); + } + } + } + } + } + throw new CollectionTerminatedException(); + } + return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long owningBucketOrd) throws IOException { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java index 3d237a94c5699..3b41d70e62d4c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java @@ -39,13 +39,16 @@ import org.opensearch.common.util.DoubleArray; import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; import org.opensearch.index.compositeindex.datacube.MetricStat; +import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues; import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper; +import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; +import org.opensearch.search.aggregations.StarTreeBucketCollector; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -53,6 +56,7 @@ import java.io.IOException; import java.util.Map; +import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getStarTreeValues; import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getSupportedStarTree; /** @@ -136,6 +140,35 @@ public void collect(int doc, long bucket) throws IOException { public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { final CompensatedSum kahanSummation = new CompensatedSum(sums.get(0), 0); + if (parent != null && subAggregators.length == 0) { + return new StarTreeBucketCollector() { + StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree); + // assert starTreeValues != null; + + // FixedBitSet matchingDocsBitSet = StarTreeFilter.getPredicateValueToFixedBitSetMap(starTreeValues, "@timestamp_month"); + + SortedNumericStarTreeValuesIterator metricValuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues + .getMetricValuesIterator("startree1_status_sum_metric"); + + @Override + public void collectStarEntry(int starTreeEntryBit, long bucket) throws IOException { + sums = context.bigArrays().grow(sums, bucket + 1); + // Advance the valuesIterator to the current bit + if (!metricValuesIterator.advanceExact(starTreeEntryBit)) { + return; // Skip if no entries for this document + } + double metricValue = NumericUtils.sortableLongToDouble(metricValuesIterator.nextValue()); + + double sum = sums.get(bucket); + + // sums = context.bigArrays().grow(sums, bucket + 1); + sums.set(bucket, metricValue + sum); + } + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException {} + }; + } return StarTreeQueryHelper.getStarTreeLeafCollector( context, valuesSource, diff --git a/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java b/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java index f7fa210691678..dca11f99ec9a5 100644 --- a/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java +++ b/server/src/main/java/org/opensearch/search/startree/StarTreeFilter.java @@ -24,6 +24,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -47,9 +48,13 @@ public class StarTreeFilter { * First go over the star tree and try to match as many dimensions as possible * For the remaining columns, use star-tree doc values to match them */ - public static FixedBitSet getStarTreeResult(StarTreeValues starTreeValues, Map predicateEvaluators) throws IOException { + public static FixedBitSet getStarTreeResult( + StarTreeValues starTreeValues, + Map predicateEvaluators, + Set groupbyField + ) throws IOException { Map queryMap = predicateEvaluators != null ? predicateEvaluators : Collections.emptyMap(); - StarTreeResult starTreeResult = traverseStarTree(starTreeValues, queryMap); + StarTreeResult starTreeResult = traverseStarTree(starTreeValues, queryMap, groupbyField); // Initialize FixedBitSet with size maxMatchedDoc + 1 FixedBitSet bitSet = new FixedBitSet(starTreeResult.maxMatchedDoc + 1); @@ -113,7 +118,8 @@ public static FixedBitSet getStarTreeResult(StarTreeValues starTreeValues, Map queryMap) throws IOException { + private static StarTreeResult traverseStarTree(StarTreeValues starTreeValues, Map queryMap, Set groupbyField) + throws IOException { DocIdSetBuilder docsWithField = new DocIdSetBuilder(starTreeValues.getStarTreeDocumentCount()); DocIdSetBuilder.BulkAdder adder; Set globalRemainingPredicateColumns = null; @@ -129,6 +135,7 @@ private static StarTreeResult traverseStarTree(StarTreeValues starTreeValues, Ma queue.add(starTree); int currentDimensionId = -1; Set remainingPredicateColumns = new HashSet<>(queryMap.keySet()); + Set remainingGroupByColumns = new HashSet<>(groupbyField); int matchedDocsCountInStarTree = 0; int maxDocNum = -1; StarTreeNode starTreeNode; @@ -139,13 +146,14 @@ private static StarTreeResult traverseStarTree(StarTreeValues starTreeValues, Ma if (dimensionId > currentDimensionId) { String dimension = dimensionNames.get(dimensionId); remainingPredicateColumns.remove(dimension); + remainingGroupByColumns.remove(dimension); if (foundLeafNode && globalRemainingPredicateColumns == null) { globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns); } currentDimensionId = dimensionId; } - if (remainingPredicateColumns.isEmpty()) { + if (remainingPredicateColumns.isEmpty() && remainingGroupByColumns.isEmpty()) { int docId = starTreeNode.getAggregatedDocId(); docIds.add(docId); matchedDocsCountInStarTree++; @@ -164,7 +172,8 @@ private static StarTreeResult traverseStarTree(StarTreeValues starTreeValues, Ma String childDimension = dimensionNames.get(dimensionId + 1); StarTreeNode starNode = null; - if (globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension)) { + if (globalRemainingPredicateColumns == null + || !globalRemainingPredicateColumns.contains(childDimension) && !remainingGroupByColumns.contains(childDimension)) { starNode = starTreeNode.getChildStarNode(); } @@ -225,4 +234,12 @@ public StarTreeResult( this.maxMatchedDoc = maxMatchedDoc; } } + + public static FixedBitSet getPredicateValueToFixedBitSetMap(StarTreeValues starTreeValues, String predicateField) throws IOException { + Set groupByField = new java.util.HashSet<>(); + groupByField.add(predicateField); + FixedBitSet bitSet = getStarTreeResult(starTreeValues, new HashMap<>(), groupByField); + return bitSet; + } + }