diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 49fc2879f..09b3dbdd7 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -6,7 +6,7 @@ package org.opensearch.flint.spark import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, EqualTo, GreaterThan, LessThan, LessThanOrEqual, Literal, Not, Or} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.{QueryTest, Row} @@ -631,4 +631,140 @@ class FlintSparkPPLITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + /** + * +--------+-------+-----------+ + * |age_span| count_age| + * +--------+-------+-----------+ + * | 20| 2 | + * | 30| 1 | + * | 70| 1 | + * +--------+-------+-----------+ + */ + test("create ppl simple count age by span of interval of 10 years query test ") { + val frame = sql( + s""" + | source = $testTable| stats count(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(1, 70L), + Row(1, 30L), + Row(2, 20L), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * +--------+-------+-----------+ + * |age_span| average_age| + * +--------+-------+-----------+ + * | 20| 22.5 | + * | 30| 30 | + * | 70| 70 | + * +--------+-------+-----------+ + */ + test("create ppl simple avg age by span of interval of 10 years query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(70D, 70L), + Row(30D, 30L), + Row(22.5D, 20L), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * +--------+-------+-----------+ + * |age_span|country|average_age| + * +--------+-------+-----------+ + * | 20| Canada| 22.5| + * | 30| USA| 30| + * | 70| USA| 70| + * +--------+-------+-----------+ + */ + ignore("create ppl average age by span of interval of 10 years group by country query test ") { + val frame = sql( + s""" + | source = $testTable | stats avg(age) by span(age, 10) as age_span, country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(1, 70L), + Row(1, 30L), + Row(2, 20L), + ) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( Multiply(Floor( Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java index 450fbaf3a..b68edbc62 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java @@ -22,6 +22,18 @@ public Span(UnresolvedExpression field, UnresolvedExpression value, SpanUnit uni this.unit = unit; } + public UnresolvedExpression getField() { + return field; + } + + public UnresolvedExpression getValue() { + return value; + } + + public SpanUnit getUnit() { + return unit; + } + @Override public List getChild() { return ImmutableList.of(field, value); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index bfd2464e5..039459150 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -12,7 +12,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Floor; +import org.apache.spark.sql.catalyst.expressions.Multiply; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; @@ -34,6 +37,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.ast.statement.Explain; import org.opensearch.sql.ast.statement.Query; @@ -55,6 +59,7 @@ import scala.collection.Seq; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import static com.google.common.base.Strings.isNullOrEmpty; @@ -147,25 +152,33 @@ public String visitAggregation(Aggregation node, CatalystPlanContext context) { final String visitExpressionList = visitExpressionList(node.getAggExprList(), context); final String group = visitExpressionList(node.getGroupExprList(), context); - - if(!isNullOrEmpty(group)) { - NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek(); - Seq namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream() - .map(v->(NamedExpression)v).collect(Collectors.toList())).toSeq(); - //now remove all context.getNamedParseExpressions() - context.getNamedParseExpressions().retainAll(emptyList()); - context.plan(p->new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)),namedExpressionSeq,p)); + if (!isNullOrEmpty(group)) { + extractedAggregation(context); + } + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + span.accept(this, context); + extractedAggregation(context); } return format( "%s | stats %s", child, String.join(" ", visitExpressionList, groupBy(group)).trim()); } - @Override - public String visitSpan(Span node, CatalystPlanContext context) { - return super.visitSpan(node, context); + private static void extractedAggregation(CatalystPlanContext context) { + NamedExpression namedExpression = (NamedExpression) context.getNamedParseExpressions().peek(); + Seq namedExpressionSeq = asScalaBuffer(context.getNamedParseExpressions().stream() + .map(v -> (NamedExpression) v).collect(Collectors.toList())).toSeq(); + //now remove all context.getNamedParseExpressions() + context.getNamedParseExpressions().retainAll(emptyList()); + context.plan(p -> new Aggregate(asScalaBuffer(singletonList((Expression) namedExpression)), namedExpressionSeq, p)); } + @Override + public String visitAlias(Alias node, CatalystPlanContext context) { + return expressionAnalyzer.visitAlias(node, context); + } + @Override public String visitRareTopN(RareTopN node, CatalystPlanContext context) { final String child = node.getChild().get(0).accept(this, context); @@ -190,7 +203,7 @@ public String visitProject(Project node, CatalystPlanContext context) { // Create a projection list from the existing expressions Seq projectList = asScalaBuffer(context.getNamedParseExpressions()).toSeq(); - if(!projectList.isEmpty()) { + if (!projectList.isEmpty()) { // build the plan with the projection step context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project((Seq) projectList, p)); } @@ -296,7 +309,7 @@ public String visitAnd(And node, CatalystPlanContext context) { String left = node.getLeft().accept(this, context); String right = node.getRight().accept(this, context); context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.And( - (Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop())); + (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); return format("%s and %s", left, right); } @@ -305,7 +318,7 @@ public String visitOr(Or node, CatalystPlanContext context) { String left = node.getLeft().accept(this, context); String right = node.getRight().accept(this, context); context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Or( - (Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop())); + (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); return format("%s or %s", left, right); } @@ -314,7 +327,7 @@ public String visitXor(Xor node, CatalystPlanContext context) { String left = node.getLeft().accept(this, context); String right = node.getRight().accept(this, context); context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.BitwiseXor( - (Expression) context.getNamedParseExpressions().pop(),context.getNamedParseExpressions().pop())); + (Expression) context.getNamedParseExpressions().pop(), context.getNamedParseExpressions().pop())); return format("%s xor %s", left, right); } @@ -328,7 +341,14 @@ public String visitNot(Not node, CatalystPlanContext context) { @Override public String visitSpan(Span node, CatalystPlanContext context) { - return super.visitSpan(node, context); + String field = node.getField().accept(this, context); + String value = node.getValue().accept(this, context); + String unit = node.getUnit().name(); + + Expression valueExpression = context.getNamedParseExpressions().pop(); + Expression fieldExpression = context.getNamedParseExpressions().pop(); + context.getNamedParseExpressions().push(new Multiply(new Floor(new Divide(fieldExpression, valueExpression)), valueExpression)); + return format("span (%s,%s,%s)", field, value, unit); } @Override @@ -366,7 +386,7 @@ public String visitField(Field node, CatalystPlanContext context) { @Override public String visitAllFields(AllFields node, CatalystPlanContext context) { // Case of aggregation step - no start projection can be added - if(!context.getNamedParseExpressions().isEmpty()) { + if (!context.getNamedParseExpressions().isEmpty()) { // if named expression exist - just return their names return context.getNamedParseExpressions().peek().toString(); } else { @@ -376,6 +396,11 @@ public String visitAllFields(AllFields node, CatalystPlanContext context) { } } + @Override + public String visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + return super.visitWindowFunction(node, context); + } + @Override public String visitAlias(Alias node, CatalystPlanContext context) { String expr = node.getDelegated().accept(this, context); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanSimpleTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala similarity index 99% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanSimpleTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala index 6076131fc..3bbdf7669 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanSimpleTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala @@ -16,7 +16,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers -class PPLLogicalPlanSimpleTranslatorTestSuite +class PPLLogicalAdvancedTranslatorTestSuite extends SparkFunSuite with Matchers { diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 000f77afc..092efe22a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, EqualTo, Floor, Literal, Multiply} import org.apache.spark.sql.catalyst.plans.logical._ import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan @@ -104,5 +104,39 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite assertEquals(compareByString(expectedPlan), compareByString(context.getPlan)) } + test("create ppl simple avg age by span of interval of 10 years query test ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source = table | stats avg(age) by span(age, 10) as age_span", false), context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *") + assert(compareByString(expectedPlan) === compareByString(context.getPlan)) + } + + ignore("create ppl simple avg age by span of interval of 10 years by country query test ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source = table | stats avg(age) by span(age, 10) as age_span, country", false), context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(logPlan, "source=[table] | stats avg(age) | fields + *") + assert(compareByString(expectedPlan) === compareByString(context.getPlan)) + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanComplexQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala similarity index 98% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanComplexQueriesTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 293bd3729..517db2ec7 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanComplexQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -14,7 +14,7 @@ import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers -class PPLLogicalPlanComplexQueriesTranslatorTestSuite +class PPLLogicalPlanBasicQueriesTranslatorTestSuite extends SparkFunSuite with Matchers {