Skip to content

Commit

Permalink
star tree parsing approach
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeshkr419 committed Jul 18, 2024
1 parent 781a2d4 commit b02365b
Show file tree
Hide file tree
Showing 19 changed files with 1,447 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public class FeatureFlags {
* aggregations.
*/
public static final String STAR_TREE_INDEX = "opensearch.experimental.feature.composite_index.star_tree.enabled";
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, false, Property.NodeScope);
public static final Setting<Boolean> STAR_TREE_INDEX_SETTING = Setting.boolSetting(STAR_TREE_INDEX, true, Property.NodeScope);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.query.startree.StarTreeQuery;
import org.opensearch.transport.RemoteClusterAware;

import java.io.IOException;
Expand Down Expand Up @@ -498,6 +499,12 @@ public boolean indexSortedOnField(String field) {
return indexSortConfig.hasPrimarySortOnField(field);
}

public ParsedQuery toStarTreeQuery(Map<String, List<Predicate<Long>>> compositePredicateMap,
Set<String> groupByColumns) {
StarTreeQuery starTreeQuery = new StarTreeQuery(compositePredicateMap, groupByColumns);
return new ParsedQuery(starTreeQuery);
}

public ParsedQuery toQuery(QueryBuilder queryBuilder) {
return toQuery(queryBuilder, q -> {
Query query = q.toQuery(this);
Expand Down
134 changes: 130 additions & 4 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,20 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.compositeindex.datacube.Metric;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregatorFactory;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeHelper;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.CompositeDataCubeFieldType;
import org.opensearch.index.mapper.CompositeMappedFieldType;
import org.opensearch.index.mapper.DerivedFieldResolver;
import org.opensearch.index.mapper.DerivedFieldResolverFactory;
import org.opensearch.index.mapper.StarTreeMapper;
import org.opensearch.index.query.InnerHitContextBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
Expand All @@ -97,11 +105,20 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregationInitializationException;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.SearchContextAggregations;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MetricsAggregator;
import org.opensearch.search.aggregations.metrics.MinAggregatorFactory;
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.opensearch.search.aggregations.startree.StarTreeAggregator;
import org.opensearch.search.aggregations.startree.StarTreeAggregatorFactory;
import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.collapse.CollapseContext;
import org.opensearch.search.dfs.DfsPhase;
Expand Down Expand Up @@ -148,6 +165,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -158,6 +176,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;

import static org.opensearch.common.unit.TimeValue.timeValueHours;
import static org.opensearch.common.unit.TimeValue.timeValueMillis;
Expand Down Expand Up @@ -1314,20 +1333,29 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.evaluateRequestShouldUseConcurrentSearch();
return;
}

// Can be marked false for majority cases for which star-tree cannot be used
// Will save checking the criteria later and we can have a limit on what search requests are supported
// As we increment the cases where star-tree can be used, this can be set back to true
boolean canUseStarTree = context.mapperService().isCompositeIndexPresent();

SearchShardTarget shardTarget = context.shardTarget();
QueryShardContext queryShardContext = context.getQueryShardContext();
context.from(source.from());
context.size(source.size());
Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
if (source.query() != null) {
canUseStarTree = false;
InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders);
context.parsedQuery(queryShardContext.toQuery(source.query()));
}
if (source.postFilter() != null) {
canUseStarTree = false;
InnerHitContextBuilder.extractInnerHits(source.postFilter(), innerHitBuilders);
context.parsedPostFilter(queryShardContext.toQuery(source.postFilter()));
}
if (innerHitBuilders.size() > 0) {
if (!innerHitBuilders.isEmpty()) {
canUseStarTree = false;
for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
try {
entry.getValue().build(context, context.innerHits());
Expand All @@ -1337,11 +1365,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.sorts() != null) {
canUseStarTree = false;
try {
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(source.sorts(), context.getQueryShardContext());
if (optionalSort.isPresent()) {
context.sort(optionalSort.get());
}
optionalSort.ifPresent(context::sort);
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create sort elements", e);
}
Expand All @@ -1354,9 +1381,11 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
throw new SearchException(shardTarget, "disabling [track_total_hits] is not allowed in a scroll context");
}
if (source.trackTotalHitsUpTo() != null) {
canUseStarTree = false;
context.trackTotalHitsUpTo(source.trackTotalHitsUpTo());
}
if (source.minScore() != null) {
canUseStarTree = false;
context.minimumScore(source.minScore());
}
if (source.timeout() != null) {
Expand All @@ -1372,13 +1401,15 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.suggest() != null) {
canUseStarTree = false;
try {
context.suggest(source.suggest().build(queryShardContext));
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create SuggestionSearchContext", e);
}
}
if (source.rescores() != null) {
canUseStarTree = false;
try {
for (RescorerBuilder<?> rescore : source.rescores()) {
context.addRescore(rescore.buildContext(queryShardContext));
Expand All @@ -1388,12 +1419,15 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.explain() != null) {
canUseStarTree = false;
context.explain(source.explain());
}
if (source.fetchSource() != null) {
canUseStarTree = false;
context.fetchSourceContext(source.fetchSource());
}
if (source.docValueFields() != null) {
canUseStarTree = false;
FetchDocValuesContext docValuesContext = FetchDocValuesContext.create(
context.mapperService()::simpleMatchToFullName,
context.mapperService().getIndexSettings().getMaxDocvalueFields(),
Expand All @@ -1402,10 +1436,12 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
context.docValuesContext(docValuesContext);
}
if (source.fetchFields() != null) {
canUseStarTree = false;
FetchFieldsContext fetchFieldsContext = new FetchFieldsContext(source.fetchFields());
context.fetchFieldsContext(fetchFieldsContext);
}
if (source.highlighter() != null) {
canUseStarTree = false;
HighlightBuilder highlightBuilder = source.highlighter();
try {
context.highlight(highlightBuilder.build(queryShardContext));
Expand All @@ -1414,6 +1450,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.scriptFields() != null && source.size() != 0) {
canUseStarTree = false;
int maxAllowedScriptFields = context.mapperService().getIndexSettings().getMaxScriptFields();
if (source.scriptFields().size() > maxAllowedScriptFields) {
throw new IllegalArgumentException(
Expand All @@ -1439,17 +1476,21 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
if (source.version() != null) {
canUseStarTree = false;
context.version(source.version());
}

if (source.seqNoAndPrimaryTerm() != null) {
// canUseStarTree = false;
context.seqNoAndPrimaryTerm(source.seqNoAndPrimaryTerm());
}

if (source.stats() != null) {
// canUseStarTree = false;
context.groupStats(source.stats());
}
if (CollectionUtils.isEmpty(source.searchAfter()) == false) {
// canUseStarTree = false;
if (context.scrollContext() != null) {
throw new SearchException(shardTarget, "`search_after` cannot be used in a scroll context.");
}
Expand All @@ -1461,13 +1502,15 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}

if (source.slice() != null) {
// canUseStarTree = false;
if (context.scrollContext() == null && !(context.readerContext() instanceof PitReaderContext)) {
throw new SearchException(shardTarget, "`slice` cannot be used outside of a scroll context or PIT context");
}
context.sliceBuilder(source.slice());
}

if (source.storedFields() != null) {
// canUseStarTree = false;
if (source.storedFields().fetchFields() == false) {
if (context.sourceRequested()) {
throw new SearchException(shardTarget, "[stored_fields] cannot be disabled if [_source] is requested");
Expand All @@ -1480,6 +1523,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}

if (source.collapse() != null) {
// canUseStarTree = false;
if (context.scrollContext() != null) {
throw new SearchException(shardTarget, "cannot use `collapse` in a scroll context");
}
Expand All @@ -1496,6 +1540,88 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
if (source.profile()) {
context.setProfilers(new Profilers(context.searcher(), context.shouldUseConcurrentSearch()));
}

if (canUseStarTree) {
try {
if (setStarTreeQuery(context, queryShardContext, source)) {
logger.info("Star Tree will be used in execution");
};
} catch (IOException e) {
logger.info("Cannot use star-tree");
}

}
}

private boolean setStarTreeQuery(SearchContext context, QueryShardContext queryShardContext, SearchSourceBuilder source) throws IOException {

// TODO: (finish)
// 1. Check criteria for star-tree query / aggregation formation
// 2: Set StarTree Query & Star Tree Aggregator here

if (source.aggregations() == null) {
return false;
}

// TODO: Support for multiple startrees
CompositeDataCubeFieldType compositeMappedFieldType = (StarTreeMapper.StarTreeFieldType) context.mapperService().getCompositeFieldTypes().iterator().next();;
List<String> supportedDimensions = new ArrayList<>(compositeMappedFieldType.fields());
Map<String, List<MetricStat>> supportedMetrics = compositeMappedFieldType.getMetrics().stream()
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));

AggregatorFactories defaultFactories = context.aggregations().factories();
if (defaultFactories.countAggregators() != 1) {
return false;
}
AggregatorFactory defaultFactory = defaultFactories.getFactories()[0];
if (!(defaultFactory instanceof TermsAggregatorFactory)) {
return false;
}
String dimension = ((ValuesSourceAggregatorFactory)defaultFactory).getField();
if (!supportedDimensions.contains(dimension)) {
return false;
}

TermsAggregatorFactory termsFactory = (TermsAggregatorFactory) defaultFactory;
AggregatorFactory defaultSubFactory = termsFactory.getSubFactories().getFactories()[0];

String field, metric;
// TODO: abstract the below metric checks


if (defaultSubFactory instanceof MaxAggregatorFactory) {
field = ((ValuesSourceAggregatorFactory)defaultSubFactory).getField();
if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MAX))) {
return false;
}
metric = MetricStat.MAX.name();
} else if (defaultSubFactory instanceof SumAggregatorFactory) {
field = ((SumAggregatorFactory)defaultSubFactory).getField();
if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM))) {
return false;
}
metric = MetricStat.SUM.name();
} else if (defaultSubFactory instanceof MinAggregatorFactory) {
field = ((MinAggregatorFactory)defaultSubFactory).getField();
if (!(supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MIN))) {
return false;
}
metric = MetricStat.MIN.name();
} else {
return false;
}
String metricKey = StarTreeHelper.fullFieldNameForStarTreeMetricsDocValues(compositeMappedFieldType.name(), field, metric.toLowerCase());

ParsedQuery query = queryShardContext.toStarTreeQuery(null, Set.of(dimension));

StarTreeAggregatorFactory factory = new StarTreeAggregatorFactory(defaultFactory.name(), queryShardContext, null, AggregatorFactories.builder(), null, List.of(dimension), List.of(metricKey));
StarTreeAggregatorFactory[] factories = {factory};
AggregatorFactories aggregatorFactories = new AggregatorFactories(factories);

context.parsedQuery(query)
.aggregations(new SearchContextAggregations(aggregatorFactories, multiBucketConsumerService.create()));

return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregator;
import org.opensearch.index.compositeindex.datacube.startree.aggregators.ValueAggregatorFactory;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ public boolean test(AggregatorFactory o) {
}
};

private AggregatorFactory[] factories;
protected AggregatorFactory[] factories;

public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories) {
public AggregatorFactories(AggregatorFactory[] factories) {
this.factories = factories;
}

Expand Down Expand Up @@ -661,4 +661,8 @@ public PipelineTree buildPipelineTree() {
return new PipelineTree(subTrees, aggregators);
}
}

public AggregatorFactory[] getFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ protected boolean supportsConcurrentSegmentSearch() {
public boolean evaluateChildFactories() {
return factories.allFactoriesSupportConcurrentSearch();
}

public AggregatorFactories getSubFactories() {
return factories;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
*
* @opensearch.internal
*/
class MaxAggregatorFactory extends ValuesSourceAggregatorFactory {
public class MaxAggregatorFactory extends ValuesSourceAggregatorFactory {

static void registerAggregators(ValuesSourceRegistry.Builder builder) {
builder.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
*
* @opensearch.internal
*/
class MinAggregatorFactory extends ValuesSourceAggregatorFactory {
public class MinAggregatorFactory extends ValuesSourceAggregatorFactory {

static void registerAggregators(ValuesSourceRegistry.Builder builder) {
builder.register(
Expand Down
Loading

0 comments on commit b02365b

Please sign in to comment.