diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index ceb2559a0e16c..e4f5e46950b5f 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -105,7 +105,7 @@ public class FeatureFlags { * aggregations. */ public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled"; - public static final Setting STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope); + public static final Setting STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope); private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, diff --git a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java index 91313092d8d28..c9016dc4c95dd 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java +++ b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java @@ -76,6 +76,7 @@ import org.opensearch.search.aggregations.support.AggregationUsageService; import org.opensearch.search.aggregations.support.ValuesSourceRegistry; import org.opensearch.search.lookup.SearchLookup; +import org.opensearch.search.query.startree.StarTreeQuery; import org.opensearch.transport.RemoteClusterAware; import java.io.IOException; @@ -498,6 +499,12 @@ public boolean indexSortedOnField(String field) { return indexSortConfig.hasPrimarySortOnField(field); } + public ParsedQuery toStarTreeQuery(Map>> compositePredicateMap, + Set groupByColumns) { + StarTreeQuery starTreeQuery = new StarTreeQuery(compositePredicateMap, groupByColumns); + return new ParsedQuery(starTreeQuery); + } + public ParsedQuery toQuery(QueryBuilder queryBuilder) { return toQuery(queryBuilder, q -> { Query query = q.toQuery(this); diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index a53a7198c366f..ef561cc14bbe6 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -77,12 +77,20 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.IndexService; import org.opensearch.index.IndexSettings; +import org.opensearch.index.compositeindex.datacube.Metric; +import org.opensearch.index.compositeindex.datacube.MetricStat; +import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregatorFactory; +import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeHelper; import org.opensearch.index.engine.Engine; +import org.opensearch.index.mapper.CompositeDataCubeFieldType; +import org.opensearch.index.mapper.CompositeMappedFieldType; import org.opensearch.index.mapper.DerivedFieldResolver; import org.opensearch.index.mapper.DerivedFieldResolverFactory; +import org.opensearch.index.mapper.StarTreeMapper; import org.opensearch.index.query.InnerHitContextBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchNoneQueryBuilder; +import org.opensearch.index.query.ParsedQuery; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; @@ -97,11 +105,20 @@ import org.opensearch.script.ScriptService; import org.opensearch.search.aggregations.AggregationInitializationException; import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregation.ReduceContext; import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.SearchContextAggregations; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory; +import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory; +import org.opensearch.search.aggregations.metrics.MetricsAggregator; +import org.opensearch.search.aggregations.metrics.MinAggregatorFactory; +import org.opensearch.search.aggregations.metrics.SumAggregatorFactory; import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree; +import org.opensearch.search.aggregations.startree.StarTreeAggregator; +import org.opensearch.search.aggregations.startree.StarTreeAggregatorFactory; +import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.collapse.CollapseContext; import org.opensearch.search.dfs.DfsPhase; @@ -148,6 +165,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -158,6 +176,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.LongSupplier; +import java.util.stream.Collectors; import static org.opensearch.common.unit.TimeValue.timeValueHours; import static org.opensearch.common.unit.TimeValue.timeValueMillis; @@ -1314,20 +1333,29 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.evaluateRequestShouldUseConcurrentSearch(); return; } + + // Can be marked false for majority cases for which star-tree cannot be used + // Will save checking the criteria later and we can have a limit on what search requests are supported + // As we increment the cases where star-tree can be used, this can be set back to true + boolean canUseStarTree = context.mapperService().isCompositeIndexPresent(); + SearchShardTarget shardTarget = context.shardTarget(); QueryShardContext queryShardContext = context.getQueryShardContext(); context.from(source.from()); context.size(source.size()); Map innerHitBuilders = new HashMap<>(); if (source.query() != null) { + canUseStarTree = false; InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders); context.parsedQuery(queryShardContext.toQuery(source.query())); } if (source.postFilter() != null) { + canUseStarTree = false; InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders); context.parsedPostFilter(queryShardContext.toQuery(source.postFilter())); } - if (innerHitBuilders.size() > 0) { + if (!innerHitBuilders.isEmpty()) { + canUseStarTree = false; for (Map.Entry entry : innerHitBuilders.entrySet()) { try { entry.getValue().build(context, context.innerHits()); @@ -1337,11 +1365,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.sorts() != null) { + canUseStarTree = false; try { Optional optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext()); - if (optionalSort.isPresent()) { - context.sort(optionalSort.get()); - } + optionalSort.ifPresent(context::sort); } catch (IOException e) { throw new SearchException(shardTarget, "failed to create sort elements", e); } @@ -1354,9 +1381,11 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc throw new SearchException(shardTarget, "disabling [track_total_hits] is not allowed in a scroll context"); } if (source.trackTotalHitsUpTo() != null) { + canUseStarTree = false; context.trackTotalHitsUpTo(source.trackTotalHitsUpTo()); } if (source.minScore() != null) { + canUseStarTree = false; context.minimumScore(source.minScore()); } if (source.timeout() != null) { @@ -1372,6 +1401,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.suggest() != null) { + canUseStarTree = false; try { context.suggest(source.suggest().build(queryShardContext)); } catch (IOException e) { @@ -1379,6 +1409,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.rescores() != null) { + canUseStarTree = false; try { for (RescorerBuilder rescore : source.rescores()) { context.addRescore(rescore.buildContext(queryShardContext)); @@ -1388,12 +1419,15 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.explain() != null) { + canUseStarTree = false; context.explain(source.explain()); } if (source.fetchSource() != null) { + canUseStarTree = false; context.fetchSourceContext(source.fetchSource()); } if (source.docValueFields() != null) { + canUseStarTree = false; FetchDocValuesContext docValuesContext = FetchDocValuesContext.create( context.mapperService()::simpleMatchToFullName, context.mapperService().getIndexSettings().getMaxDocvalueFields(), @@ -1402,10 +1436,12 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.docValuesContext(docValuesContext); } if (source.fetchFields() != null) { + canUseStarTree = false; FetchFieldsContext fetchFieldsContext = new FetchFieldsContext(source.fetchFields()); context.fetchFieldsContext(fetchFieldsContext); } if (source.highlighter() != null) { + canUseStarTree = false; HighlightBuilder highlightBuilder = source.highlighter(); try { context.highlight(highlightBuilder.build(queryShardContext)); @@ -1414,6 +1450,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.scriptFields() != null && source.size() != 0) { + canUseStarTree = false; int maxAllowedScriptFields = context.mapperService().getIndexSettings().getMaxScriptFields(); if (source.scriptFields().size() > maxAllowedScriptFields) { throw new IllegalArgumentException( @@ -1439,17 +1476,21 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } if (source.version() != null) { + canUseStarTree = false; context.version(source.version()); } if (source.seqNoAndPrimaryTerm() != null) { +// canUseStarTree = false; context.seqNoAndPrimaryTerm(source.seqNoAndPrimaryTerm()); } if (source.stats() != null) { +// canUseStarTree = false; context.groupStats(source.stats()); } if (CollectionUtils.isEmpty(source.searchAfter()) == false) { +// canUseStarTree = false; if (context.scrollContext() != null) { throw new SearchException(shardTarget, "`search_after` cannot be used in a scroll context."); } @@ -1461,6 +1502,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } if (source.slice() != null) { +// canUseStarTree = false; if (context.scrollContext() == null && !(context.readerContext() instanceof PitReaderContext)) { throw new SearchException(shardTarget, "`slice` cannot be used outside of a scroll context or PIT context"); } @@ -1468,6 +1510,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } if (source.storedFields() != null) { +// canUseStarTree = false; if (source.storedFields().fetchFields() == false) { if (context.sourceRequested()) { throw new SearchException(shardTarget, "[stored_fields] cannot be disabled if [_source] is requested"); @@ -1480,6 +1523,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } if (source.collapse() != null) { +// canUseStarTree = false; if (context.scrollContext() != null) { throw new SearchException(shardTarget, "cannot use `collapse` in a scroll context"); } @@ -1496,6 +1540,88 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc if (source.profile()) { context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch())); } + + if (canUseStarTree) { + try { + if (setStarTreeQuery(context, queryShardContext, source)) { + logger.info("Star Tree will be used in execution"); + }; + } catch (IOException e) { + logger.info("Cannot use star-tree"); + } + + } + } + + private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source) throws IOException { + + // TODO: (finish) + // 1. Check criteria for star-tree query / aggregation formation + // 2: Set StarTree Query & Star Tree Aggregator here + + if (source.aggregations() == null) { + return false; + } + + // TODO: Support for multiple startrees + CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService().getCompositeFieldTypes().iterator().next();; + List supportedDimensions = new ArrayList<>(compositeMappedFieldType.fields()); + Map> supportedMetrics = compositeMappedFieldType.getMetrics().stream() + .collect(Collectors.toMap(Metric::getField, Metric::getMetrics)); + + AggregatorFactories defaultFactories = context.aggregations().factories(); + if (defaultFactories.countAggregators() != 1) { + return false; + } + AggregatorFactory defaultFactory = defaultFactories.getFactories()[0]; + if (!(defaultFactory instanceof TermsAggregatorFactory)) { + return false; + } + String dimension = ((ValuesSourceAggregatorFactory)defaultFactory).getField(); + if (!supportedDimensions.contains(dimension)) { + return false; + } + + TermsAggregatorFactory termsFactory = (TermsAggregatorFactory) defaultFactory; + AggregatorFactory defaultSubFactory = termsFactory.getSubFactories().getFactories()[0]; + + String field, metric; + // TODO: abstract the below metric checks + + + if (defaultSubFactory instanceof MaxAggregatorFactory) { + field = ((ValuesSourceAggregatorFactory)defaultSubFactory).getField(); + if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MAX))) { + return false; + } + metric = MetricStat.MAX.name(); + } else if (defaultSubFactory instanceof SumAggregatorFactory) { + field = ((SumAggregatorFactory)defaultSubFactory).getField(); + if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM))) { + return false; + } + metric = MetricStat.SUM.name(); + } else if (defaultSubFactory instanceof MinAggregatorFactory) { + field = ((MinAggregatorFactory)defaultSubFactory).getField(); + if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MIN))) { + return false; + } + metric = MetricStat.MIN.name(); + } else { + return false; + } + String metricKey = StarTreeHelper.fullFieldNameForStarTreeMetricsDocValues(compositeMappedFieldType.name(), field, metric.toLowerCase()); + + ParsedQuery query = queryShardContext.toStarTreeQuery(null, Set.of(dimension)); + + StarTreeAggregatorFactory factory = new StarTreeAggregatorFactory(defaultFactory.name(), queryShardContext, null, AggregatorFactories.builder(), null, List.of(dimension), List.of(metricKey)); + StarTreeAggregatorFactory[] factories = {factory}; + AggregatorFactories aggregatorFactories = new AggregatorFactories(factories); + + context.parsedQuery(query) + .aggregations(new SearchContextAggregations(aggregatorFactories, multiBucketConsumerService.create())); + + return true; } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java index 47e9def094623..6758d0e80194b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java @@ -37,6 +37,9 @@ import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.indices.breaker.CircuitBreakerService; +import org.opensearch.index.compositeindex.datacube.MetricStat; +import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregator; +import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregatorFactory; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java index eeb0c606694b0..cb1b14b7a263b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java @@ -249,13 +249,13 @@ public boolean test(AggregatorFactory o) { } }; - private AggregatorFactory[] factories; + protected AggregatorFactory[] factories; public static Builder builder() { return new Builder(); } - private AggregatorFactories(AggregatorFactory[] factories) { + public AggregatorFactories(AggregatorFactory[] factories) { this.factories = factories; } @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() { return new PipelineTree(subTrees, aggregators); } } + + public AggregatorFactory[] getFactories() { + return factories; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java index 6cc3a78fb1e36..86fbb46a9ad3c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactory.java @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() { public boolean evaluateChildFactories() { return factories.allFactoriesSupportConcurrentSearch(); } + + public AggregatorFactories getSubFactories() { + return factories; + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java index 4fe936c8b7797..0d537745126d3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class MaxAggregatorFactory extends ValuesSourceAggregatorFactory { +public class MaxAggregatorFactory extends ValuesSourceAggregatorFactory { static void registerAggregators(ValuesSourceRegistry.Builder builder) { builder.register( diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java index 58fbe5edefd12..2b7c8a4cc8c9c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class MinAggregatorFactory extends ValuesSourceAggregatorFactory { +public class MinAggregatorFactory extends ValuesSourceAggregatorFactory { static void registerAggregators(ValuesSourceRegistry.Builder builder) { builder.register( diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java index ef9b93920ba18..e0cd44f2672a8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class SumAggregatorFactory extends ValuesSourceAggregatorFactory { +public class SumAggregatorFactory extends ValuesSourceAggregatorFactory { SumAggregatorFactory( String name, diff --git a/server/src/main/java/org/opensearch/search/aggregations/startree/InternalStarTree.java b/server/src/main/java/org/opensearch/search/aggregations/startree/InternalStarTree.java new file mode 100644 index 0000000000000..0dcdb5019db00 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/startree/InternalStarTree.java @@ -0,0 +1,261 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.InternalMultiBucketAggregation; +import org.opensearch.search.aggregations.support.CoreValuesSourceType; +import org.opensearch.search.aggregations.support.ValueType; +import org.opensearch.search.aggregations.support.ValuesSourceType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InternalStarTree> extends InternalMultiBucketAggregation { + static final InternalStarTree.Factory FACTORY = new InternalStarTree.Factory(); + + public static class Bucket extends InternalMultiBucketAggregation.InternalBucket { + public double sum; + public InternalAggregations aggregations; + private final String key; + + public Bucket(String key, double sum, InternalAggregations aggregations) { + this.key = key; + this.sum = sum; + this.aggregations = aggregations; + } + + @Override + public String getKey() { + return getKeyAsString(); + } + + @Override + public String getKeyAsString() { + return key; + } + + @Override + public long getDocCount() { + return (long) sum; + } + + @Override + public InternalAggregations getAggregations() { + return aggregations; + } + + protected InternalStarTree.Factory getFactory() { + return FACTORY; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(Aggregation.CommonFields.KEY.getPreferredName(), key); + // TODO : this is hack ( we are mapping bucket.noofdocs to sum ) + builder.field("SUM", sum); + aggregations.toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(key); + out.writeVLong((long) sum); + aggregations.writeTo(out); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + InternalStarTree.Bucket that = (InternalStarTree.Bucket) other; + return Objects.equals(sum, that.sum) && Objects.equals(aggregations, that.aggregations) && Objects.equals(key, that.key); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), sum, aggregations, key); + } + } + + public static class Factory> { + public ValuesSourceType getValueSourceType() { + return CoreValuesSourceType.NUMERIC; + } + + public ValueType getValueType() { + return ValueType.NUMERIC; + } + + @SuppressWarnings("unchecked") + public R create(String name, List ranges, Map metadata) { + return (R) new InternalStarTree(name, ranges, metadata); + } + + @SuppressWarnings("unchecked") + public B createBucket(String key, long docCount, InternalAggregations aggregations) { + return (B) new InternalStarTree.Bucket(key, docCount, aggregations); + } + + @SuppressWarnings("unchecked") + public R create(List ranges, R prototype) { + return (R) new InternalStarTree(prototype.name, ranges, prototype.metadata); + } + + @SuppressWarnings("unchecked") + public B createBucket(InternalAggregations aggregations, B prototype) { + // TODO : prototype.getDocCount() -- is mapped to sum - change this + return (B) new InternalStarTree.Bucket(prototype.getKey(), prototype.getDocCount(), aggregations); + } + } + + public InternalStarTree.Factory getFactory() { + return FACTORY; + } + + private final List ranges; + + public InternalStarTree(String name, List ranges, Map metadata) { + super(name, metadata); + this.ranges = ranges; + } + + /** + * Read from a stream. + */ + public InternalStarTree(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + List ranges = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + String key = in.readString(); + ranges.add(getFactory().createBucket(key, in.readVLong(), InternalAggregations.readFrom(in))); + } + this.ranges = ranges; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeVInt(ranges.size()); + for (B bucket : ranges) { + bucket.writeTo(out); + } + } + + @Override + public String getWriteableName() { + return "startree"; + } + + @Override + public List getBuckets() { + return ranges; + } + + public R create(List buckets) { + return getFactory().create(buckets, (R) this); + } + + @Override + public B createBucket(InternalAggregations aggregations, B prototype) { + return getFactory().createBucket(aggregations, prototype); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + Map> bucketsMap = new HashMap<>(); + + for (InternalAggregation aggregation : aggregations) { + InternalStarTree filters = (InternalStarTree) aggregation; + int i = 0; + for (B bucket : filters.ranges) { + String key = bucket.getKey(); + List sameRangeList = bucketsMap.computeIfAbsent(key, k -> new ArrayList<>(aggregations.size())); + sameRangeList.add(bucket); + } + } + + ArrayList reducedBuckets = new ArrayList<>(bucketsMap.size()); + + for (List sameRangeList : bucketsMap.values()) { + B reducedBucket = reduceBucket(sameRangeList, reduceContext); + if (reducedBucket.getDocCount() >= 1) { + reducedBuckets.add(reducedBucket); + } + } + reduceContext.consumeBucketsAndMaybeBreak(reducedBuckets.size()); + reducedBuckets.sort(Comparator.comparing(Bucket::getKey)); + + return getFactory().create(name, reducedBuckets, getMetadata()); + } + + @Override + protected B reduceBucket(List buckets, ReduceContext context) { + assert !buckets.isEmpty(); + + B reduced = null; + List aggregationsList = new ArrayList<>(buckets.size()); + for (B bucket : buckets) { + if (reduced == null) { + reduced = (B) new Bucket(bucket.getKey(), bucket.getDocCount(), bucket.getAggregations()); + } else { + reduced.sum += bucket.sum; + } + aggregationsList.add(bucket.getAggregations()); + } + reduced.aggregations = InternalAggregations.reduce(aggregationsList, context); + return reduced; + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(CommonFields.BUCKETS.getPreferredName()); + + for (B range : ranges) { + range.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), ranges); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + if (super.equals(obj) == false) return false; + + InternalStarTree that = (InternalStarTree) obj; + return Objects.equals(ranges, that.ranges); + } + +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregationBuilder.java b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregationBuilder.java new file mode 100644 index 0000000000000..1252e8dc1c625 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregationBuilder.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.aggregations.AbstractAggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + * TODO: Will replace with InternalStarTreeAggregationBuilder since this is user-hidden implementation, + * decided for specific query/aggregations which can be resolved using StarTree + */ +public class StarTreeAggregationBuilder extends AbstractAggregationBuilder { + public static final String NAME = "startree"; + + private List fieldCols; + private List metrics; + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + NAME, + StarTreeAggregationBuilder::new + ); + + static { + PARSER.declareStringArray(StarTreeAggregationBuilder::groupby, new ParseField("groupby")); + PARSER.declareStringArray(StarTreeAggregationBuilder::metrics, new ParseField("metrics")); + } + + private void groupby(List strings) { + fieldCols = new ArrayList<>(); + fieldCols.addAll(strings); + } + + private void metrics(List strings) { + metrics = new ArrayList<>(); + metrics.addAll(strings); + } + + public StarTreeAggregationBuilder(String name) { + super(name); + } + + protected StarTreeAggregationBuilder( + StarTreeAggregationBuilder clone, + AggregatorFactories.Builder factoriesBuilder, + Map metadata + ) { + super(clone, factoriesBuilder, metadata); + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new StarTreeAggregationBuilder(this, factoriesBuilder, metadata); + } + + /** + * Read from a stream. + */ + public StarTreeAggregationBuilder(StreamInput in) throws IOException { + super(in); + String[] fieldArr = in.readOptionalStringArray(); + String[] metrics = in.readOptionalStringArray(); + if (fieldArr != null) { + fieldCols = Arrays.asList(fieldArr); + } + if(metrics != null) { + this.metrics = Arrays.asList(metrics); + } + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // Nothing to write + out.writeOptionalStringArray(fieldCols.toArray(new String[0])); + out.writeOptionalStringArray(metrics.toArray(new String[0])); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.MANY; + } + + @Override + protected AggregatorFactory doBuild( + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder + ) throws IOException { + return new StarTreeAggregatorFactory(name, queryShardContext, parent, subFactoriesBuilder, metadata, fieldCols, metrics); + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getType() { + return NAME; + } + +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregator.java new file mode 100644 index 0000000000000..58dac3811c848 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregator.java @@ -0,0 +1,239 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; + +import java.util.concurrent.atomic.AtomicReference; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.util.NumericUtils; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ConstructingObjectParser; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.LeafBucketCollectorBase; +import org.opensearch.search.aggregations.bucket.BucketsAggregator; +import org.opensearch.search.aggregations.bucket.SingleBucketAggregator; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.StringJoiner; + +import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class StarTreeAggregator extends BucketsAggregator implements SingleBucketAggregator { + + private Map sumMap = new HashMap<>(); + private Map indexMap = new HashMap<>(); + + private List fieldCols; + private List metrics; + + private static final Logger logger = LogManager.getLogger(StarTreeAggregator.class); + + public StarTreeAggregator( + String name, + AggregatorFactories factories, + SearchContext context, + Aggregator parent, + Map metadata, + List fieldCols, + List metrics + ) throws IOException { + super(name, factories, context, parent, CardinalityUpperBound.MANY, metadata); + this.fieldCols = fieldCols; + this.metrics = metrics; + } + + public static class StarTree implements Writeable, ToXContentObject { + public static final ParseField KEY_FIELD = new ParseField("key"); + + protected final String key; + + public StarTree(String key) { + this.key = key; + } + + /** + * Read from a stream. + */ + public StarTree(StreamInput in) throws IOException { + key = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(key); + } + + public String getKey() { + return this.key; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (key != null) { + builder.field(KEY_FIELD.getPreferredName(), key); + } + builder.endObject(); + return builder; + } + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("startree", arg -> { + String key = (String) arg[0]; + return new StarTree(key); + }); + + static { + PARSER.declareField(optionalConstructorArg(), (p, c) -> p.text(), KEY_FIELD, ObjectParser.ValueType.DOUBLE); + } + + @Override + public int hashCode() { + return Objects.hash(key); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + StarTree other = (StarTree) obj; + return Objects.equals(key, other.key); + } + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + + return buildAggregationsForFixedBucketCount( + owningBucketOrds, + indexMap.size(), + (offsetInOwningOrd, docCount, subAggregationResults) -> { + // TODO : make this better + String key = ""; + for (Map.Entry entry : indexMap.entrySet()) { + if (offsetInOwningOrd == entry.getValue()) { + key = entry.getKey(); + break; + } + } + + // return starTreeFactory.createBucket(key, docCount, subAggregationResults); + return new InternalStarTree.Bucket(key, sumMap.get(key), subAggregationResults); + }, + buckets -> create(name, buckets, metadata()) + ); + } + + public InternalStarTree create(String name, List ranges, Map metadata) { + return new InternalStarTree(name, ranges, metadata); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalStarTree(name, new ArrayList(), new HashMap<>()); + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + //StarTreeAggregatedValues values = (StarTreeAggregatedValues) ctx.reader().getAggregatedDocValues(); + SegmentReader reader = Lucene.segmentReader(ctx.reader()); + + if(!(reader.getDocValuesReader() instanceof CompositeIndexReader)) return null; + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + List fiList = starTreeDocValuesReader.getCompositeIndexFields(); + StarTreeValues values = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(fiList.get(0)); + final AtomicReference aggrVal = new AtomicReference<>(null); + return new LeafBucketCollectorBase(sub, values) { + @Override + public void collect(int doc, long bucket) throws IOException { + if(aggrVal.get() == null) { + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + List fiList = starTreeDocValuesReader.getCompositeIndexFields(); + StarTreeValues values = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(fiList.get(0)); + aggrVal.set(values); + } + StarTreeValues aggrVals = aggrVal.get(); + List fieldColToDocValuesMap = new ArrayList<>(); + + // TODO : validations + for (String field : fieldCols) { + fieldColToDocValuesMap.add((SortedNumericDocValues) aggrVals.getDimensionDocValuesIteratorMap().get(field)); + } + // Another hardcoding + SortedNumericDocValues dv = (SortedNumericDocValues) aggrVals.getMetricDocValuesIteratorMap().get(metrics.get(0)); + if (dv.advanceExact(doc)) { + long v = dv.nextValue(); + double val = NumericUtils.sortableLongToDouble(v); + //System.out.println(val); + // TODO : do optimization for sorted numeric doc vals ? + // if (fieldColToDocValuesMap.size() == 1 ) { + //final int valuesCount = dv.docValueCount(); +// for (int i = 0; i < valuesCount; i++) { +// val1 += dv.nextValue(); +// } + // } + + String key = getKey(fieldColToDocValuesMap, doc); + //System.out.println(key); + if(key.equals("") ) { + return; + } + if (indexMap.containsKey(key)) { + sumMap.put(key, sumMap.getOrDefault(key, 0.0) + val); + } else { + indexMap.put(key, indexMap.size()); + sumMap.put(key, val); + } + collectBucket(sub, doc, subBucketOrdinal(bucket, indexMap.get(key))); + } + } + }; + + } + + private String getKey(List dimensionsKeyList, int doc) throws IOException { + StringJoiner sj = new StringJoiner("-"); + for (SortedNumericDocValues dim : dimensionsKeyList) { + dim.advanceExact(doc); + long val = dim.nextValue(); + sj.add("" + val); + } + return sj.toString(); + } + + private long subBucketOrdinal(long owningBucketOrdinal, int keyOrd) { + return owningBucketOrdinal * indexMap.size() + keyOrd; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregatorFactory.java new file mode 100644 index 0000000000000..1870715a8dba5 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/startree/StarTreeAggregatorFactory.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; + +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; +import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class StarTreeAggregatorFactory extends AggregatorFactory { + private List fieldCols; + private List metrics; + + public StarTreeAggregatorFactory( + String name, + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metadata, + List fieldCols, + List metrics + ) throws IOException { + super(name, queryShardContext, parent, subFactoriesBuilder, metadata); + this.fieldCols = fieldCols; + this.metrics = metrics; + } + + @Override + public Aggregator createInternal( + SearchContext searchContext, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata + ) throws IOException { + return new StarTreeAggregator(name, factories, searchContext, parent, metadata, fieldCols, metrics); + } + + @Override + protected boolean supportsConcurrentSegmentSearch() { + return true; + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/startree/package-info.java b/server/src/main/java/org/opensearch/search/aggregations/startree/package-info.java new file mode 100644 index 0000000000000..ef76726106a25 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/aggregations/startree/package-info.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java index 69a4a5d8b6703..94570801536e8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java @@ -40,7 +40,12 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; /** * Base class for all values source agg factories @@ -102,4 +107,12 @@ protected abstract Aggregator doCreateInternal( public String getStatsSubtype() { return config.valueSourceType().typeName(); } + + public String getField() { + return config.fieldContext().field(); + } + + public String getAggregationName() { + return name; + } } diff --git a/server/src/main/java/org/opensearch/search/query/startree/StarTreeFilter.java b/server/src/main/java/org/opensearch/search/query/startree/StarTreeFilter.java new file mode 100644 index 0000000000000..26b0c0f7ebc88 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/startree/StarTreeFilter.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opensearch.search.query.startree; + + +import java.util.LinkedHashMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.DocIdSetBuilder; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.Predicate; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.node.StarTreeNode; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** Filter operator for star tree data structure. */ +public class StarTreeFilter { + private static final Logger logger = LogManager.getLogger(StarTreeFilter.class); + + + /** Helper class to wrap the result from traversing the star tree. */ + static class StarTreeResult { + final DocIdSetBuilder _matchedDocIds; + final Set _remainingPredicateColumns; + final int numOfMatchedDocs; + final int maxMatchedDoc; + + StarTreeResult(DocIdSetBuilder matchedDocIds, Set remainingPredicateColumns, int numOfMatchedDocs, + int maxMatchedDoc) { + _matchedDocIds = matchedDocIds; + _remainingPredicateColumns = remainingPredicateColumns; + this.numOfMatchedDocs = numOfMatchedDocs; + this.maxMatchedDoc = maxMatchedDoc; + } + } + + private final StarTreeNode starTreeRoot; + + Map>> _predicateEvaluators; + private final Set _groupByColumns; + + DocIdSetBuilder docsWithField; + + DocIdSetBuilder.BulkAdder adder; + Map dimValueMap; + public StarTreeFilter( + StarTreeValues starTreeAggrStructure, + Map>> predicateEvaluators, + Set groupByColumns + ) throws IOException { + // This filter operator does not support AND/OR/NOT operations. + starTreeRoot = starTreeAggrStructure.getRoot(); + dimValueMap = starTreeAggrStructure.getDimensionDocValuesIteratorMap(); + _predicateEvaluators = predicateEvaluators != null ? predicateEvaluators : Collections.emptyMap(); + _groupByColumns = groupByColumns != null ? groupByColumns : Collections.emptySet(); + + // TODO : this should be the maximum number of doc values + docsWithField = new DocIdSetBuilder(Integer.MAX_VALUE); + } + + /** + *
    + *
  • First go over the star tree and try to match as many dimensions as possible + *
  • For the remaining columns, use doc values indexes to match them + *
+ */ + public DocIdSetIterator getStarTreeResult() throws IOException { + StarTreeResult starTreeResult = traverseStarTree(); + //logger.info("Matched docs in star tree : {}" , starTreeResult.numOfMatchedDocs); + List andIterators = new ArrayList<>(); + andIterators.add(starTreeResult._matchedDocIds.build().iterator()); + DocIdSetIterator docIdSetIterator = andIterators.get(0); + // No matches, return + if(starTreeResult.maxMatchedDoc == -1) { + return docIdSetIterator; + } + int docCount = 0; + for (String remainingPredicateColumn : starTreeResult._remainingPredicateColumns) { + // TODO : set to max value of doc values + logger.info("remainingPredicateColumn : {}, maxMatchedDoc : {} ", remainingPredicateColumn, starTreeResult.maxMatchedDoc); + DocIdSetBuilder builder = new DocIdSetBuilder(starTreeResult.maxMatchedDoc + 1); + List> compositePredicateEvaluators = _predicateEvaluators.get(remainingPredicateColumn); + SortedNumericDocValues ndv = (SortedNumericDocValues) this.dimValueMap.get(remainingPredicateColumn); + List docIds = new ArrayList<>(); + while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + docCount++; + int docID = docIdSetIterator.docID(); + if(ndv.advanceExact(docID)) { + final int valuesCount = ndv.docValueCount(); + long value = ndv.nextValue(); + for (Predicate compositePredicateEvaluator : compositePredicateEvaluators) { + // TODO : this might be expensive as its done against all doc values docs + if (compositePredicateEvaluator.test(value)) { + docIds.add(docID); + for (int i = 0; i < valuesCount - 1; i++) { + while(docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + docIds.add(docIdSetIterator.docID()); + } + } + break; + } + } + } + } + DocIdSetBuilder.BulkAdder adder = builder.grow(docIds.size()); + for(int docID : docIds) { + adder.add(docID); + } + docIdSetIterator = builder.build().iterator(); + } + return docIdSetIterator; + } + + /** + * Helper method to traverse the star tree, get matching documents and keep track of all the + * predicate dimensions that are not matched. + */ + private StarTreeResult traverseStarTree() throws IOException { + Set globalRemainingPredicateColumns = null; + + StarTreeNode starTree = starTreeRoot; + + List dimensionNames = new ArrayList<>(dimValueMap.keySet()); + + // Track whether we have found a leaf node added to the queue. If we have found a leaf node, and + // traversed to the + // level of the leave node, we can set globalRemainingPredicateColumns if not already set + // because we know the leaf + // node won't split further on other predicate columns. + boolean foundLeafNode = starTree.isLeaf(); + + // Use BFS to traverse the star tree + Queue queue = new ArrayDeque<>(); + queue.add(starTree); + int currentDimensionId = -1; + Set remainingPredicateColumns = new HashSet<>(_predicateEvaluators.keySet()); + Set remainingGroupByColumns = new HashSet<>(_groupByColumns); + if (foundLeafNode) { + globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns); + } + + int matchedDocsCountInStarTree = 0; + int maxDocNum = -1; + + StarTreeNode starTreeNode; + List docIds = new ArrayList<>(); + while ((starTreeNode = queue.poll()) != null) { + int dimensionId = starTreeNode.getDimensionId(); + if (dimensionId > currentDimensionId) { + // Previous level finished + String dimension = dimensionNames.get(dimensionId); + remainingPredicateColumns.remove(dimension); + remainingGroupByColumns.remove(dimension); + if (foundLeafNode && globalRemainingPredicateColumns == null) { + globalRemainingPredicateColumns = new HashSet<>(remainingPredicateColumns); + } + currentDimensionId = dimensionId; + } + + // If all predicate columns and group-by columns are matched, we can use aggregated document + if (remainingPredicateColumns.isEmpty() && remainingGroupByColumns.isEmpty()) { + int docId = starTreeNode.getAggregatedDocId(); + docIds.add(docId); + matchedDocsCountInStarTree++; + maxDocNum = Math.max(docId, maxDocNum); + continue; + } + + // For leaf node, because we haven't exhausted all predicate columns and group-by columns, we + // cannot use + // the aggregated document. Add the range of documents for this node to the bitmap, and keep + // track of the + // remaining predicate columns for this node + if (starTreeNode.isLeaf()) { + for (long i = starTreeNode.getStartDocId(); i < starTreeNode.getEndDocId(); i++) { + docIds.add((int)i); + matchedDocsCountInStarTree++; + maxDocNum = Math.max((int) i, maxDocNum); + } + continue; + } + + // For non-leaf node, proceed to next level + String childDimension = dimensionNames.get(dimensionId + 1); + + // Only read star-node when the dimension is not in the global remaining predicate columns or + // group-by columns + // because we cannot use star-node in such cases + StarTreeNode starNode = null; + if ((globalRemainingPredicateColumns == null || !globalRemainingPredicateColumns.contains(childDimension)) + && !remainingGroupByColumns.contains(childDimension)) { + starNode = starTreeNode.getChildForDimensionValue(StarTreeNode.ALL); + } + + if (remainingPredicateColumns.contains(childDimension)) { + // Have predicates on the next level, add matching nodes to the queue + + // Calculate the matching dictionary ids for the child dimension + int numChildren = starTreeNode.getNumChildren(); + + // If number of matching dictionary ids is large, use scan instead of binary search + + Iterator childrenIterator = starTreeNode.getChildrenIterator(); + + // When the star-node exists, and the number of matching doc ids is more than or equal to + // the + // number of non-star child nodes, check if all the child nodes match the predicate, and use + // the star-node if so + if (starNode != null) { + List matchingChildNodes = new ArrayList<>(); + boolean findLeafChildNode = false; + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + List> predicates = _predicateEvaluators.get(childDimension); + for (Predicate predicate : predicates) { + long val = childNode.getDimensionValue(); + if (predicate.test(val)) { + matchingChildNodes.add(childNode); + findLeafChildNode |= childNode.isLeaf(); + break; + } + } + } + if (matchingChildNodes.size() == numChildren - 1) { + // All the child nodes (except for the star-node) match the predicate, use the star-node + queue.add(starNode); + foundLeafNode |= starNode.isLeaf(); + } else { + // Some child nodes do not match the predicate, use the matching child nodes + queue.addAll(matchingChildNodes); + foundLeafNode |= findLeafChildNode; + } + } else { + // Cannot use the star-node, use the matching child nodes + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + List> predicates = _predicateEvaluators.get(childDimension); + for (Predicate predicate : predicates) { + if (predicate.test(childNode.getDimensionValue())) { + queue.add(childNode); + foundLeafNode |= childNode.isLeaf(); + break; + } + } + } + } + } else { + // No predicate on the next level + + if (starNode != null) { + // Star-node exists, use it + queue.add(starNode); + foundLeafNode |= starNode.isLeaf(); + } else { + // Star-node does not exist or cannot be used, add all non-star nodes to the queue + Iterator childrenIterator = starTreeNode.getChildrenIterator(); + while (childrenIterator.hasNext()) { + StarTreeNode childNode = childrenIterator.next(); + if (childNode.getDimensionValue() != StarTreeNode.ALL) { + queue.add(childNode); + foundLeafNode |= childNode.isLeaf(); + } + } + } + } + } + + adder = docsWithField.grow(docIds.size()); + for(int id : docIds) { + adder.add(id); + } + return new StarTreeResult( + docsWithField, + globalRemainingPredicateColumns != null ? globalRemainingPredicateColumns : Collections.emptySet(), + matchedDocsCountInStarTree, + maxDocNum + ); + } +} diff --git a/server/src/main/java/org/opensearch/search/query/startree/StarTreeQuery.java b/server/src/main/java/org/opensearch/search/query/startree/StarTreeQuery.java new file mode 100644 index 0000000000000..e4e6c1b0afb77 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/startree/StarTreeQuery.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opensearch.search.query.startree; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Accountable; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.datacube.startree.StarTreeValues; + +/** Query class for querying star tree data structure */ +public class StarTreeQuery extends Query implements Accountable { + + Map>> compositePredicateMap; + Set groupByColumns; + + public StarTreeQuery(Map>> compositePredicateMap, Set groupByColumns) { + this.compositePredicateMap = compositePredicateMap; + this.groupByColumns = groupByColumns; + } + + @Override + public String toString(String field) { + return null; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + return sameClassAs(obj); + } + + @Override + public int hashCode() { + return classHash(); + } + + @Override + public long ramBytesUsed() { + return 0; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new ConstantScoreWeight(this, boost) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + SegmentReader reader = Lucene.segmentReader(context.reader()); + + // We get the 'StarTreeReader' instance so that we can get StarTreeValues + + if(!(reader.getDocValuesReader() instanceof CompositeIndexReader)) return null; + + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + List fiList = starTreeDocValuesReader.getCompositeIndexFields(); + StarTreeValues starTreeValues = null; + if(fiList != null && !fiList.isEmpty()) { + starTreeValues = (StarTreeValues) starTreeDocValuesReader.getCompositeIndexValues(fiList.get(0)); + } else { + return null; + } + + DocIdSetIterator result = null; + StarTreeFilter filter = new StarTreeFilter(starTreeValues, compositePredicateMap, groupByColumns); + result = filter.getStarTreeResult(); + return new ConstantScoreScorer(this, score(), scoreMode, result); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/startree/StarTreeQueryBuilder.java b/server/src/main/java/org/opensearch/search/query/startree/StarTreeQueryBuilder.java new file mode 100644 index 0000000000000..41364af69ea90 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/startree/StarTreeQueryBuilder.java @@ -0,0 +1,169 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query.startree; + + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +/** + * TODO: Will replace with InternalStarTreeQueryBuilder since this is user-hidden implementation, + * decided for specific query/aggregations which can be resolved using StarTree + */ +public class StarTreeQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "startree"; + private static final ParseField FILTER = new ParseField("filter"); + private final List filterClauses = new ArrayList<>(); + + private final Set groupBy = new HashSet<>(); + Map>> predicateMap = new HashMap<>(); + private static final Logger logger = LogManager.getLogger(StarTreeQueryBuilder.class); + + public StarTreeQueryBuilder() {} + + /** + * Read from a stream. + */ + public StarTreeQueryBuilder(StreamInput in) throws IOException { + super(in); + filterClauses.addAll(readQueries(in)); + in.readOptionalStringArray(); + } + + static List readQueries(StreamInput in) throws IOException { + int size = in.readVInt(); + List queries = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + queries.add(in.readNamedWriteable(QueryBuilder.class)); + } + return queries; + } + + @Override + protected void doWriteTo(StreamOutput out) { + // only superclass has state + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + doXArrayContent(FILTER, filterClauses, builder, params); + builder.endObject(); + } + + private static void doXArrayContent(ParseField field, List clauses, XContentBuilder builder, Params params) + throws IOException { + if (clauses.isEmpty()) { + return; + } + builder.startArray(field.getPreferredName()); + for (QueryBuilder clause : clauses) { + clause.toXContent(builder, params); + } + builder.endArray(); + } + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, StarTreeQueryBuilder::new); + + static { + PARSER.declareObjectArrayOrNull( + (builder, clauses) -> clauses.forEach(builder::filter), + (p, c) -> parseInnerQueryBuilder(p), + FILTER + ); + PARSER.declareStringArray(StarTreeQueryBuilder::groupby, new ParseField("groupby")); + + } + + private void groupby(List strings) { + groupBy.addAll(strings); + } + + public StarTreeQueryBuilder filter(QueryBuilder queryBuilder) { + if (queryBuilder == null) { + throw new IllegalArgumentException("inner bool query clause cannot be null"); + } + filterClauses.add(queryBuilder); + + for (QueryBuilder filterClause : filterClauses) { + if (filterClause instanceof BoolQueryBuilder) { + BoolQueryBuilder bq = (BoolQueryBuilder) filterClause; + List shouldQbs = bq.should(); + for (QueryBuilder sqb : shouldQbs) { + if (sqb instanceof TermQueryBuilder) { + TermQueryBuilder tq = (TermQueryBuilder) sqb; + String field = tq.fieldName(); + long inputQueryVal = Long.valueOf((String) tq.value()); + List> predicates = predicateMap.getOrDefault(field, new ArrayList<>()); + Predicate predicate = dimVal -> dimVal == inputQueryVal; + predicates.add(predicate); + predicateMap.put(field, predicates); + } + } + } + } + + return this; + } + + public static StarTreeQueryBuilder fromXContent(XContentParser parser) { + try { + return PARSER.apply(parser, null); + } catch (IllegalArgumentException e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + @Override + protected Query doToQuery(QueryShardContext context) { + // TODO : star tree supports either group by or filter + if (predicateMap.size() > 0) { + return new StarTreeQuery(predicateMap, new HashSet<>()); + } + logger.info("Group by : {} ", this.groupBy.toString() ); + return new StarTreeQuery(new HashMap<>(), this.groupBy); + } + + @Override + protected boolean doEquals(StarTreeQueryBuilder other) { + return true; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/server/src/main/java/org/opensearch/search/query/startree/package-info.java b/server/src/main/java/org/opensearch/search/query/startree/package-info.java new file mode 100644 index 0000000000000..93d166e7c0af6 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/query/startree/package-info.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.query.startree;