From 5e8a914075341cf8f00b8e58ce48c1e635d85fb7 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 25 Sep 2023 14:28:02 -0700 Subject: [PATCH] adding window function support for time based spans Signed-off-by: YANGDB --- .../flint/spark/FlintSparkPPLITSuite.scala | 35 -- .../FlintSparkPPLTimeWindowITSuite.scala | 359 ++++++++++++++++++ .../sql/ppl/utils/DataTypeTransformer.java | 47 +++ .../sql/ppl/utils/WindowSpecTransformer.java | 30 +- 4 files changed, 433 insertions(+), 38 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLTimeWindowITSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 54e6f7339..d2448a3f4 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -857,41 +857,6 @@ class FlintSparkPPLITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - ignore("create ppl simple count age by span of interval of 10 years query order by age test ") { - val frame = sql(s""" - | source = $testTable| stats count(age) by span(age, 10) as age_span | sort age_span - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) - - // Compare the results - assert(results === expectedResults) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "span (age,10,NONE)")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), - global = true, - expectedPlan) - // Compare the two plans - assert(sortedPlan === logicalPlan) - } - /** * | age_span | average_age | * |:---------|------------:| diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLTimeWindowITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLTimeWindowITSuite.scala new file mode 100644 index 000000000..e7452436a --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLTimeWindowITSuite.scala @@ -0,0 +1,359 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.sql.Timestamp + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, Floor, GenericRowWithSchema, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTimeWindowITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_sales_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | transactionId STRING, + | transactionDate TIMESTAMP, + | productId STRING, + | productsAmount INT, + | customerId STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for April 2023 + sql(s""" + |INSERT INTO $testTable PARTITION (year=2023, month=4) + |VALUES + |('txn001', CAST('2023-04-01 10:30:00' AS TIMESTAMP), 'prod1', 2, 'cust1'), + |('txn001', CAST('2023-04-01 14:30:00' AS TIMESTAMP), 'prod1', 4, 'cust1'), + |('txn002', CAST('2023-04-02 11:45:00' AS TIMESTAMP), 'prod2', 1, 'cust2'), + |('txn003', CAST('2023-04-03 12:15:00' AS TIMESTAMP), 'prod3', 3, 'cust1'), + |('txn004', CAST('2023-04-04 09:50:00' AS TIMESTAMP), 'prod1', 1, 'cust3') + | """.stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for May 2023 + sql(s""" + |INSERT INTO $testTable PARTITION (year=2023, month=5) + |VALUES + |('txn005', CAST('2023-05-01 08:30:00' AS TIMESTAMP), 'prod2', 1, 'cust4'), + |('txn006', CAST('2023-05-02 07:25:00' AS TIMESTAMP), 'prod4', 5, 'cust2'), + |('txn007', CAST('2023-05-03 15:40:00' AS TIMESTAMP), 'prod3', 1, 'cust3'), + |('txn007', CAST('2023-05-03 19:30:00' AS TIMESTAMP), 'prod3', 2, 'cust3'), + |('txn008', CAST('2023-05-04 14:15:00' AS TIMESTAMP), 'prod1', 4, 'cust1') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl query count sales by days window test") { + /* + val dataFrame = spark.read.table(testTable) + val query = dataFrame + .groupBy( + window( + col("transactionDate"), " 1 days") + ).agg(sum(col("productsAmount"))) + + query.show(false) + */ + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.getAs[GenericRowWithSchema](1).get(0), + row.getAs[GenericRowWithSchema](1).get(1))) + + // Define the expected results + val expectedResults = Array( + Row(6, Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-04 17:00:00")), + Row(3, Timestamp.valueOf("2023-04-02 17:00:00"), Timestamp.valueOf("2023-04-03 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-01 17:00:00"), Timestamp.valueOf("2023-04-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-03 17:00:00"), Timestamp.valueOf("2023-04-04 17:00:00")), + Row(1, Timestamp.valueOf("2023-05-02 17:00:00"), Timestamp.valueOf("2023-05-03 17:00:00")), + Row(5, Timestamp.valueOf("2023-05-01 17:00:00"), Timestamp.valueOf("2023-05-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-30 17:00:00"), Timestamp.valueOf("2023-05-01 17:00:00")), + Row(6, Timestamp.valueOf("2023-03-31 17:00:00"), Timestamp.valueOf("2023-04-01 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl query count sales by days window with sorting test") { + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.getAs[GenericRowWithSchema](1).get(0), + row.getAs[GenericRowWithSchema](1).get(1))) + + // Define the expected results + val expectedResults = Array( + Row(6, Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-04 17:00:00")), + Row(3, Timestamp.valueOf("2023-04-02 17:00:00"), Timestamp.valueOf("2023-04-03 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-01 17:00:00"), Timestamp.valueOf("2023-04-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-03 17:00:00"), Timestamp.valueOf("2023-04-04 17:00:00")), + Row(1, Timestamp.valueOf("2023-05-02 17:00:00"), Timestamp.valueOf("2023-05-03 17:00:00")), + Row(5, Timestamp.valueOf("2023-05-01 17:00:00"), Timestamp.valueOf("2023-05-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-30 17:00:00"), Timestamp.valueOf("2023-05-01 17:00:00")), + Row(6, Timestamp.valueOf("2023-03-31 17:00:00"), Timestamp.valueOf("2023-04-01 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } + + test("create ppl query count sales by days window and productId with sorting test") { + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date, productId | sort age_date + | """.stripMargin) + + frame.show(false) + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.get(1), + row.getAs[GenericRowWithSchema](2).get(0), + row.getAs[GenericRowWithSchema](2).get(1))) + + // Define the expected results + val expectedResults = Array( + Row( + 6, + "prod1", + Timestamp.valueOf("2023-03-31 17:00:00"), + Timestamp.valueOf("2023-04-01 17:00:00")), + Row( + 1, + "prod2", + Timestamp.valueOf("2023-04-01 17:00:00"), + Timestamp.valueOf("2023-04-02 17:00:00")), + Row( + 3, + "prod3", + Timestamp.valueOf("2023-04-02 17:00:00"), + Timestamp.valueOf("2023-04-03 17:00:00")), + Row( + 1, + "prod1", + Timestamp.valueOf("2023-04-03 17:00:00"), + Timestamp.valueOf("2023-04-04 17:00:00")), + Row( + 1, + "prod2", + Timestamp.valueOf("2023-04-30 17:00:00"), + Timestamp.valueOf("2023-05-01 17:00:00")), + Row( + 5, + "prod4", + Timestamp.valueOf("2023-05-01 17:00:00"), + Timestamp.valueOf("2023-05-02 17:00:00")), + Row( + 1, + "prod3", + Timestamp.valueOf("2023-05-02 17:00:00"), + Timestamp.valueOf("2023-05-03 17:00:00")), + Row( + 4, + "prod1", + Timestamp.valueOf("2023-05-03 17:00:00"), + Timestamp.valueOf("2023-05-04 17:00:00")), + Row( + 2, + "prod3", + Timestamp.valueOf("2023-05-03 17:00:00"), + Timestamp.valueOf("2023-05-04 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](2)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsId = Alias(UnresolvedAttribute("productId"), "productId")() + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = Aggregate( + Seq(productsId, windowExpression), + Seq(aggregateExpressions, productsId, windowExpression), + table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } + + ignore("create ppl simple count age by span of interval of 10 years query order by age test ") { + val frame = sql(s""" + | source = $testTable| stats count(age) by span(age, 10) as age_span | sort age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) + + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(sortedPlan === logicalPlan) + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java index f0369ae69..0c7269a07 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -12,10 +12,21 @@ import org.apache.spark.sql.types.IntegerType$; import org.apache.spark.sql.types.StringType$; import org.apache.spark.unsafe.types.UTF8String; +import org.opensearch.sql.ast.expression.SpanUnit; import scala.collection.mutable.Seq; import java.util.List; +import static org.opensearch.sql.ast.expression.SpanUnit.DAY; +import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; +import static org.opensearch.sql.ast.expression.SpanUnit.MILLISECOND; +import static org.opensearch.sql.ast.expression.SpanUnit.MINUTE; +import static org.opensearch.sql.ast.expression.SpanUnit.MONTH; +import static org.opensearch.sql.ast.expression.SpanUnit.NONE; +import static org.opensearch.sql.ast.expression.SpanUnit.QUARTER; +import static org.opensearch.sql.ast.expression.SpanUnit.SECOND; +import static org.opensearch.sql.ast.expression.SpanUnit.WEEK; +import static org.opensearch.sql.ast.expression.SpanUnit.YEAR; import static scala.collection.JavaConverters.asScalaBufferConverter; /** @@ -54,4 +65,40 @@ static Object translate(Object value, org.opensearch.sql.ast.expression.DataType return value; } } + + static String translate(SpanUnit unit) { + switch (unit) { + case UNKNOWN: + case NONE: + return NONE.name(); + case MILLISECOND: + case MS: + return MILLISECOND.name(); + case SECOND: + case S: + return SECOND.name(); + case MINUTE: + case m: + return MINUTE.name(); + case HOUR: + case H: + return HOUR.name(); + case DAY: + case D: + return DAY.name(); + case WEEK: + case W: + return WEEK.name(); + case MONTH: + case M: + return MONTH.name(); + case QUARTER: + case Q: + return QUARTER.name(); + case YEAR: + case Y: + return YEAR.name(); + } + return ""; + } } \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index 5fb1c1942..c215caec5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -8,9 +8,19 @@ import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Floor; +import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.Multiply; +import org.apache.spark.sql.catalyst.expressions.TimeWindow; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.StringType$; import org.opensearch.sql.ast.expression.SpanUnit; +import static java.lang.String.format; +import static org.opensearch.sql.ast.expression.DataType.STRING; +import static org.opensearch.sql.ast.expression.SpanUnit.NONE; +import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; + public interface WindowSpecTransformer { /** @@ -22,8 +32,22 @@ public interface WindowSpecTransformer { * @return */ static Expression window(Expression fieldExpression, Expression valueExpression, SpanUnit unit) { - // todo check can WindowSpec provide the same functionality as below - // todo for time unit - use TimeWindowSpec if possible + // In case the unit is time unit - use TimeWindowSpec if possible + if (isTimeBased(unit)) { + return new TimeWindow(fieldExpression,timeLiteral(valueExpression, unit)); + } + // if the unit is not time base - create a math expression to bucket the span partitions return new Multiply(new Floor(new Divide(fieldExpression, valueExpression)), valueExpression); - } + } + + static boolean isTimeBased(SpanUnit unit) { + return !(unit == NONE || unit == UNKNOWN); + } + + + static org.apache.spark.sql.catalyst.expressions.Literal timeLiteral( Expression valueExpression, SpanUnit unit) { + String format = format("%s %s", valueExpression.toString(), translate(unit)); + return new org.apache.spark.sql.catalyst.expressions.Literal( + translate(format, STRING), translate(STRING)); + } }