diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index 9f2c2c5fa8..4a5276418d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -34,6 +34,7 @@ import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Default implementor for implementing logical to physical translation. "Default" here means all @@ -123,6 +124,11 @@ public PhysicalPlan visitLimit(LogicalLimit node, C context) { return new LimitOperator(visitChild(node, context), node.getLimit(), node.getOffset()); } + @Override + public PhysicalPlan visitTableScanBuilder(TableScanBuilder plan, C context) { + return plan.build(); + } + @Override public PhysicalPlan visitRelation(LogicalRelation node, C context) { throw new UnsupportedOperationException("Storage engine is responsible for " diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 28539562e7..0386eb6e2a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -6,6 +6,8 @@ package org.opensearch.sql.planner.logical; +import org.opensearch.sql.storage.read.TableScanBuilder; + /** * The visitor of {@link LogicalPlan}. * @@ -22,6 +24,10 @@ public R visitRelation(LogicalRelation plan, C context) { return visitNode(plan, context); } + public R visitTableScanBuilder(TableScanBuilder plan, C context) { + return visitNode(plan, context); + } + public R visitFilter(LogicalFilter plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 0e547df68d..f241e76993 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -15,6 +15,8 @@ import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; +import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; +import org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown; /** * {@link LogicalPlan} Optimizer. @@ -39,8 +41,21 @@ public LogicalPlanOptimizer(List> rules) { */ public static LogicalPlanOptimizer create() { return new LogicalPlanOptimizer(Arrays.asList( + /* + * Phase 1: Transformations that rely on relational algebra equivalence + */ new MergeFilterAndFilter(), - new PushFilterUnderSort())); + new PushFilterUnderSort(), + /* + * Phase 2: Transformations that rely on data source push down capability + */ + new CreateTableScanBuilder(), + TableScanPushDown.PUSH_DOWN_FILTER, + TableScanPushDown.PUSH_DOWN_AGGREGATION, + TableScanPushDown.PUSH_DOWN_SORT, + TableScanPushDown.PUSH_DOWN_LIMIT, + TableScanPushDown.PUSH_DOWN_HIGHLIGHT, + TableScanPushDown.PUSH_DOWN_PROJECT)); } /** @@ -63,7 +78,14 @@ private LogicalPlan internalOptimize(LogicalPlan plan) { Match match = DEFAULT_MATCHER.match(rule.pattern(), node); if (match.isPresent()) { node = rule.apply(match.value(), match.captures()); - done = false; + + // For new TableScanPushDown impl, pattern match doesn't necessarily cause + // push down to happen. So reiterate all rules against the node only if the node + // is actually replaced by any rule. + // TODO: may need to introduce fixed point or maximum iteration limit in future + if (node != match.value()) { + done = false; + } } } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java index 73d0f8c577..0ba478594a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java @@ -6,10 +6,22 @@ package org.opensearch.sql.planner.optimizer.pattern; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.Property; +import com.facebook.presto.matching.PropertyPattern; import java.util.Optional; import lombok.experimental.UtilityClass; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Pattern helper class. @@ -17,6 +29,55 @@ @UtilityClass public class Patterns { + /** + * Logical filter with a given pattern on inner field. + */ + public static Pattern filter(Pattern pattern) { + return Pattern.typeOf(LogicalFilter.class).with(source(pattern)); + } + + /** + * Logical aggregate operator with a given pattern on inner field. + */ + public static Pattern aggregate(Pattern pattern) { + return Pattern.typeOf(LogicalAggregation.class).with(source(pattern)); + } + + /** + * Logical sort operator with a given pattern on inner field. + */ + public static Pattern sort(Pattern pattern) { + return Pattern.typeOf(LogicalSort.class).with(source(pattern)); + } + + /** + * Logical limit operator with a given pattern on inner field. + */ + public static Pattern limit(Pattern pattern) { + return Pattern.typeOf(LogicalLimit.class).with(source(pattern)); + } + + /** + * Logical highlight operator with a given pattern on inner field. + */ + public static Pattern highlight(Pattern pattern) { + return Pattern.typeOf(LogicalHighlight.class).with(source(pattern)); + } + + /** + * Logical project operator with a given pattern on inner field. + */ + public static Pattern project(Pattern pattern) { + return Pattern.typeOf(LogicalProject.class).with(source(pattern)); + } + + /** + * Pattern for {@link TableScanBuilder} and capture it meanwhile. + */ + public static Pattern scanBuilder() { + return Pattern.typeOf(TableScanBuilder.class).capturedAs(Capture.newCapture()); + } + /** * LogicalPlan source {@link Property}. */ @@ -25,4 +86,28 @@ public static Property source() { ? Optional.of(plan.getChild().get(0)) : Optional.empty()); } + + /** + * Source (children field) with a given pattern. + */ + @SuppressWarnings("unchecked") + public static + PropertyPattern source(Pattern pattern) { + Property property = Property.optionalProperty("source", + plan -> plan.getChild().size() == 1 + ? Optional.of((T) plan.getChild().get(0)) + : Optional.empty()); + + return property.matching(pattern); + } + + /** + * Logical relation with table field. + */ + public static Property table() { + return Property.optionalProperty("table", + plan -> plan instanceof LogicalRelation + ? Optional.of(((LogicalRelation) plan).getTable()) + : Optional.empty()); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java new file mode 100644 index 0000000000..dbe61ca8c3 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule.read; + +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.table; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.optimizer.Rule; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Rule that replace logical relation operator to {@link TableScanBuilder} for later + * push down optimization. All push down optimization rules that depends on table scan + * builder needs to run after this. + */ +public class CreateTableScanBuilder implements Rule { + + /** Capture the table inside matched logical relation operator. */ + private final Capture capture; + + /** Pattern that matches logical relation operator. */ + @Accessors(fluent = true) + @Getter + private final Pattern pattern; + + /** + * Construct create table scan builder rule. + */ + public CreateTableScanBuilder() { + this.capture = Capture.newCapture(); + this.pattern = Pattern.typeOf(LogicalRelation.class) + .with(table().capturedAs(capture)); + } + + @Override + public LogicalPlan apply(LogicalRelation plan, Captures captures) { + TableScanBuilder scanBuilder = captures.get(capture).createScanBuilder(); + // TODO: Remove this after Prometheus refactored to new table scan builder too + return (scanBuilder == null) ? plan : scanBuilder; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java new file mode 100644 index 0000000000..556a12bb34 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule.read; + +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.aggregate; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.filter; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.highlight; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.project; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.scanBuilder; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.sort; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.TableScanPushDownBuilder.match; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.matching.pattern.CapturePattern; +import com.facebook.presto.matching.pattern.WithPattern; +import java.util.function.BiFunction; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.Rule; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Rule template for all table scan push down rules. Because all push down optimization rules + * have similar workflow in common, such as a pattern that match an operator on top of table scan + * builder, and action that eliminates the original operator if pushed down, this class helps + * remove redundant code and improve readability. + * + * @param logical plan node type + */ +public class TableScanPushDown implements Rule { + + /** Push down optimize rule for filtering condition. */ + public static final Rule PUSH_DOWN_FILTER = + match( + filter( + scanBuilder())) + .apply((filter, scanBuilder) -> scanBuilder.pushDownFilter(filter)); + + /** Push down optimize rule for aggregate operator. */ + public static final Rule PUSH_DOWN_AGGREGATION = + match( + aggregate( + scanBuilder())) + .apply((agg, scanBuilder) -> scanBuilder.pushDownAggregation(agg)); + + /** Push down optimize rule for sort operator. */ + public static final Rule PUSH_DOWN_SORT = + match( + sort( + scanBuilder())) + .apply((sort, scanBuilder) -> scanBuilder.pushDownSort(sort)); + + /** Push down optimize rule for limit operator. */ + public static final Rule PUSH_DOWN_LIMIT = + match( + limit( + scanBuilder())) + .apply((limit, scanBuilder) -> scanBuilder.pushDownLimit(limit)); + + public static final Rule PUSH_DOWN_PROJECT = + match( + project( + scanBuilder())) + .apply((project, scanBuilder) -> scanBuilder.pushDownProject(project)); + + public static final Rule PUSH_DOWN_HIGHLIGHT = + match( + highlight( + scanBuilder())) + .apply((highlight, scanBuilder) -> scanBuilder.pushDownHighlight(highlight)); + + + /** Pattern that matches a plan node. */ + private final WithPattern pattern; + + /** Capture table scan builder inside a plan node. */ + private final Capture capture; + + /** Push down function applied to the plan node and captured table scan builder. */ + private final BiFunction pushDownFunction; + + + @SuppressWarnings("unchecked") + private TableScanPushDown(WithPattern pattern, + BiFunction pushDownFunction) { + this.pattern = pattern; + this.capture = ((CapturePattern) pattern.getPattern()).capture(); + this.pushDownFunction = pushDownFunction; + } + + @Override + public Pattern pattern() { + return pattern; + } + + @Override + public LogicalPlan apply(T plan, Captures captures) { + TableScanBuilder scanBuilder = captures.get(capture); + if (pushDownFunction.apply(plan, scanBuilder)) { + return scanBuilder; + } + return plan; + } + + /** + * Custom builder class other than generated by Lombok to provide more readable code. + */ + static class TableScanPushDownBuilder { + + private WithPattern pattern; + + public static + TableScanPushDownBuilder match(Pattern pattern) { + TableScanPushDownBuilder builder = new TableScanPushDownBuilder<>(); + builder.pattern = (WithPattern) pattern; + return builder; + } + + public TableScanPushDown apply( + BiFunction pushDownFunction) { + return new TableScanPushDown<>(pattern, pushDownFunction); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index f43531e2a6..ae0aaaf17b 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -11,6 +11,7 @@ import org.opensearch.sql.executor.streaming.StreamingSource; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Table. @@ -45,7 +46,9 @@ default void create(Map schema) { * * @param plan logical plan * @return physical plan + * @deprecated because of new {@link TableScanBuilder} implementation */ + @Deprecated(since = "2.5.0") PhysicalPlan implement(LogicalPlan plan); /** @@ -54,11 +57,22 @@ default void create(Map schema) { * * @param plan logical plan. * @return logical plan. + * @deprecated because of new {@link TableScanBuilder} implementation */ + @Deprecated(since = "2.5.0") default LogicalPlan optimize(LogicalPlan plan) { return plan; } + /** + * Create table scan builder for logical to physical transformation. + * + * @return table scan builder + */ + default TableScanBuilder createScanBuilder() { + return null; // TODO: Enforce all subclasses to implement this later + } + /** * Translate {@link Table} to {@link StreamingSource} if possible. */ diff --git a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java new file mode 100644 index 0000000000..c0fdf36e70 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.storage.read; + +import java.util.Collections; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * A TableScanBuilder represents transition state between logical planning and physical planning + * for table scan operator. The concrete implementation class gets involved in the logical + * optimization through this abstraction and thus get the chance to handle push down optimization + * without intruding core engine. + */ +public abstract class TableScanBuilder extends LogicalPlan { + + /** + * Construct and initialize children to empty list. + */ + public TableScanBuilder() { + super(Collections.emptyList()); + } + + /** + * Build table scan operator. + * + * @return table scan operator + */ + public abstract TableScanOperator build(); + + /** + * Can a given filter operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param filter logical filter operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownFilter(LogicalFilter filter) { + return false; + } + + /** + * Can a given aggregate operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param aggregation logical aggregate operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownAggregation(LogicalAggregation aggregation) { + return false; + } + + /** + * Can a given sort operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param sort logical sort operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownSort(LogicalSort sort) { + return false; + } + + /** + * Can a given limit operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param limit logical limit operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownLimit(LogicalLimit limit) { + return false; + } + + /** + * Can a given project operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param project logical project operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownProject(LogicalProject project) { + return false; + } + + /** + * Can a given highlight operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param highlight logical highlight operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownHighlight(LogicalHighlight highlight) { + return false; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitTableScanBuilder(this, context); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 3a6a95764c..2322e4684e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -36,6 +36,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; @@ -55,6 +56,8 @@ import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; @ExtendWith(MockitoExtension.class) class DefaultImplementorTest { @@ -197,4 +200,16 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { assertEquals(physicalPlan, logicalPlan.accept(implementor, null)); } + + @Test + public void visitTableScanBuilderShouldBuildTableScanOperator() { + TableScanOperator tableScanOperator = Mockito.mock(TableScanOperator.class); + TableScanBuilder tableScanBuilder = new TableScanBuilder() { + @Override + public TableScanOperator build() { + return tableScanOperator; + } + }; + assertEquals(tableScanOperator, tableScanBuilder.accept(implementor, null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 03eeb9c626..33c6b02c38 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -32,6 +32,8 @@ import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Todo. Temporary added for UT coverage, Will be removed. @@ -72,6 +74,15 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(relation.accept(new LogicalPlanNodeVisitor() { }, null)); + LogicalPlan tableScanBuilder = new TableScanBuilder() { + @Override + public TableScanOperator build() { + return null; + } + }; + assertNull(tableScanBuilder.accept(new LogicalPlanNodeVisitor() { + }, null)); + LogicalPlan filter = LogicalPlanDSL.filter(relation, expression); assertNull(filter.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 9f3035888f..e2510ec464 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -7,29 +7,53 @@ package org.opensearch.sql.planner.optimizer; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.longValue; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.Map; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.opensearch.sql.analysis.AnalyzerTestBase; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.springframework.context.annotation.Configuration; -import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit.jupiter.SpringExtension; - -@Configuration -@ExtendWith(SpringExtension.class) -@ContextConfiguration(classes = {AnalyzerTestBase.class}) -class LogicalPlanOptimizerTest extends AnalyzerTestBase { +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +@ExtendWith(MockitoExtension.class) +class LogicalPlanOptimizerTest { + + @Mock + private Table table; + + @Spy + private TableScanBuilder tableScanBuilder; + + @BeforeEach + void setUp() { + when(table.createScanBuilder()).thenReturn(tableScanBuilder); + } + /** * Filter - Filter --> Filter. */ @@ -37,7 +61,7 @@ class LogicalPlanOptimizerTest extends AnalyzerTestBase { void filter_merge_filter() { assertEquals( filter( - relation("schema", table), + tableScanBuilder, DSL.and(DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(2))), DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))) ), @@ -61,7 +85,7 @@ void push_filter_under_sort() { assertEquals( sort( filter( - relation("schema", table), + tableScanBuilder, DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) @@ -86,7 +110,7 @@ void multiple_filter_should_eventually_be_merged() { assertEquals( sort( filter( - relation("schema", table), + tableScanBuilder, DSL.and(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), DSL.less(DSL.ref("longV", INTEGER), DSL.literal(longValue(1L)))) ), @@ -107,6 +131,145 @@ void multiple_filter_should_eventually_be_merged() { ); } + @Test + void default_table_scan_builder_should_not_push_down_anything() { + LogicalPlan[] plans = { + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + sort( + relation("schema", table), + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), + limit( + relation("schema", table), + 1, 1) + }; + + for (LogicalPlan plan : plans) { + assertEquals(plan, optimize(plan)); + } + } + + @Test + void table_scan_builder_support_project_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownProject(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void table_scan_builder_support_filter_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownFilter(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + ) + ); + } + + @Test + void table_scan_builder_support_aggregation_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownAggregation(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))) + ) + ); + } + + @Test + void table_scan_builder_support_sort_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownSort(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + sort( + relation("schema", table), + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void table_scan_builder_support_limit_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownLimit(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + limit( + relation("schema", table), + 1, 1) + ) + ); + } + + @Test + void table_scan_builder_support_highlight_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownHighlight(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + highlight( + relation("schema", table), + DSL.literal("*"), + Collections.emptyMap()) + ) + ); + } + + @Test + void table_not_support_scan_builder_should_not_be_impact() { + Mockito.reset(table, tableScanBuilder); + Table table = new Table() { + @Override + public Map getFieldTypes() { + return null; + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + return null; + } + }; + + assertEquals( + relation("schema", table), + optimize(relation("schema", table)) + ); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); final LogicalPlan optimize = optimizer.optimize(plan); diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java index ad7c7c50dc..61d192362a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java @@ -13,7 +13,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalPlan; @ExtendWith(MockitoExtension.class) @@ -26,5 +28,12 @@ class PatternsTest { void source_is_empty() { when(plan.getChild()).thenReturn(Collections.emptyList()); assertFalse(Patterns.source().getFunction().apply(plan).isPresent()); + assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()); + } + + @Test + void table_is_empty() { + plan = Mockito.mock(LogicalFilter.class); + assertFalse(Patterns.table().getFunction().apply(plan).isPresent()); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java deleted file mode 100644 index 84bfb47a08..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import com.google.common.collect.ImmutableList; -import java.util.List; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; - -/** - * Logical Index Scan Aggregation Operation. - */ -@Getter -@ToString -@EqualsAndHashCode(callSuper = false) -public class OpenSearchLogicalIndexAgg extends LogicalPlan { - - private final String relationName; - - /** - * Filter Condition. - */ - @Setter - private Expression filter; - - /** - * Aggregation List. - */ - @Setter - private List aggregatorList; - - /** - * Group List. - */ - @Setter - private List groupByList; - - /** - * Sort List. - */ - @Setter - private List> sortList; - - /** - * ElasticsearchLogicalIndexAgg Constructor. - */ - @Builder - public OpenSearchLogicalIndexAgg( - String relationName, - Expression filter, - List aggregatorList, - List groupByList, - List> sortList) { - super(ImmutableList.of()); - this.relationName = relationName; - this.filter = filter; - this.aggregatorList = aggregatorList; - this.groupByList = groupByList; - this.sortList = sortList; - } - - @Override - public R accept(LogicalPlanNodeVisitor visitor, C context) { - return visitor.visitNode(this, context); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java deleted file mode 100644 index d182b5f84d..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import com.google.common.collect.ImmutableList; -import java.util.List; -import java.util.Set; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; - -/** - * OpenSearch Logical Index Scan Operation. - */ -@Getter -@ToString -@EqualsAndHashCode(callSuper = false) -public class OpenSearchLogicalIndexScan extends LogicalPlan { - - /** - * Relation Name. - */ - private final String relationName; - - /** - * Filter Condition. - */ - @Setter - private Expression filter; - - /** - * Projection List. - */ - @Setter - private Set projectList; - - /** - * Sort List. - */ - @Setter - private List> sortList; - - @Setter - private Integer offset; - - @Setter - private Integer limit; - - /** - * ElasticsearchLogicalIndexScan Constructor. - */ - @Builder - public OpenSearchLogicalIndexScan( - String relationName, - Expression filter, - Set projectList, - List> sortList, - Integer limit, Integer offset) { - super(ImmutableList.of()); - this.relationName = relationName; - this.filter = filter; - this.projectList = projectList; - this.sortList = sortList; - this.limit = limit; - this.offset = offset; - } - - @Override - public R accept(LogicalPlanNodeVisitor visitor, C context) { - return visitor.visitNode(this, context); - } - - public boolean hasLimit() { - return limit != null; - } - - /** - * Test has projects or not. - * - * @return true for has projects, otherwise false. - */ - public boolean hasProjects() { - return projectList != null && !projectList.isEmpty(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java deleted file mode 100644 index 77cb6b13bd..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import java.util.Arrays; -import lombok.experimental.UtilityClass; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeAggAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeAggAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeFilterAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeLimitAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeLimitAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.PushProjectAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.PushProjectAndRelation; -import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; - -/** - * OpenSearch storage specified logical plan optimizer. - */ -@UtilityClass -public class OpenSearchLogicalPlanOptimizerFactory { - - /** - * Create OpenSearch storage specified logical plan optimizer. - */ - public static LogicalPlanOptimizer create() { - return new LogicalPlanOptimizer(Arrays.asList( - new MergeFilterAndRelation(), - new MergeAggAndIndexScan(), - new MergeAggAndRelation(), - new MergeSortAndRelation(), - new MergeSortAndIndexScan(), - new MergeSortAndIndexAgg(), - new MergeSortAndIndexScan(), - new MergeLimitAndRelation(), - new MergeLimitAndIndexScan(), - new PushProjectAndRelation(), - new PushProjectAndIndexScan() - )); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java deleted file mode 100644 index 3d4d999d12..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalAggregation; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Aggregation -- Relation to IndexScanAggregation. - */ -public class MergeAggAndIndexScan implements Rule { - - private final Capture capture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndIndexScan. - */ - public MergeAggAndIndexScan() { - this.capture = Capture.newCapture(); - this.pattern = typeOf(LogicalAggregation.class) - .with(source().matching(typeOf(OpenSearchLogicalIndexScan.class) - .matching(indexScan -> !indexScan.hasLimit()) - .capturedAs(capture))); - } - - @Override - public LogicalPlan apply(LogicalAggregation aggregation, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(capture); - return OpenSearchLogicalIndexAgg - .builder() - .relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .aggregatorList(aggregation.getAggregatorList()) - .groupByList(aggregation.getGroupByList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java deleted file mode 100644 index 2e79e7c51a..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.planner.logical.LogicalAggregation; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Aggregation -- Relation to IndexScanAggregation. - */ -public class MergeAggAndRelation implements Rule { - - private final Capture relationCapture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndRelation. - */ - public MergeAggAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalAggregation.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public LogicalPlan apply(LogicalAggregation aggregation, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexAgg - .builder() - .relationName(relation.getRelationName()) - .aggregatorList(aggregation.getAggregatorList()) - .groupByList(aggregation.getGroupByList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java deleted file mode 100644 index 19143c390e..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalFilter; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Filter -- Relation to LogicalIndexScan. - */ -public class MergeFilterAndRelation implements Rule { - - private final Capture relationCapture; - private final Pattern pattern; - - /** - * Constructor of MergeFilterAndRelation. - */ - public MergeFilterAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalFilter.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalFilter filter, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .filter(filter.getCondition()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java deleted file mode 100644 index 9d880bb4dc..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalLimit; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.Rule; - -@Getter -public class MergeLimitAndIndexScan implements Rule { - - private final Capture indexScanCapture; - - @Accessors(fluent = true) - private final Pattern pattern; - - /** - * Constructor of MergeLimitAndIndexScan. - */ - public MergeLimitAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalLimit.class) - .with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class).capturedAs(indexScanCapture))); - } - - @Override - public LogicalPlan apply(LogicalLimit plan, Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - OpenSearchLogicalIndexScan.OpenSearchLogicalIndexScanBuilder builder = - OpenSearchLogicalIndexScan.builder(); - builder.relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .offset(plan.getOffset()) - .limit(plan.getLimit()); - if (indexScan.getSortList() != null) { - builder.sortList(indexScan.getSortList()); - } - return builder.build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java deleted file mode 100644 index 8a170aaa4a..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalLimit; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -@Getter -public class MergeLimitAndRelation implements Rule { - - private final Capture relationCapture; - - @Accessors(fluent = true) - private final Pattern pattern; - - /** - * Constructor of MergeLimitAndRelation. - */ - public MergeLimitAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalLimit.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public LogicalPlan apply(LogicalLimit plan, Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan.builder() - .relationName(relation.getRelationName()) - .offset(plan.getOffset()) - .limit(plan.getLimit()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java deleted file mode 100644 index 57dac4dcf1..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort -- IndexScanAggregation to IndexScanAggregation. - */ -public class MergeSortAndIndexAgg implements Rule { - - private final Capture indexAggCapture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndIndexScan. - */ - public MergeSortAndIndexAgg() { - this.indexAggCapture = Capture.newCapture(); - final AtomicReference sortRef = new AtomicReference<>(); - - this.pattern = typeOf(LogicalSort.class) - .matching(OptimizationRuleUtils::sortByFieldsOnly) - .matching(sort -> { - sortRef.set(sort); - return true; - }) - .with(source().matching(typeOf(OpenSearchLogicalIndexAgg.class) - .matching(indexAgg -> !hasAggregatorInSortBy(sortRef.get(), indexAgg)) - .capturedAs(indexAggCapture))); - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - OpenSearchLogicalIndexAgg indexAgg = captures.get(indexAggCapture); - return OpenSearchLogicalIndexAgg.builder() - .relationName(indexAgg.getRelationName()) - .filter(indexAgg.getFilter()) - .groupByList(indexAgg.getGroupByList()) - .aggregatorList(indexAgg.getAggregatorList()) - .sortList(sort.getSortList()) - .build(); - } - - private boolean hasAggregatorInSortBy(LogicalSort sort, OpenSearchLogicalIndexAgg agg) { - final Set aggregatorNames = - agg.getAggregatorList().stream().map(NamedAggregator::getName).collect(Collectors.toSet()); - for (Pair sortPair : sort.getSortList()) { - if (aggregatorNames.contains(((ReferenceExpression) sortPair.getRight()).getAttr())) { - return true; - } - } - return false; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java deleted file mode 100644 index 337f09308c..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort with IndexScan only when Sort by fields. - */ -public class MergeSortAndIndexScan implements Rule { - - private final Capture indexScanCapture; - private final Pattern pattern; - - /** - * Constructor of MergeSortAndRelation. - */ - public MergeSortAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalSort.class).matching(OptimizationRuleUtils::sortByFieldsOnly) - .with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class).capturedAs(indexScanCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - - return OpenSearchLogicalIndexScan - .builder() - .relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .sortList(mergeSortList(indexScan.getSortList(), sort.getSortList())) - .build(); - } - - private List> mergeSortList(List> l1, List> l2) { - if (null == l1) { - return l2; - } else { - return Stream.concat(l1.stream(), l2.stream()).collect(Collectors.toList()); - } - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java deleted file mode 100644 index 3ba3c7f645..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort with Relation only when Sort by fields. - */ -public class MergeSortAndRelation implements Rule { - - private final Capture relationCapture; - private final Pattern pattern; - - /** - * Constructor of MergeSortAndRelation. - */ - public MergeSortAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalSort.class).matching(OptimizationRuleUtils::sortByFieldsOnly) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .sortList(sort.getSortList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java deleted file mode 100644 index aa1ffa9e4c..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import lombok.experimental.UtilityClass; -import org.opensearch.sql.expression.ExpressionNodeVisitor; -import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.planner.logical.LogicalSort; - -@UtilityClass -public class OptimizationRuleUtils { - - /** - * Does the sort list only contain {@link ReferenceExpression}. - * - * @param logicalSort LogicalSort. - * @return true only contain ReferenceExpression, otherwise false. - */ - public static boolean sortByFieldsOnly(LogicalSort logicalSort) { - return logicalSort.getSortList().stream() - .map(sort -> sort.getRight() instanceof ReferenceExpression) - .reduce(true, Boolean::logicalAnd); - } - - /** - * Find reference expression from expression. - * @param expressions a list of expression. - * - * @return a list of ReferenceExpression - */ - public static Set findReferenceExpressions( - List expressions) { - Set projectList = new HashSet<>(); - for (NamedExpression namedExpression : expressions) { - projectList.addAll(findReferenceExpression(namedExpression)); - } - return projectList; - } - - /** - * Find reference expression from expression. - * @param expression expression. - * - * @return a list of ReferenceExpression - */ - public static List findReferenceExpression( - NamedExpression expression) { - List results = new ArrayList<>(); - expression.accept(new ExpressionNodeVisitor() { - @Override - public Object visitReference(ReferenceExpression node, Object context) { - return results.add(node); - } - }, null); - return results; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java deleted file mode 100644 index 43714282fb..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.opensearch.planner.logical.rule.OptimizationRuleUtils.findReferenceExpressions; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalProject; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Push Project list into ElasticsearchLogicalIndexScan. - */ -public class PushProjectAndIndexScan implements Rule { - - private final Capture indexScanCapture; - - private final Pattern pattern; - - private Set pushDownProjects; - - /** - * Constructor of MergeProjectAndIndexScan. - */ - public PushProjectAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalProject.class).matching( - project -> { - pushDownProjects = findReferenceExpressions(project.getProjectList()); - return !pushDownProjects.isEmpty(); - }).with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class) - .matching(indexScan -> !indexScan.hasProjects()) - .capturedAs(indexScanCapture))); - - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalProject project, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - indexScan.setProjectList(pushDownProjects); - return new LogicalProject(indexScan, project.getProjectList(), - project.getNamedParseExpressions()); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java deleted file mode 100644 index a29a1df466..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.opensearch.planner.logical.rule.OptimizationRuleUtils.findReferenceExpressions; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalProject; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Push Project list into Relation. The transformed plan is Project - IndexScan - */ -public class PushProjectAndRelation implements Rule { - - private final Capture relationCapture; - - private final Pattern pattern; - - private Set pushDownProjects; - - /** - * Constructor of MergeProjectAndRelation. - */ - public PushProjectAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalProject.class) - .matching(project -> { - pushDownProjects = findReferenceExpressions(project.getProjectList()); - return !pushDownProjects.isEmpty(); - }) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalProject project, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return new LogicalProject( - OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .projectList(findReferenceExpressions(project.getProjectList())) - .build(), - project.getProjectList(), - project.getNamedParseExpressions() - ); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index c26413c622..439a970a4f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -9,6 +9,8 @@ import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; +import com.google.common.collect.Lists; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -24,7 +26,9 @@ import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; @@ -158,6 +162,11 @@ public void pushDownAggregation( * @param sortBuilders sortBuilders. */ public void pushDownSort(List> sortBuilders) { + // TODO: Sort by _doc is added when filter push down. Remove both logic once doctest fixed. + if (isSortByDocOnly()) { + sourceBuilder.sorts().clear(); + } + for (SortBuilder sortBuilder : sortBuilders) { sourceBuilder.sort(sortBuilder); } @@ -220,4 +229,12 @@ public void pushTypeMapping(Map typeMapping) { private boolean isBoolFilterQuery(QueryBuilder current) { return (current instanceof BoolQueryBuilder); } + + private boolean isSortByDocOnly() { + List> sorts = sourceBuilder.sorts(); + if (sorts != null) { + return sorts.equals(Arrays.asList(SortBuilders.fieldSort(DOC_FIELD_NAME))); + } + return false; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java index 00e8a5154c..7459300caa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java @@ -18,12 +18,14 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; /** * Composite Aggregation Parser which include composite aggregation and metric parsers. */ +@EqualsAndHashCode public class CompositeAggregationParser implements OpenSearchAggregationResponseParser { private final MetricParserHelper metricsParser; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java index cfcba82c18..8358379be0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java @@ -15,6 +15,7 @@ import java.util.Map; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.bucket.filter.Filter; @@ -25,6 +26,7 @@ * do nothing and return the result from metricsParser. */ @Builder +@EqualsAndHashCode public class FilterParser implements MetricParser { private final MetricParser metricsParser; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java index 54b9305f49..d5c0141ad2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.Aggregations; @@ -25,6 +26,7 @@ /** * Parse multiple metrics in one bucket. */ +@EqualsAndHashCode @RequiredArgsConstructor public class MetricParserHelper { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java index 88d9604137..384e07ad8f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java @@ -17,6 +17,7 @@ import java.util.Collections; import java.util.Map; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -25,6 +26,7 @@ /** * {@link NumericMetricsAggregation.SingleValue} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class SingleValueParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java index 5928b7efc9..c80b75de05 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.Map; import java.util.function.Function; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -26,6 +27,7 @@ /** * {@link ExtendedStats} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class StatsParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java index 4a3a346a84..a98e1b4ce3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -18,6 +19,7 @@ /** * {@link TopHits} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class TopHitsParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 26082abed1..c694769b89 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -8,41 +8,27 @@ import com.google.common.annotations.VisibleForTesting; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; -import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; -import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; -import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; -import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; /** OpenSearch table (index) implementation. */ public class OpenSearchIndex implements Table { @@ -122,98 +108,30 @@ public Integer getMaxResultWindow() { */ @Override public PhysicalPlan implement(LogicalPlan plan) { - OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, - getMaxResultWindow(), new OpenSearchExprValueFactory(getFieldTypes())); - - /* - * Visit logical plan with index scan as context so logical operators visited, such as - * aggregation, filter, will accumulate (push down) OpenSearch query and aggregation DSL on - * index scan. - */ - return plan.accept(new OpenSearchDefaultImplementor(indexScan, client), indexScan); + // TODO: Leave it here to avoid impact Prometheus and AD operators. Need to move to Planner. + return plan.accept(new OpenSearchDefaultImplementor(client), null); } @Override public LogicalPlan optimize(LogicalPlan plan) { - return OpenSearchLogicalPlanOptimizerFactory.create().optimize(plan); + // No-op because optimization already done in Planner + return plan; + } + + @Override + public TableScanBuilder createScanBuilder() { + OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, + getMaxResultWindow(), new OpenSearchExprValueFactory(getFieldTypes())); + return new OpenSearchIndexScanBuilder(indexScan); } @VisibleForTesting @RequiredArgsConstructor public static class OpenSearchDefaultImplementor extends DefaultImplementor { - private final OpenSearchIndexScan indexScan; private final OpenSearchClient client; - @Override - public PhysicalPlan visitNode(LogicalPlan plan, OpenSearchIndexScan context) { - if (plan instanceof OpenSearchLogicalIndexScan) { - return visitIndexScan((OpenSearchLogicalIndexScan) plan, context); - } else if (plan instanceof OpenSearchLogicalIndexAgg) { - return visitIndexAggregation((OpenSearchLogicalIndexAgg) plan, context); - } else { - throw new IllegalStateException(StringUtils.format("unexpected plan node type %s", - plan.getClass())); - } - } - - /** - * Implement ElasticsearchLogicalIndexScan. - */ - public PhysicalPlan visitIndexScan(OpenSearchLogicalIndexScan node, - OpenSearchIndexScan context) { - if (null != node.getSortList()) { - final SortQueryBuilder builder = new SortQueryBuilder(); - context.getRequestBuilder().pushDownSort(node.getSortList().stream() - .map(sort -> builder.build(sort.getValue(), sort.getKey())) - .collect(Collectors.toList())); - } - - if (null != node.getFilter()) { - FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); - QueryBuilder query = queryBuilder.build(node.getFilter()); - context.getRequestBuilder().pushDown(query); - } - - if (node.getLimit() != null) { - context.getRequestBuilder().pushDownLimit(node.getLimit(), node.getOffset()); - } - - if (node.hasProjects()) { - context.getRequestBuilder().pushDownProjects(node.getProjectList()); - } - return indexScan; - } - - /** - * Implement ElasticsearchLogicalIndexAgg. - */ - public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, - OpenSearchIndexScan context) { - if (node.getFilter() != null) { - FilterQueryBuilder queryBuilder = new FilterQueryBuilder( - new DefaultExpressionSerializer()); - QueryBuilder query = queryBuilder.build(node.getFilter()); - context.getRequestBuilder().pushDown(query); - } - AggregationQueryBuilder builder = - new AggregationQueryBuilder(new DefaultExpressionSerializer()); - Pair, OpenSearchAggregationResponseParser> aggregationBuilder = - builder.buildAggregationBuilder(node.getAggregatorList(), - node.getGroupByList(), node.getSortList()); - context.getRequestBuilder().pushDownAggregation(aggregationBuilder); - context.getRequestBuilder().pushTypeMapping( - builder.buildTypeMapping(node.getAggregatorList(), - node.getGroupByList())); - return indexScan; - } - - @Override - public PhysicalPlan visitRelation(LogicalRelation node, OpenSearchIndexScan context) { - return indexScan; - } - @Override public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan context) { return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), @@ -231,12 +149,5 @@ public PhysicalPlan visitML(LogicalML node, OpenSearchIndexScan context) { return new MLOperator(visitChild(node, context), node.getArguments(), client.getNodeClient()); } - - @Override - public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { - context.getRequestBuilder().pushDownHighlight( - StringUtils.unquoteText(node.getHighlightField().toString()), node.getArguments()); - return visitChild(node, context); - } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java new file mode 100644 index 0000000000..e52fc566cd --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; +import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Index scan builder for aggregate query used by {@link OpenSearchIndexScanBuilder} internally. + */ +class OpenSearchIndexScanAggregationBuilder extends TableScanBuilder { + + /** OpenSearch index scan to be optimized. */ + private final OpenSearchIndexScan indexScan; + + /** Aggregators pushed down. */ + private List aggregatorList; + + /** Grouping items pushed down. */ + private List groupByList; + + /** Sorting items pushed down. */ + private List> sortList; + + /** + * Initialize with given index scan and perform push-down optimization later. + * + * @param indexScan index scan not fully optimized yet + */ + OpenSearchIndexScanAggregationBuilder(OpenSearchIndexScan indexScan) { + this.indexScan = indexScan; + } + + @Override + public TableScanOperator build() { + AggregationQueryBuilder builder = + new AggregationQueryBuilder(new DefaultExpressionSerializer()); + Pair, OpenSearchAggregationResponseParser> aggregationBuilder = + builder.buildAggregationBuilder(aggregatorList, groupByList, sortList); + indexScan.getRequestBuilder().pushDownAggregation(aggregationBuilder); + indexScan.getRequestBuilder().pushTypeMapping( + builder.buildTypeMapping(aggregatorList, groupByList)); + return indexScan; + } + + @Override + public boolean pushDownAggregation(LogicalAggregation aggregation) { + aggregatorList = aggregation.getAggregatorList(); + groupByList = aggregation.getGroupByList(); + return true; + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + if (hasAggregatorInSortBy(sort)) { + return false; + } + + sortList = sort.getSortList(); + return true; + } + + private boolean hasAggregatorInSortBy(LogicalSort sort) { + final Set aggregatorNames = + aggregatorList.stream().map(NamedAggregator::getName).collect(Collectors.toSet()); + for (Pair sortPair : sort.getSortList()) { + if (aggregatorNames.contains(((ReferenceExpression) sortPair.getRight()).getAttr())) { + return true; + } + } + return false; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java new file mode 100644 index 0000000000..d7483cfcf0 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import com.google.common.annotations.VisibleForTesting; +import lombok.EqualsAndHashCode; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Table scan builder that builds table scan operator for OpenSearch. The actual work is performed + * by delegated builder internally. This is to avoid conditional check of different push down logic + * for non-aggregate and aggregate query everywhere. + */ +public class OpenSearchIndexScanBuilder extends TableScanBuilder { + + /** + * Delegated index scan builder for non-aggregate or aggregate query. + */ + @EqualsAndHashCode.Include + private TableScanBuilder delegate; + + /** Is limit operator pushed down. */ + private boolean isLimitPushedDown = false; + + @VisibleForTesting + OpenSearchIndexScanBuilder(TableScanBuilder delegate) { + this.delegate = delegate; + } + + /** + * Initialize with given index scan. + * + * @param indexScan index scan to optimize + */ + public OpenSearchIndexScanBuilder(OpenSearchIndexScan indexScan) { + this.delegate = new OpenSearchIndexScanQueryBuilder(indexScan); + } + + @Override + public TableScanOperator build() { + return delegate.build(); + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + return delegate.pushDownFilter(filter); + } + + @Override + public boolean pushDownAggregation(LogicalAggregation aggregation) { + if (isLimitPushedDown) { + return false; + } + + // Switch to builder for aggregate query which has different push down logic + // for later filter, sort and limit operator. + delegate = new OpenSearchIndexScanAggregationBuilder( + (OpenSearchIndexScan) delegate.build()); + + return delegate.pushDownAggregation(aggregation); + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + if (!sortByFieldsOnly(sort)) { + return false; + } + return delegate.pushDownSort(sort); + } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + // Assume limit push down happening on OpenSearchIndexScanQueryBuilder + isLimitPushedDown = true; + return delegate.pushDownLimit(limit); + } + + @Override + public boolean pushDownProject(LogicalProject project) { + return delegate.pushDownProject(project); + } + + @Override + public boolean pushDownHighlight(LogicalHighlight highlight) { + return delegate.pushDownHighlight(highlight); + } + + private boolean sortByFieldsOnly(LogicalSort sort) { + return sort.getSortList().stream() + .map(sortItem -> sortItem.getRight() instanceof ReferenceExpression) + .reduce(true, Boolean::logicalAnd); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java new file mode 100644 index 0000000000..7190d58000 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; +import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; +import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Index scan builder for simple non-aggregate query used by + * {@link OpenSearchIndexScanBuilder} internally. + */ +@VisibleForTesting +class OpenSearchIndexScanQueryBuilder extends TableScanBuilder { + + /** OpenSearch index scan to be optimized. */ + @EqualsAndHashCode.Include + private final OpenSearchIndexScan indexScan; + + /** + * Initialize with given index scan and perform push-down optimization later. + * + * @param indexScan index scan not optimized yet + */ + OpenSearchIndexScanQueryBuilder(OpenSearchIndexScan indexScan) { + this.indexScan = indexScan; + } + + @Override + public TableScanOperator build() { + return indexScan; + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + FilterQueryBuilder queryBuilder = new FilterQueryBuilder( + new DefaultExpressionSerializer()); + QueryBuilder query = queryBuilder.build(filter.getCondition()); + indexScan.getRequestBuilder().pushDown(query); + return true; + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + List> sortList = sort.getSortList(); + final SortQueryBuilder builder = new SortQueryBuilder(); + indexScan.getRequestBuilder().pushDownSort(sortList.stream() + .map(sortItem -> builder.build(sortItem.getValue(), sortItem.getKey())) + .collect(Collectors.toList())); + return true; + } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + indexScan.getRequestBuilder().pushDownLimit(limit.getLimit(), limit.getOffset()); + return true; + } + + @Override + public boolean pushDownProject(LogicalProject project) { + indexScan.getRequestBuilder().pushDownProjects( + findReferenceExpressions(project.getProjectList())); + + // Return false intentionally to keep the original project operator + return false; + } + + @Override + public boolean pushDownHighlight(LogicalHighlight highlight) { + indexScan.getRequestBuilder().pushDownHighlight( + StringUtils.unquoteText(highlight.getHighlightField().toString()), + highlight.getArguments()); + return true; + } + + /** + * Find reference expression from expression. + * @param expressions a list of expression. + * + * @return a list of ReferenceExpression + */ + public static Set findReferenceExpressions( + List expressions) { + Set projectList = new HashSet<>(); + for (NamedExpression namedExpression : expressions) { + projectList.addAll(findReferenceExpression(namedExpression)); + } + return projectList; + } + + /** + * Find reference expression from expression. + * @param expression expression. + * + * @return a list of ReferenceExpression + */ + public static List findReferenceExpression(NamedExpression expression) { + List results = new ArrayList<>(); + expression.accept(new ExpressionNodeVisitor<>() { + @Override + public Object visitReference(ReferenceExpression node, Object context) { + return results.add(node); + } + }, null); + return results; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java deleted file mode 100644 index 31ad2b2ee3..0000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java +++ /dev/null @@ -1,576 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.data.type.ExprCoreType.LONG; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.opensearch.utils.Utils.indexScan; -import static org.opensearch.sql.opensearch.utils.Utils.indexScanAgg; -import static org.opensearch.sql.opensearch.utils.Utils.noProjects; -import static org.opensearch.sql.opensearch.utils.Utils.projects; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.opensearch.utils.Utils; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; -import org.opensearch.sql.storage.Table; - -@ExtendWith(MockitoExtension.class) -class OpenSearchLogicOptimizerTest { - - @Mock - private Table table; - - /** - * SELECT intV as i FROM schema WHERE intV = 1. - */ - @Test - void project_filter_merge_with_relation() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - ImmutableSet.of(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - DSL.named("i", DSL.ref("intV", INTEGER))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY string_value. - */ - @Test - void aggregation_merge_relation() { - assertEquals( - project( - indexScanAgg("schema", ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema WHERE intV = 1 GROUP BY string_value. - */ - @Test - void aggregation_merge_filter_relation() { - assertEquals( - project( - indexScanAgg("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - @Disabled("This test should be enabled once https://github.com/opensearch-project/sql/issues/912 is fixed") - @Test - void aggregation_cant_merge_indexScan_with_project() { - assertEquals( - aggregation( - OpenSearchLogicalIndexScan.builder().relationName("schema") - .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) - .projectList(ImmutableSet.of(DSL.ref("intV", INTEGER))) - .build(), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - optimize( - aggregation( - OpenSearchLogicalIndexScan.builder().relationName("schema") - .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) - .projectList( - ImmutableSet.of(DSL.ref("intV", INTEGER))) - .build(), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG)))))) - ); - } - - /** - * Sort - Relation --> IndexScan. - */ - @Test - void sort_merge_with_relation() { - assertEquals( - indexScan("schema", Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), - optimize( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - /** - * Sort - IndexScan --> IndexScan. - */ - @Test - void sort_merge_with_indexScan() { - assertEquals( - indexScan("schema", - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG))), - optimize( - sort( - indexScan("schema", Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ) - ) - ); - } - - /** - * Sort - Filter - Relation --> IndexScan. - */ - @Test - void sort_filter_merge_with_relation() { - assertEquals( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ), - optimize( - sort( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ) - ) - ); - } - - @Test - void sort_with_expression_cannot_merge_with_relation() { - assertEquals( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) - ), - optimize( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) - ) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. - */ - @Test - void sort_merge_indexagg() { - assertEquals( - project( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList - .of(Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("stringV", STRING)))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - sort( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("stringV", STRING)) - ), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. - */ - @Test - void sort_merge_indexagg_nulls_last() { - assertEquals( - project( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList - .of(Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - sort( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)) - ), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - - /** - * Can't Optimize the following query. - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY avg(intV). - */ - @Test - void sort_refer_to_aggregator_should_not_merge_with_indexAgg() { - assertEquals( - sort( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("AVG(intV)", INTEGER)) - ), - optimize( - sort( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("AVG(intV)", INTEGER)) - ) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV ASC NULL_LAST. - */ - @Test - void sort_with_customized_option_should_merge_with_indexAgg() { - assertEquals( - indexScanAgg( - "schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList.of( - Pair.of( - new Sort.SortOption(Sort.SortOrder.ASC, Sort.NullOrder.NULL_LAST), - DSL.ref("stringV", STRING)))), - optimize( - sort( - indexScanAgg( - "schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of( - new Sort.SortOption(Sort.SortOrder.ASC, Sort.NullOrder.NULL_LAST), - DSL.ref("stringV", STRING))))); - } - - @Test - void limit_merge_with_relation() { - assertEquals( - project( - indexScan("schema", 1, 1, projects(DSL.ref("intV", INTEGER))), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - relation("schema", table), - 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void limit_merge_with_index_scan() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - 1, 1, - projects(DSL.ref("intV", INTEGER)) - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER))) - ) - ); - } - - @Test - void limit_merge_with_index_scan_sort() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - 1, 1, - Utils.sort(DSL.ref("longV", LONG), Sort.SortOption.DEFAULT_ASC), - projects(DSL.ref("intV", INTEGER)) - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - sort( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ), 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void aggregation_cant_merge_index_scan_with_limit() { - assertEquals( - project( - aggregation( - indexScan("schema", 10, 0, noProjects()), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - indexScan("schema", 10, 0, noProjects()), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))))); - } - - @Test - void push_down_projectList_to_relation() { - assertEquals( - project( - indexScan("schema", projects(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER))) - ) - ); - } - - /** - * Project(intV, abs(intV)) -> Relation. - * -- will be optimized as - * Project(intV, abs(intV)) -> Relation(project=intV). - */ - @Test - void push_down_should_handle_duplication() { - assertEquals( - project( - indexScan("schema", projects(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("absi", DSL.abs(DSL.ref("intV", INTEGER))) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("absi", DSL.abs(DSL.ref("intV", INTEGER)))) - ) - ); - } - - /** - * Project(ListA) -> Project(ListB) -> Relation. - * -- will be optimized as - * Project(ListA) -> Project(ListB) -> Relation(project=ListB). - */ - @Test - void only_one_project_should_be_push() { - assertEquals( - project( - project( - indexScan("schema", - projects(DSL.ref("intV", INTEGER), DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("s", DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("s", DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void project_literal_no_push() { - assertEquals( - project( - relation("schema", table), - DSL.named("i", DSL.literal("str")) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.literal("str")) - ) - ) - ); - } - - /** - * SELECT AVG(intV) FILTER(WHERE intV > 1) FROM schema GROUP BY stringV. - */ - @Test - void filter_aggregation_merge_relation() { - assertEquals( - project( - indexScanAgg("schema", ImmutableList.of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))), - optimize( - project( - aggregation( - relation("schema", table), - ImmutableList.of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT AVG(intV) FILTER(WHERE intV > 1) FROM schema WHERE longV < 1 GROUP BY stringV. - */ - @Test - void filter_aggregation_merge_filter_relation() { - assertEquals( - project( - indexScanAgg("schema", - DSL.less(DSL.ref("longV", LONG), DSL.literal(1)), - ImmutableList.of(DSL.named("avg(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))), - optimize( - project( - aggregation( - filter( - relation("schema", table), - DSL.less(DSL.ref("longV", LONG), DSL.literal(1)) - ), - ImmutableList.of(DSL.named("avg(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))) - ) - ); - } - - private LogicalPlan optimize(LogicalPlan plan) { - final LogicalPlanOptimizer optimizer = OpenSearchLogicalPlanOptimizerFactory.create(); - final LogicalPlan optimize = optimizer.optimize(plan); - return optimize; - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java deleted file mode 100644 index 2e10f33787..0000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import static org.junit.jupiter.api.Assertions.assertFalse; - -import com.google.common.collect.ImmutableSet; -import org.junit.jupiter.api.Test; - -class OpenSearchLogicalIndexScanTest { - - @Test - void has_projects() { - assertFalse(OpenSearchLogicalIndexScan.builder() - .projectList(ImmutableSet.of()).build() - .hasProjects()); - - assertFalse(OpenSearchLogicalIndexScan.builder().build().hasProjects()); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 43b9353190..33376ece83 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -7,41 +7,70 @@ package org.opensearch.sql.opensearch.request; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; +import static org.opensearch.search.sort.SortOrder.ASC; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.ScoreSortBuilder; +import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; @ExtendWith(MockitoExtension.class) public class OpenSearchRequestBuilderTest { - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + private static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + private static final Integer DEFAULT_OFFSET = 0; + private static final Integer DEFAULT_LIMIT = 200; + private static final Integer MAX_RESULT_WINDOW = 500; + @Mock private Settings settings; @Mock - private OpenSearchExprValueFactory factory; + private OpenSearchExprValueFactory exprValueFactory; + + private OpenSearchRequestBuilder requestBuilder; @BeforeEach void setup() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + + requestBuilder = new OpenSearchRequestBuilder( + "test", MAX_RESULT_WINDOW, settings, exprValueFactory); } @Test void buildQueryRequest() { - Integer maxResultWindow = 500; Integer limit = 200; Integer offset = 0; - OpenSearchRequestBuilder builder = - new OpenSearchRequestBuilder("test", maxResultWindow, settings, factory); - builder.pushDownLimit(limit, offset); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchQueryRequest( @@ -50,27 +79,145 @@ void buildQueryRequest() { .from(offset) .size(limit) .timeout(DEFAULT_QUERY_TIMEOUT), - factory), - builder.build()); + exprValueFactory), + requestBuilder.build()); } @Test void buildScrollRequestWithCorrectSize() { - Integer maxResultWindow = 500; Integer limit = 800; Integer offset = 10; - OpenSearchRequestBuilder builder = - new OpenSearchRequestBuilder("test", maxResultWindow, settings, factory); - builder.pushDownLimit(limit, offset); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), new SearchSourceBuilder() .from(offset) - .size(maxResultWindow - offset) + .size(MAX_RESULT_WINDOW - offset) .timeout(DEFAULT_QUERY_TIMEOUT), - factory), - builder.build()); + exprValueFactory), + requestBuilder.build()); + } + + @Test + void testPushDownQuery() { + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDown(query); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(DOC_FIELD_NAME, ASC), + requestBuilder.getSourceBuilder() + ); + } + + @Test + void testPushDownAggregation() { + AggregationBuilder aggBuilder = AggregationBuilders.composite( + "composite_buckets", + Collections.singletonList(new TermsValuesSourceBuilder("longA"))); + OpenSearchAggregationResponseParser responseParser = + new CompositeAggregationParser( + new SingleValueParser("AVG(intA)")); + requestBuilder.pushDownAggregation(Pair.of(List.of(aggBuilder), responseParser)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(0) + .timeout(DEFAULT_QUERY_TIMEOUT) + .aggregation(aggBuilder), + requestBuilder.getSourceBuilder() + ); + verify(exprValueFactory).setParser(responseParser); + } + + @Test + void testPushDownQueryAndSort() { + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDown(query); + + FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownSort() { + FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownNonFieldSort() { + ScoreSortBuilder sortBuilder = SortBuilders.scoreSort(); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownMultipleSort() { + requestBuilder.pushDownSort(List.of( + SortBuilders.fieldSort("intA"), + SortBuilders.fieldSort("intB"))); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(SortBuilders.fieldSort("intA")) + .sort(SortBuilders.fieldSort("intB")), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownProject() { + Set references = Set.of(DSL.ref("intA", INTEGER)); + requestBuilder.pushDownProjects(references); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .fetchSource(new String[]{"intA"}, new String[0]), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushTypeMapping() { + Map typeMapping = Map.of("intA", INTEGER); + requestBuilder.pushTypeMapping(typeMapping); + + verify(exprValueFactory).setTypeMapping(typeMapping); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index a74c5fcbd4..d7e5955491 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -6,12 +6,7 @@ package org.opensearch.sql.opensearch.storage; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -20,9 +15,7 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAD; -import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -31,41 +24,20 @@ @ExtendWith(MockitoExtension.class) public class OpenSearchDefaultImplementorTest { - @Mock - OpenSearchIndexScan indexScan; @Mock OpenSearchClient client; @Mock Table table; - /** - * For test coverage. - */ - @Test - public void visitInvalidTypeShouldThrowException() { - final OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - - final IllegalStateException exception = - assertThrows(IllegalStateException.class, - () -> implementor.visitNode(relation("index", table), - indexScan)); - ; - assertEquals( - "unexpected plan node type " - + "class org.opensearch.sql.planner.logical.LogicalRelation", - exception.getMessage()); - } - @Test public void visitMachineLearning() { LogicalMLCommons node = Mockito.mock(LogicalMLCommons.class, Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitMLCommons(node, indexScan)); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitMLCommons(node, null)); } @Test @@ -74,8 +46,8 @@ public void visitAD() { Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitAD(node, indexScan)); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitAD(node, null)); } @Test @@ -84,21 +56,7 @@ public void visitML() { Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitML(node, indexScan)); - } - - @Test - public void visitHighlight() { - LogicalHighlight node = Mockito.mock(LogicalHighlight.class, - Answers.RETURNS_DEEP_STUBS); - Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); - OpenSearchRequestBuilder requestBuilder = Mockito.mock(OpenSearchRequestBuilder.class); - Mockito.when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); - OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - - implementor.visitHighlight(node, indexScan); - verify(requestBuilder).pushDownHighlight(any(), any()); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitML(node, null)); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 9e375aa1b0..74c18f7c3d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -9,8 +9,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.arrayContaining; -import static org.hamcrest.Matchers.emptyArray; import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -23,14 +21,7 @@ import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD; -import static org.opensearch.sql.opensearch.utils.Utils.indexScan; -import static org.opensearch.sql.opensearch.utils.Utils.indexScanAgg; -import static org.opensearch.sql.opensearch.utils.Utils.noProjects; -import static org.opensearch.sql.opensearch.utils.Utils.projects; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.eval; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.remove; @@ -49,13 +40,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; @@ -67,12 +56,7 @@ import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; -import org.opensearch.sql.planner.physical.AggregationOperator; -import org.opensearch.sql.planner.physical.FilterOperator; -import org.opensearch.sql.planner.physical.LimitOperator; -import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; -import org.opensearch.sql.planner.physical.ProjectOperator; import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) @@ -144,25 +128,28 @@ void getFieldTypes() { .put("blob", "binary") .build()))); - Map fieldTypes = index.getFieldTypes(); - assertThat( - fieldTypes, - allOf( - aMapWithSize(13), - hasEntry("name", ExprCoreType.STRING), - hasEntry("address", (ExprType) OpenSearchDataType.OPENSEARCH_TEXT), - hasEntry("age", ExprCoreType.INTEGER), - hasEntry("account_number", ExprCoreType.LONG), - hasEntry("balance1", ExprCoreType.FLOAT), - hasEntry("balance2", ExprCoreType.DOUBLE), - hasEntry("gender", ExprCoreType.BOOLEAN), - hasEntry("family", ExprCoreType.ARRAY), - hasEntry("employer", ExprCoreType.STRUCT), - hasEntry("birthday", ExprCoreType.TIMESTAMP), - hasEntry("id1", ExprCoreType.BYTE), - hasEntry("id2", ExprCoreType.SHORT), - hasEntry("blob", (ExprType) OpenSearchDataType.OPENSEARCH_BINARY) - )); + // Run more than once to confirm caching logic is covered and can work + for (int i = 0; i < 2; i++) { + Map fieldTypes = index.getFieldTypes(); + assertThat( + fieldTypes, + allOf( + aMapWithSize(13), + hasEntry("name", ExprCoreType.STRING), + hasEntry("address", (ExprType) OpenSearchDataType.OPENSEARCH_TEXT), + hasEntry("age", ExprCoreType.INTEGER), + hasEntry("account_number", ExprCoreType.LONG), + hasEntry("balance1", ExprCoreType.FLOAT), + hasEntry("balance2", ExprCoreType.DOUBLE), + hasEntry("gender", ExprCoreType.BOOLEAN), + hasEntry("family", ExprCoreType.ARRAY), + hasEntry("employer", ExprCoreType.STRUCT), + hasEntry("birthday", ExprCoreType.TIMESTAMP), + hasEntry("id1", ExprCoreType.BYTE), + hasEntry("id2", ExprCoreType.SHORT), + hasEntry("blob", (ExprType) OpenSearchDataType.OPENSEARCH_BINARY) + )); + } } @Test @@ -170,7 +157,7 @@ void implementRelationOperatorOnly() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - LogicalPlan plan = relation(indexName, table); + LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), @@ -182,7 +169,7 @@ void implementRelationOperatorWithOptimization() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - LogicalPlan plan = relation(indexName, table); + LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), @@ -217,7 +204,7 @@ void implementOtherLogicalOperators() { eval( remove( rename( - relation(indexName, table), + index.createScanBuilder(), mappings), exclude), newEvalField), @@ -243,214 +230,4 @@ void implementOtherLogicalOperators() { include), index.implement(plan)); } - - @Test - void shouldImplLogicalIndexScan() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression filterExpr = DSL.equal(field, literal("John")); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - filterExpr - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownFilterFarFromRelation() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - PhysicalPlan plan = index.implement( - filter( - aggregation( - relation(indexName, table), - aggregators, - groupByExprs - ), - filterExpr)); - - assertTrue(plan instanceof FilterOperator); - } - - @Test - void shouldImplLogicalIndexScanAgg() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - // IndexScanAgg without Filter - PhysicalPlan plan = index.implement( - filter( - indexScanAgg( - indexName, - aggregators, - groupByExprs - ), - filterExpr)); - - assertTrue(plan.getChild().get(0) instanceof OpenSearchIndexScan); - - // IndexScanAgg with Filter - plan = index.implement( - indexScanAgg( - indexName, - filterExpr, - aggregators, - groupByExprs)); - assertTrue(plan instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownAggregationFarFromRelation() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - PhysicalPlan plan = index.implement( - aggregation( - filter(filter( - relation(indexName, table), - filterExpr), filterExpr), - aggregators, - groupByExprs)); - assertTrue(plan instanceof AggregationOperator); - } - - @Test - void shouldImplIndexScanWithSort() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression sortExpr = ref("name", STRING); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - Pair.of(Sort.SortOption.DEFAULT_ASC, sortExpr) - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldImplIndexScanWithLimit() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - 1, 1, noProjects() - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldImplIndexScanWithSortAndLimit() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression sortExpr = ref("name", STRING); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - sortExpr, - 1, 1, - noProjects() - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownLimitFarFromRelationButUpdateScanSize() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - PhysicalPlan plan = index.implement(index.optimize( - project( - limit( - sort( - relation("test", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, - DSL.abs(named("intV", ref("intV", INTEGER)))) - ), - 300, 1 - ), - named("intV", ref("intV", INTEGER)) - ) - )); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof LimitOperator); - } - - @Test - void shouldPushDownProjects() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, projects(ref("intV", INTEGER)) - ), - named("i", ref("intV", INTEGER)))); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - - final FetchSourceContext fetchSource = - ((OpenSearchIndexScan) ((ProjectOperator) plan).getInput()).getRequestBuilder() - .getSourceBuilder().fetchSource(); - assertThat(fetchSource.includes(), arrayContaining("intV")); - assertThat(fetchSource.excludes(), emptyArray()); - } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java new file mode 100644 index 0000000000..363727cbd3 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -0,0 +1,609 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; +import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_AGGREGATION; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_FILTER; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_HIGHLIGHT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_LIMIT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_PROJECT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_SORT; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.HighlightExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; +import org.opensearch.sql.storage.Table; + + +@ExtendWith(MockitoExtension.class) +class OpenSearchIndexScanOptimizationTest { + + @Mock + private Table table; + + @Mock + private OpenSearchIndexScan indexScan; + + private OpenSearchIndexScanBuilder indexScanBuilder; + + @Mock + private OpenSearchRequestBuilder requestBuilder; + + private Runnable[] verifyPushDownCalls = {}; + + @BeforeEach + void setUp() { + indexScanBuilder = new OpenSearchIndexScanBuilder(indexScan); + when(table.createScanBuilder()).thenReturn(indexScanBuilder); + when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); + } + + @Test + void test_project_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withProjectPushedDown(DSL.ref("intV", INTEGER))), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER))) + ); + } + + /** + * SELECT intV as i FROM schema WHERE intV = 1. + */ + @Test + void test_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + //withProjectPushedDown(DSL.ref("intV", INTEGER)), + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema GROUP BY string_value. + */ + @Test + void test_aggregation_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("longV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + /* + @Disabled("This test should be enabled once https://github.com/opensearch-project/sql/issues/912 is fixed") + @Test + void aggregation_cant_merge_indexScan_with_project() { + assertEquals( + aggregation( + OpenSearchLogicalIndexScan.builder().relationName("schema") + .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + .projectList(ImmutableSet.of(DSL.ref("intV", INTEGER))) + .build(), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + optimize( + aggregation( + OpenSearchLogicalIndexScan.builder().relationName("schema") + .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + .projectList( + ImmutableSet.of(DSL.ref("intV", INTEGER))) + .build(), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG)))))) + ); + } + */ + + /** + * Sort - Relation --> IndexScan. + */ + @Test + void test_sort_push_down() { + assertEqualsAfterOptimization( + indexScanBuilder( + withSortPushedDown( + SortBuilders.fieldSort("intV").order(SortOrder.ASC).missing("_first")) + ), + sort( + relation("schema", table), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void test_limit_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withLimitPushedDown(1, 1)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + project( + limit( + relation("schema", table), + 1, 1), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void test_highlight_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withHighlightPushedDown("*", Collections.emptyMap())), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) + ), + project( + highlight( + relation("schema", table), + DSL.literal("*"), Collections.emptyMap()), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema WHERE intV = 1 GROUP BY string_value. + */ + @Test + void test_aggregation_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("longV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ), + project( + aggregation( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + /** + * Sort - Filter - Relation --> IndexScan. + */ + @Test + void test_sort_filter_push_down() { + assertEqualsAfterOptimization( + indexScanBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withSortPushedDown( + SortBuilders.fieldSort("longV").order(SortOrder.ASC).missing("_first")) + ), + sort( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. + */ + @Test + void test_sort_aggregation_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .sortBy(SortOption.DEFAULT_DESC) + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + @Test + void test_limit_sort_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withSortPushedDown( + SortBuilders.fieldSort("longV").order(SortOrder.ASC).missing("_first")), + withLimitPushedDown(1, 1)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + project( + limit( + sort( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) + ), 1, 1 + ), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ) + ); + } + + /* + * Project(ListA) -> Project(ListB) -> Relation. + * -- will be optimized as + * Project(ListA) -> Project(ListB) -> Relation(project=ListB). + */ + @Test + void only_one_project_should_be_push() { + assertEqualsAfterOptimization( + project( + project( + indexScanBuilder( + withProjectPushedDown( + DSL.ref("intV", INTEGER), + DSL.ref("stringV", STRING))), + DSL.named("i", DSL.ref("intV", INTEGER)), + DSL.named("s", DSL.ref("stringV", STRING)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER)), + DSL.named("s", DSL.ref("stringV", STRING)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void sort_with_expression_cannot_merge_with_relation() { + assertEqualsAfterOptimization( + sort( + indexScanBuilder(), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ), + sort( + relation("schema", table), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void sort_with_expression_cannot_merge_with_aggregation() { + assertEqualsAfterOptimization( + sort( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ), + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void aggregation_cant_merge_index_scan_with_limit() { + assertEqualsAfterOptimization( + project( + aggregation( + indexScanBuilder( + withLimitPushedDown(10, 0)), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + aggregation( + limit( + relation("schema", table), + 10, 0), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)))); + } + + /** + * Can't Optimize the following query. + * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY avg(intV). + */ + @Test + void sort_refer_to_aggregator_should_not_merge_with_indexAgg() { + assertEqualsAfterOptimization( + project( + sort( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("AVG(intV)", INTEGER)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("AVG(intV)", INTEGER)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + @Test + void project_literal_should_not_be_pushed_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder(), + DSL.named("i", DSL.literal("str")) + ), + optimize( + project( + relation("schema", table), + DSL.named("i", DSL.literal("str")) + ) + ) + ); + } + + private OpenSearchIndexScanBuilder indexScanBuilder(Runnable... verifyPushDownCalls) { + this.verifyPushDownCalls = verifyPushDownCalls; + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(indexScan)); + } + + private OpenSearchIndexScanBuilder indexScanAggBuilder(Runnable... verifyPushDownCalls) { + this.verifyPushDownCalls = verifyPushDownCalls; + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder(indexScan)); + } + + private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan actual) { + assertEquals(expected, optimize(actual)); + + // Trigger build to make sure all push down actually happened in scan builder + indexScanBuilder.build(); + + // Verify to make sure all push down methods are called as expected + if (verifyPushDownCalls.length == 0) { + reset(indexScan); + } else { + Arrays.stream(verifyPushDownCalls).forEach(Runnable::run); + } + } + + private Runnable withFilterPushedDown(QueryBuilder filteringCondition) { + return () -> verify(requestBuilder, times(1)).pushDown(filteringCondition); + } + + private Runnable withAggregationPushedDown( + AggregationAssertHelper.AggregationAssertHelperBuilder aggregation) { + + // Assume single term bucket and AVG metric in all tests in this suite + CompositeAggregationBuilder aggBuilder = AggregationBuilders.composite( + "composite_buckets", + Collections.singletonList( + new TermsValuesSourceBuilder(aggregation.groupBy) + .field(aggregation.groupBy) + .order(aggregation.sortBy.getSortOrder() == ASC ? "asc" : "desc") + .missingOrder(aggregation.sortBy.getNullOrder() == NULL_FIRST ? "first" : "last") + .missingBucket(true))) + .subAggregation( + AggregationBuilders.avg(aggregation.aggregateName) + .field(aggregation.aggregateBy)) + .size(AggregationQueryBuilder.AGGREGATION_BUCKET_SIZE); + + List aggBuilders = Collections.singletonList(aggBuilder); + OpenSearchAggregationResponseParser responseParser = + new CompositeAggregationParser( + new SingleValueParser(aggregation.aggregateName)); + + return () -> { + verify(requestBuilder, times(1)).pushDownAggregation(Pair.of(aggBuilders, responseParser)); + verify(requestBuilder, times(1)).pushTypeMapping(aggregation.resultTypes); + }; + } + + private Runnable withSortPushedDown(SortBuilder... sorts) { + return () -> verify(requestBuilder, times(1)).pushDownSort(Arrays.asList(sorts)); + } + + private Runnable withLimitPushedDown(int size, int offset) { + return () -> verify(requestBuilder, times(1)).pushDownLimit(size, offset); + } + + private Runnable withProjectPushedDown(ReferenceExpression... references) { + return () -> verify(requestBuilder, times(1)).pushDownProjects( + new HashSet<>(Arrays.asList(references))); + } + + private Runnable withHighlightPushedDown(String field, Map arguments) { + return () -> verify(requestBuilder, times(1)).pushDownHighlight(field, arguments); + } + + private static AggregationAssertHelper.AggregationAssertHelperBuilder aggregate(String aggName) { + var aggBuilder = new AggregationAssertHelper.AggregationAssertHelperBuilder(); + aggBuilder.aggregateName = aggName; + aggBuilder.sortBy = SortOption.DEFAULT_ASC; + return aggBuilder; + } + + /** Assertion helper for readability. */ + @Builder + private static class AggregationAssertHelper { + + String aggregateName; + + String aggregateBy; + + String groupBy; + + SortOption sortBy; + + Map resultTypes; + } + + private LogicalPlan optimize(LogicalPlan plan) { + LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer(List.of( + new CreateTableScanBuilder(), + PUSH_DOWN_FILTER, + PUSH_DOWN_AGGREGATION, + PUSH_DOWN_SORT, + PUSH_DOWN_LIMIT, + PUSH_DOWN_HIGHLIGHT, + PUSH_DOWN_PROJECT)); + return optimizer.optimize(plan); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java index 2ed9a16434..85b8889de3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java @@ -20,141 +20,10 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; @UtilityClass public class Utils { - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Expression filter) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Pair... sorts) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .sortList(Arrays.asList(sorts)) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Pair... sorts) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .sortList(Arrays.asList(sorts)) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Integer offset, Integer limit, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Integer offset, Integer limit, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Integer offset, Integer limit, - List> sorts, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .sortList(sorts) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Set projects) { - return OpenSearchLogicalIndexScan.builder() - .relationName(tableName) - .projectList(projects) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Expression filter, - Set projects) { - return OpenSearchLogicalIndexScan.builder() - .relationName(tableName) - .filter(filter) - .projectList(projects) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, List aggregators, - List groupByList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName) - .aggregatorList(aggregators).groupByList(groupByList).build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, List aggregators, - List groupByList, - List> sortList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName) - .aggregatorList(aggregators).groupByList(groupByList).sortList(sortList).build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, - Expression filter, - List aggregators, - List groupByList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName).filter(filter) - .aggregatorList(aggregators).groupByList(groupByList).build(); - } - public static AvgAggregator avg(Expression expr, ExprCoreType type) { return new AvgAggregator(Arrays.asList(expr), type); }