diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 4964fb882..f6a0f2bd5 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -22,7 +22,7 @@ sbt scalafmtAll ``` The code style is automatically checked, but users can also manually check it. ``` -sbt sbt scalastyle +sbt scalastyle ``` For IntelliJ user, read more in [scalafmt IntelliJ](https://scalameta.org/scalafmt/docs/installation.html#intellij) to integrate scalafmt with IntelliJ diff --git a/README.md b/README.md index 030f34b5e..2f957a40a 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,12 @@ OpenSearch Flint is ... It consists of two modules: - `flint-core`: a module that contains Flint specification and client. - `flint-spark-integration`: a module that provides Spark integration for Flint and derived dataset based on it. +- `ppl-spark-integration`: a module that provides PPL query execution on top of Spark See [PPL repository](https://github.com/opensearch-project/piped-processing-language). ## Documentation Please refer to the [Flint Index Reference Manual](./docs/index.md) for more information. +For PPL language see [PPL Reference Manual](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) for more information. ## Prerequisites @@ -17,7 +19,7 @@ Version compatibility: |---------------|-------------|---------------|---------------|------------| | 0.1.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ | -## Usage +## Flint Extension Usage To use this application, you can run Spark with Flint extension: @@ -25,6 +27,14 @@ To use this application, you can run Spark with Flint extension: spark-sql --conf "spark.sql.extensions=org.opensearch.flint.FlintSparkExtensions" ``` +## PPL Extension Usage + +To use PPL to Spark translation, you can run Spark with PPL extension: + +``` +spark-sql --conf "spark.sql.extensions=org.opensearch.flint.FlintPPLSparkExtensions" +``` + ## Build To build and run this application with Spark, you can run: @@ -37,6 +47,18 @@ then add org.opensearch:opensearch-spark_2.12 when run spark application, for ex bin/spark-shell --packages "org.opensearch:opensearch-spark_2.12:0.1.0-SNAPSHOT" ``` +### PPL Build & Run + +To build and run this PPL in Spark, you can run: + +``` +sbt clean sparkPPLCosmetic/publishM2 +``` +then add org.opensearch:opensearch-spark_2.12 when run spark application, for example, +``` +bin/spark-shell --packages "org.opensearch:opensearch-spark-ppl_2.12:0.1.0-SNAPSHOT" +``` + ## Code of Conduct This project has adopted an [Open Source Code of Conduct](./CODE_OF_CONDUCT.md). diff --git a/build.sbt b/build.sbt index 29a3cdba7..6b7c8d53a 100644 --- a/build.sbt +++ b/build.sbt @@ -43,7 +43,7 @@ lazy val commonSettings = Seq( Test / test := ((Test / test) dependsOn testScalastyle).value) lazy val root = (project in file(".")) - .aggregate(flintCore, flintSparkIntegration, sparkSqlApplication) + .aggregate(flintCore, flintSparkIntegration, pplSparkIntegration, sparkSqlApplication) .disablePlugins(AssemblyPlugin) .settings(name := "flint", publish / skip := true) @@ -61,6 +61,42 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind")), publish / skip := true) +lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) + .enablePlugins(AssemblyPlugin, Antlr4Plugin) + .settings( + commonSettings, + name := "ppl-spark-integration", + scalaVersion := scala212, + libraryDependencies ++= Seq( + "org.scalactic" %% "scalactic" % "3.2.15" % "test", + "org.scalatest" %% "scalatest" % "3.2.15" % "test", + "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", + "org.scalatestplus" %% "mockito-4-6" % "3.2.15.0" % "test", + "com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test", + "com.github.sbt" % "junit-interface" % "0.13.3" % "test"), + libraryDependencies ++= deps(sparkVersion), + // ANTLR settings + Antlr4 / antlr4Version := "4.8", + Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.ppl"), + Antlr4 / antlr4GenListener := true, + Antlr4 / antlr4GenVisitor := true, + // Assembly settings + assemblyPackageScala / assembleArtifact := false, + assembly / assemblyOption ~= { + _.withIncludeScala(false) + }, + assembly / assemblyMergeStrategy := { + case PathList(ps @ _*) if ps.last endsWith ("module-info.class") => + MergeStrategy.discard + case PathList("module-info.class") => MergeStrategy.discard + case PathList("META-INF", "versions", xs @ _, "module-info.class") => + MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }, + assembly / test := (Test / test).value) + lazy val flintSparkIntegration = (project in file("flint-spark-integration")) .dependsOn(flintCore) .enablePlugins(AssemblyPlugin, Antlr4Plugin) @@ -102,7 +138,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) // Test assembly package with integration test. lazy val integtest = (project in file("integ-test")) - .dependsOn(flintSparkIntegration % "test->test") + .dependsOn(flintSparkIntegration % "test->test", pplSparkIntegration % "test->test" ) .settings( commonSettings, name := "integ-test", @@ -118,7 +154,7 @@ lazy val integtest = (project in file("integ-test")) "org.opensearch.client" % "opensearch-java" % "2.6.0" % "test" exclude ("com.fasterxml.jackson.core", "jackson-databind")), libraryDependencies ++= deps(sparkVersion), - Test / fullClasspath += (flintSparkIntegration / assembly).value) + Test / fullClasspath ++= Seq((flintSparkIntegration / assembly).value, (pplSparkIntegration / assembly).value)) lazy val standaloneCosmetic = project .settings( @@ -144,6 +180,14 @@ lazy val sparkSqlApplicationCosmetic = project exportJars := true, Compile / packageBin := (sparkSqlApplication / assembly).value) +lazy val sparkPPLCosmetic = project + .settings( + name := "opensearch-spark-ppl", + commonSettings, + releaseSettings, + exportJars := true, + Compile / packageBin := (pplSparkIntegration / assembly).value) + lazy val releaseSettings = Seq( publishMavenStyle := true, publishArtifact := true, diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala new file mode 100644 index 000000000..d9c0a1b8c --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} + +/** + * general utility functions for ppl to spark transformation test + */ +trait LogicalPlanTestUtils { + + /** + * utility function to compare two logical plans while ignoring the auto-generated expressionId + * associated with the alias which is used for projection or aggregation + * @param plan + * @return + */ + def compareByString(plan: LogicalPlan): String = { + // Create a rule to replace Alias's ExprId with a dummy id + val rule: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + val newProjections = p.projectList.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + p.copy(projectList = newProjections) + + case agg: Aggregate => + val newGrouping = agg.groupingExpressions.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + val newAggregations = agg.aggregateExpressions.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + agg.copy(groupingExpressions = newGrouping, aggregateExpressions = newAggregations) + + case other => other + } + + // Apply the rule using transform + val transformedPlan = plan.transform(rule) + + // Return the string representation of the transformed plan + transformedPlan.toString + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala new file mode 100644 index 000000000..70951ad27 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -0,0 +1,295 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlintPPLSuite + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLAggregationWithSpanITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + /** + * | age_span | count_age | + * |:---------|----------:| + * | 20 | 2 | + * | 30 | 1 | + * | 70 | 1 | + */ + test("create ppl simple count age by span of interval of 10 years query test ") { + val frame = sql(s""" + | source = $testTable| stats count(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | average_age | + * |:---------|------------:| + * | 20 | 22.5 | + * | 30 | 30 | + * | 70 | 70 | + */ + test("create ppl simple avg age by span of interval of 10 years query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val limitPlan = Limit(Literal(2), aggregatePlan) + val expectedPlan = Project(star, limitPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | + */ + test("create ppl average age by span of interval of 10 years group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 3 + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val limitPlan = Limit(Literal(3), aggregatePlan) + val expectedPlan = Project(star, limitPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | sort - age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), + global = true, + aggregatePlan) + val limitPlan = Limit(Literal(2), sortedPlan) + val expectedPlan = Project(star, limitPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala new file mode 100644 index 000000000..0a58a039d --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -0,0 +1,415 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlintPPLSuite + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, LessThan, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLAggregationsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple age avg query test") { + val frame = sql(s""" + | source = $testTable| stats avg(age) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(36.25)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg query with filter test") { + val frame = sql(s""" + | source = $testTable| where age < 50 | stats avg(age) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(25)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThan(ageField, Literal(50)) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val countryAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, countryAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg group by country head (limit) query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by country | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val projectPlan = Limit(Literal(1), aggregatePlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age max group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats max(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70, "USA"), Row(25, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age min group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats min(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(30, "USA"), Row(20, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age sum group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats sum(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(100L, "USA"), Row(45L, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age sum group by country order by age query test with sort ") { + val frame = sql(s""" + | source = $testTable| stats sum(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(45L, "Canada"), Row(100L, "USA")) + + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age count group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats count(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl simple age avg group by country with state filter query test ") { + val frame = sql(s""" + | source = $testTable| where state != 'Quebec' | stats avg(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(25.0, "Canada"), Row(50.0, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala new file mode 100644 index 000000000..0a4810b01 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -0,0 +1,452 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlintPPLSuite + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLFiltersITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple age literal equal filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable age=25 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = EqualTo(UnresolvedAttribute("age"), Literal(25)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable age>10 and country != 'USA' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = And( + GreaterThan(UnresolvedAttribute("age"), Literal(10)), + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA")))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { + val frame = sql(s""" + | source = $testTable age>10 and country != 'USA' | sort - age | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = And( + GreaterThan(UnresolvedAttribute("age"), Literal(10)), + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA")))) + val filterPlan = Filter(filterExpr, table) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, filterPlan) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable age<=20 OR country = 'USA' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Or( + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20)), + EqualTo(UnresolvedAttribute("country"), Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + val frame = sql(s""" + | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Or( + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20)), + EqualTo(UnresolvedAttribute("country"), Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val limitPlan = Limit(Literal(1), Project(projectList, filterPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal greater than filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable age>25 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = GreaterThan(UnresolvedAttribute("age"), Literal(25)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal smaller than equals filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable age<=65 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age literal smaller than equals filter query with two fields result with sort test") { + val frame = sql(s""" + | source = $testTable age<=65 | sort name | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("Jane", 20), Row("John", 25)) + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, filterPlan) + val expectedPlan = Project(projectList, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple name literal equal filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable name='Jake' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("Jake")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple name literal not equal filter query with two fields result test") { + val frame = sql(s""" + | source = $testTable name!='Jake' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Not(EqualTo(UnresolvedAttribute("name"), Literal("Jake"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple avg age by span of interval of 10 years query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val limitPlan = Limit(Literal(2), aggregatePlan) + val expectedPlan = Project(star, limitPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | + */ + test("create ppl average age by span of interval of 10 years group by country query test ") { + val dataFrame = spark.sql( + "SELECT FLOOR(age / 10) * 10 AS age_span, country, AVG(age) AS average_age FROM default.flint_ppl_test GROUP BY FLOOR(age / 10) * 10, country ") + dataFrame.collect(); + dataFrame.show() + + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala new file mode 100644 index 000000000..848c2af2c --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLITSuite.scala @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlintPPLSuite + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple query test") { + val frame = sql(s""" + | source = $testTable + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } + + test("create ppl simple query with head (limit) 3 test") { + val frame = sql(s""" + | source = $testTable| head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val limitPlan: LogicalPlan = + Limit(Literal(2), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple query with head (limit) and sorted test") { + val frame = sql(s""" + | source = $testTable| sort name | head 2 + | """.stripMargin) + + // Retrieve the results + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + UnresolvedRelation(Seq("default", "flint_ppl_test"))) + + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedStar(None)), Limit(Literal(2), sortedPlan)) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple query two with fields result test") { + val frame = sql(s""" + | source = $testTable| fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } + + test("create ppl simple sorted query two with fields result test sorted") { + val frame = sql(s""" + | source = $testTable| sort age | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jane", 20), Row("John", 25), Row("Hello", 30), Row("Jake", 70)) + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("default", "flint_ppl_test"))) + + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sortedPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple query two with fields and head (limit) test") { + val frame = sql(s""" + | source = $testTable| fields name, age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) + // Define the expected logical plan + val limitPlan: LogicalPlan = Limit(Literal(1), project) + val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), limitPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple query two with fields and head (limit) with sorting test") { + val frame = sql(s""" + | source = $testTable| fields name, age | head 1 | sort age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) + // Define the expected logical plan + val limitPlan: LogicalPlan = Limit(Literal(1), project) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan); + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala new file mode 100644 index 000000000..40bcbdcb9 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTimeWindowITSuite.scala @@ -0,0 +1,413 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import java.sql.Timestamp + +import org.opensearch.flint.spark.FlintPPLSuite + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, Floor, GenericRowWithSchema, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLTimeWindowITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_sales_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | transactionId STRING, + | transactionDate TIMESTAMP, + | productId STRING, + | productsAmount INT, + | customerId STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for April 2023 + sql(s""" + |INSERT INTO $testTable PARTITION (year=2023, month=4) + |VALUES + |('txn001', CAST('2023-04-01 10:30:00' AS TIMESTAMP), 'prod1', 2, 'cust1'), + |('txn001', CAST('2023-04-01 14:30:00' AS TIMESTAMP), 'prod1', 4, 'cust1'), + |('txn002', CAST('2023-04-02 11:45:00' AS TIMESTAMP), 'prod2', 1, 'cust2'), + |('txn003', CAST('2023-04-03 12:15:00' AS TIMESTAMP), 'prod3', 3, 'cust1'), + |('txn004', CAST('2023-04-04 09:50:00' AS TIMESTAMP), 'prod1', 1, 'cust3') + | """.stripMargin) + + // Update data insertion + // -- Inserting records into the testTable for May 2023 + sql(s""" + |INSERT INTO $testTable PARTITION (year=2023, month=5) + |VALUES + |('txn005', CAST('2023-05-01 08:30:00' AS TIMESTAMP), 'prod2', 1, 'cust4'), + |('txn006', CAST('2023-05-02 07:25:00' AS TIMESTAMP), 'prod4', 5, 'cust2'), + |('txn007', CAST('2023-05-03 15:40:00' AS TIMESTAMP), 'prod3', 1, 'cust3'), + |('txn007', CAST('2023-05-03 19:30:00' AS TIMESTAMP), 'prod3', 2, 'cust3'), + |('txn008', CAST('2023-05-04 14:15:00' AS TIMESTAMP), 'prod1', 4, 'cust1') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl query count sales by days window test") { + /* + val dataFrame = spark.read.table(testTable) + val query = dataFrame + .groupBy( + window( + col("transactionDate"), " 1 days") + ).agg(sum(col("productsAmount"))) + + query.show(false) + */ + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.getAs[GenericRowWithSchema](1).get(0), + row.getAs[GenericRowWithSchema](1).get(1))) + + // Define the expected results + val expectedResults = Array( + Row(6, Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-04 17:00:00")), + Row(3, Timestamp.valueOf("2023-04-02 17:00:00"), Timestamp.valueOf("2023-04-03 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-01 17:00:00"), Timestamp.valueOf("2023-04-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-03 17:00:00"), Timestamp.valueOf("2023-04-04 17:00:00")), + Row(1, Timestamp.valueOf("2023-05-02 17:00:00"), Timestamp.valueOf("2023-05-03 17:00:00")), + Row(5, Timestamp.valueOf("2023-05-01 17:00:00"), Timestamp.valueOf("2023-05-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-30 17:00:00"), Timestamp.valueOf("2023-05-01 17:00:00")), + Row(6, Timestamp.valueOf("2023-03-31 17:00:00"), Timestamp.valueOf("2023-04-01 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl query count sales by days window with sorting test") { + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.getAs[GenericRowWithSchema](1).get(0), + row.getAs[GenericRowWithSchema](1).get(1))) + + // Define the expected results + val expectedResults = Array( + Row(6, Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-04 17:00:00")), + Row(3, Timestamp.valueOf("2023-04-02 17:00:00"), Timestamp.valueOf("2023-04-03 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-01 17:00:00"), Timestamp.valueOf("2023-04-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-03 17:00:00"), Timestamp.valueOf("2023-04-04 17:00:00")), + Row(1, Timestamp.valueOf("2023-05-02 17:00:00"), Timestamp.valueOf("2023-05-03 17:00:00")), + Row(5, Timestamp.valueOf("2023-05-01 17:00:00"), Timestamp.valueOf("2023-05-02 17:00:00")), + Row(1, Timestamp.valueOf("2023-04-30 17:00:00"), Timestamp.valueOf("2023-05-01 17:00:00")), + Row(6, Timestamp.valueOf("2023-03-31 17:00:00"), Timestamp.valueOf("2023-04-01 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl query count sales by days window and productId with sorting test") { + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date, productId | sort age_date + | """.stripMargin) + + frame.show(false) + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.get(1), + row.getAs[GenericRowWithSchema](2).get(0), + row.getAs[GenericRowWithSchema](2).get(1))) + + // Define the expected results + val expectedResults = Array( + Row( + 6, + "prod1", + Timestamp.valueOf("2023-03-31 17:00:00"), + Timestamp.valueOf("2023-04-01 17:00:00")), + Row( + 1, + "prod2", + Timestamp.valueOf("2023-04-01 17:00:00"), + Timestamp.valueOf("2023-04-02 17:00:00")), + Row( + 3, + "prod3", + Timestamp.valueOf("2023-04-02 17:00:00"), + Timestamp.valueOf("2023-04-03 17:00:00")), + Row( + 1, + "prod1", + Timestamp.valueOf("2023-04-03 17:00:00"), + Timestamp.valueOf("2023-04-04 17:00:00")), + Row( + 1, + "prod2", + Timestamp.valueOf("2023-04-30 17:00:00"), + Timestamp.valueOf("2023-05-01 17:00:00")), + Row( + 5, + "prod4", + Timestamp.valueOf("2023-05-01 17:00:00"), + Timestamp.valueOf("2023-05-02 17:00:00")), + Row( + 1, + "prod3", + Timestamp.valueOf("2023-05-02 17:00:00"), + Timestamp.valueOf("2023-05-03 17:00:00")), + Row( + 4, + "prod1", + Timestamp.valueOf("2023-05-03 17:00:00"), + Timestamp.valueOf("2023-05-04 17:00:00")), + Row( + 2, + "prod3", + Timestamp.valueOf("2023-05-03 17:00:00"), + Timestamp.valueOf("2023-05-04 17:00:00"))) + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](2)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsId = Alias(UnresolvedAttribute("productId"), "productId")() + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = Aggregate( + Seq(productsId, windowExpression), + Seq(aggregateExpressions, productsId, windowExpression), + table) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + test("create ppl query count sales by weeks window and productId with sorting test") { + val frame = sql(s""" + | source = $testTable| stats sum(productsAmount) by span(transactionDate, 1w) as age_date | sort age_date + | """.stripMargin) + + frame.show(false) + // Retrieve the results + val results: Array[Row] = frame + .collect() + .map(row => + Row( + row.get(0), + row.getAs[GenericRowWithSchema](1).get(0), + row.getAs[GenericRowWithSchema](1).get(1))) + + // Define the expected results + val expectedResults = Array( + Row(11, Timestamp.valueOf("2023-03-29 17:00:00"), Timestamp.valueOf("2023-04-05 17:00:00")), + Row(7, Timestamp.valueOf("2023-04-26 17:00:00"), Timestamp.valueOf("2023-05-03 17:00:00")), + Row(6, Timestamp.valueOf("2023-05-03 17:00:00"), Timestamp.valueOf("2023-05-10 17:00:00"))) + + // Compare the results + implicit val timestampOrdering: Ordering[Timestamp] = new Ordering[Timestamp] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Timestamp](_.getAs[Timestamp](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("default", "flint_ppl_sales_test")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + ignore("create ppl simple count age by span of interval of 10 years query order by age test ") { + val frame = sql(s""" + | source = $testTable| stats count(age) by span(age, 10) as age_span | sort age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) + + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), + global = true, + expectedPlan) + // Compare the two plans + assert(sortedPlan === logicalPlan) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md new file mode 100644 index 000000000..a824cfaaf --- /dev/null +++ b/ppl-spark-integration/README.md @@ -0,0 +1,280 @@ +## PPL Language Support On Spark + +This module provides the support for running [PPL](https://github.com/opensearch-project/piped-processing-language) queries on Spark using direct logical plan +translation between PPL's logical plan to Spark's Catalyst logical plan. + +### Context +The next concepts are the main purpose of introduction this functionality: +- Transforming PPL to become OpenSearch default query language (specifically for logs/traces/metrics signals) +- Promoting PPL as a viable candidate for the proposed CNCF Observability universal query language. +- Seamlessly Interact with different datasources (S3 / Prometheus / data-lake) from within OpenSearch +- Improve and promote PPL to become extensible and general purpose query language to be adopted by the community + +Acknowledging spark is an excellent conduit for promoting these goals and showcasing the capabilities of PPL to interact & federate data across multiple sources and domains. + +Another byproduct of introducing PPL on spark would be the much anticipated JOIN capability that will emerge from the usage of Spark compute engine. + +**What solution would you like?** + +For PPL to become a library which has a simple and easy means of importing and extending, PPL client (the thin API layer) which can interact and provide a generic query composition framework to be used in any type of application independently of OpenSearch plugins. + +![PPL endpoint](https://github.com/opensearch-project/opensearch-spark/assets/48943349/e9831a8f-abde-484c-9c62-331570e88460) + +As depicted in the above image, the protocol & AST (antler based language traversals ) verticals should be detached and composed into a self sustainable component that can be imported regardless of OpenSearch plugins. + +--- + +## PPL On Spark + +Running PPL on spark is a goal for allowing simple adoption of PPL query language and also for simplifying the Flint project to allow visualization for federated queries using the Observability dashboards capabilities. + + +### Background + +In Apache Spark, the DataFrame API serves as a programmatic interface for data manipulation and queries, allowing the construction of complex operations using a chain of method calls. This API can work in tandem with other query languages like SQL or PPL. + +For instance, if you have a PPL query and a translator, you can convert it into DataFrame operations to generate an optimized execution plan. Spark's underlying Catalyst optimizer will convert these DataFrame transformations and actions into an optimized physical plan executed over RDDs or Datasets. + +The following section describes the two main options for translating the PPL query (using the logical plan) into the spark corespondent component (either dataframe API or spark logical plan) + + +### Translation Process + +**Using Catalyst Logical Plan Grammar** +The leading option for translation would be using the Catalyst Grammar for directly translating the Logical plan steps +Here is an example of such translation outcome: + + +Our goal would be translating the PPL into the Unresolved logical plan so that the Analysis phase would behave in the similar manner to the SQL originated query. + +![spark execution process](https://github.com/opensearch-project/opensearch-spark/assets/48943349/780c0072-0ab4-4fb4-afb1-11fb3bfbd2c3) + +**The following PPL query:** +`search source=t'| where a=1` + +Translates into the PPL logical plan: +`Relation(tableName=t, alias=null), Compare(operator==, left=Field(field=a, fieldArgs=[]), right=1)` + +Would be transformed into the next catalyst Plan: +``` +// Create an UnresolvedRelation for the table 't' +val table = UnresolvedRelation(TableIdentifier("t")) +// Create an EqualTo expression for "a == 1" +val equalToCondition = EqualTo(UnresolvedAttribute("a"), ..Literal(1)) +// Create a Filter LogicalPlan +val filterPlan = Filter(equalToCondition, table) +``` + +The following PPL query: +`source=t | stats count(a) by b` + +Would produce the next PPL Logical Plan": +``` +Aggregation(aggExprList=[Alias(name=count(a), delegated=count(Field(field=a, fieldArgs=[])), alias=null)], +sortExprList=[], groupExprList=[Alias(name=b, delegated=Field(field=b, fieldArgs=[]), alias=null)], span=null, argExprList=[Argument(argName=partitions, value=1), Argument(argName=allnum, value=false), Argument(argName=delim, value= ), Argument(argName=dedupsplit, value=false)], child=[Relation(tableName=t, alias=null)]) +``` + +Would be transformed into the next catalyst Plan: +``` +// Create an UnresolvedRelation for the table 't' + val table = UnresolvedRelation(TableIdentifier("t")) + // Create an Alias for the aggregation expression 'count(a)' +val aggExpr = Alias(Count(UnresolvedAttribute("a")), "count(a)")() +// Create an Alias for the grouping expression 'b' +val groupExpr = Alias(UnresolvedAttribute("b"), "b")() +// Create an Aggregate LogicalPlan val aggregatePlan = Aggregate(Seq(groupExpr), Seq(groupExpr, aggExpr), table) +``` + +--- + + +## Design Considerations + +In general when translating between two query languages we have the following options: + +**1) Source Grammar Tree To destination Dataframe API Translation** +This option uses the syntax tree to directly translate from one language syntax grammar tree to the other language (dataframe) API thus eliminating the parsing phase and creating a strongly validated process that can be verified and tested with high degree of confidence. + +**Advantages :** +- Simpler solution to develop since the abstract structure of the query language is simpler to transform into compared with other transformation options. -using the build-in traverse visitor API +- Optimization potential by leveraging the specific knowledge of the actual original language and being able to directly use specific grammar function and commands directly. + +**Disadvantages :** +- Fully depended on the Source Code of the target language including potentially internal structure of its grammatical components - In spark case this is not a severe disadvantage since this is a very well know and well structured API grammar. +- Not sufficiently portable since this api is coupled with the + +**2) Source Logical Plan To destination Logical Plan (Catalyst) [Preferred Option]** +This option uses the syntax tree to directly translate from one language syntax grammar tree to the other language syntax grammar tree thus eliminating the parsing phase and creating a strongly validated process that can be verified and tested with high degree of confidence. + +Once the target plan is created - it can be analyzed and executed separately from the translations process (or location) + +``` + SparkSession spark = SparkSession.builder() + .appName("SparkExecuteLogicalPlan") + .master("local") + .getOrCreate(); + + // catalyst logical plan - translated from PPL Logical plan + Seq scalaProjectList = //... your project list + LogicalPlan unresolvedTable = //... your unresolved table + LogicalPlan projectNode = new Project(scalaProjectList, unresolvedTable); + + // Analyze and execute + Analyzer analyzer = new Analyzer(spark.sessionState().catalog(), spark.sessionState().conf()); + LogicalPlan analyzedPlan = analyzer.execute(projectNode); + LogicalPlan optimizedPlan = spark.sessionState().optimizer().execute(analyzedPlan); + + QueryExecution qe = spark.sessionState().executePlan(optimizedPlan); + Dataset result = new Dataset<>(spark, qe, RowEncoder.apply(qe.analyzed().schema())); + +``` +**Advantages :** +- A little more complex develop compared to the first option but still relatively simple since the abstract structure of the query language is simpler to transform into another’s language syntax grammar tree + +- Optimization potential by leveraging the specific knowledge of the actual original language and being able to directly use specific grammar function and commands directly. + +**Disadvantages :** +- Fully depended on the Source Code of the target language including potentially internal structure of its grammatical components - In spark case this is not a severe disadvantage since this is a very well know and well structured API grammar. +- Add the additional phase for analyzing the logical plan and generating the physical plan and the execution part itself. + + +**3) Source Grammar Tree To destination Query Translation** +This option uses the syntax tree to from the original query language into the target query (SQL in our case). This is a more generalized solution that may be utilized for additional purposes such as direct query to an RDBMS server. + +**Advantages :** +- A general purpose solution that may be utilized for other SQL compliant servers + +**Disadvantages :** +- This is a more complicated use case since it requires additional layer of complexity to be able to correctly translate the original syntax tree to a textual representation of the outcome language that has to be parsed and verified +- SQL plugin already support SQL so its not clear what is the advantage of translating PPL back to SQL since our plugin already supports SQL out of the box. + +--- +### Architecture + +**1. Using Spark Connect (PPL Grammar To dataframe API Translation)** + +In Apache Spark 3.4, Spark Connect introduced a decoupled client-server architecture that allows remote connectivity to Spark clusters using the DataFrame API and unresolved logical plans as the protocol. + +**How Spark Connect works**: +The Spark Connect client library is designed to simplify Spark application development. It is a thin API that can be embedded everywhere: in application servers, IDEs, notebooks, and programming languages. The Spark Connect API builds on Spark’s DataFrame API using unresolved logical plans as a language-agnostic protocol between the client and the Spark driver. + +The Spark Connect client translates DataFrame operations into unresolved logical query plans which are encoded using protocol buffers. These are sent to the server using the gRPC framework. +The Spark Connect endpoint embedded on the Spark Server receives and translates unresolved logical plans into Spark’s logical plan operators. This is similar to parsing a SQL query, where attributes and relations are parsed and an initial parse plan is built. From there, the standard Spark execution process kicks in, ensuring that Spark Connect leverages all of Spark’s optimizations and enhancements. Results are streamed back to the client through gRPC as Apache Arrow-encoded row batches. + +**Advantages :** +Stability: Applications that use too much memory will now only impact their own environment as they can run in their own processes. Users can define their own dependencies on the client and don’t need to worry about potential conflicts with the Spark driver. + +Upgradability: The Spark driver can now seamlessly be upgraded independently of applications, for example to benefit from performance improvements and security fixes. This means applications can be forward-compatible, as long as the server-side RPC definitions are designed to be backwards compatible. + +Debuggability and observability: Spark Connect enables interactive debugging during development directly from your favorite IDE. Similarly, applications can be monitored using the application’s framework native metrics and logging libraries. + +Not need separating PPL into a dedicated library - all can be done from the existing SQL repository. + +**Disadvantages :** +Not all _managed_ Spark solution support this "new" feature so as part of using this capability we will need to manually deploy the corresponding spark-connect plugins as part of flint’s deployment. + +All the context creation would have to be done from the spark client - this creates some additional complexity since the Flint spark plugin has some contextual requirements that have to be somehow propagated from the client’s side . + +--- + +### Implemented solution +As presented here and detailed in the [issue](https://github.com/opensearch-project/opensearch-spark/issues/30), there are several options to allow spark to be able to understand and run ppl queries. + +The selected option is to us the PPL AST logical plan API and traversals to transform the PPL logical plan into Catalyst logical plan thus enabling a the longer term +solution for using spark-connect as a part of the ppl-client (as described below): + +Advantages of the selected approach: + +- **reuse** of existing PPL code that is tested and in production +- **simplify** development while relying on well known and structured codebase +- **long term support** in case the `spark-connect` will become user chosen strategy - existing code can be used without any changes +- **single place of maintenance** by reusing the PPL logical model which relies on ppl antlr parser, we can use a single repository to maintain and develop the PPL language without the need to constantly merge changes upstream . + +The following diagram shows the high level architecture of the selected implementation solution : + +![ppl logical architecture ](https://github.com/opensearch-project/opensearch-spark/assets/48943349/6965258f-9823-4f12-a4f9-529c1365fc4a) + +The **logical Architecture** show the next artifacts: +- **_Libraries_**: + - PPL ( the ppl core , protocol, parser & logical plan utils) + - SQL ( the SQL core , protocol, parser - depends on PPL for using the logical plan utils) + +- **_Drivers_**: + - PPL OpenSearch Driver (depends on OpenSearch core) + - PPL Spark Driver (depends on Spark core) + - PPL Prometheus Driver (directly translates PPL to PromQL ) + - SQL OpenSearch Driver (depends on OpenSearch core) + +**Physical Architecture :** +Currently the drivers reside inside the PPL client repository within the OpenSearch Plugins. +Next tasks ahead will resolve this: + +- Extract PPL logical component outside the SQL plugin into a (none-plugin) library - publish library to maven +- Separate the PPL / SQL drivers inside the OpenSearch PPL client to better distinguish +- Create a thin PPL client capable of interaction with the PPL Driver regardless of which driver (Spark , OpenSearch , Prometheus ) + +--- + +### Roadmap + +This section describes the next steps planned for enabling additional commands and gamer translation. + +#### Supported +The next samples of PPL queries are currently supported: + +**Fields** + - `source = table` + - `source = table | fields a,b,c` + +**Filters** + - `source = table | where a = 1 | fields a,b,c` + - `source = table | where a >= 1 | fields a,b,c` + - `source = table | where a < 1 | fields a,b,c` + - `source = table | where b != 'test' | fields a,b,c` + - `source = table | where c = 'test' | fields a,b,c | head 3` + +**Filters With Logical Conditions** + - `source = table | where c = 'test' AND a = 1 | fields a,b,c` + - `source = table | where c != 'test' OR a > 1 | fields a,b,c | head 1` + - `source = table | where c = 'test' NOT a > 1 | fields a,b,c` + +**Aggregations** + - `source = table | stats avg(a) ` + - `source = table | where a < 50 | stats avg(c) ` + - `source = table | stats max(c) by b` + - `source = table | stats count(c) by b | head 5` + +**Aggregations With Span** +- `source = table | stats count(a) by span(a, 10) as a_span` +- `source = table | stats sum(age) by span(age, 5) as age_span | head 2` +- `source = table | stats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` + +**Aggregations With TimeWindow Span (tumble windowing function) ** +- `source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date` +- `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId` + +#### Supported Commands: + - `search` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/search.rst) + - `where` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/where.rst) + - `fields` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/fields.rst) + - `head` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/head.rst) + - `stats` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/stats.rst) + - `sort` - [See details](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) + +> For additional details review the next [Integration Test ](../integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala) +> For additional details review the next [Integration Time Window Test ](../integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLTimeWindowITSuite.scala) + +--- + +#### Planned Support + + - support the `explain` command to return the explained PPL query logical plan and expected execution plan + + - attend [sort](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/sort.rst) partially supported, missing capability to sort by alias field (span like or aggregation) + - attend `alias` - partially supported, missing capability to sort by / group-by alias field name + + - add [conditions](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/condition.rst) support + - add [top](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/top.rst) support + - add [cast](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/conversion.rst) support + - add [math](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/functions/math.rst) support + - add [deduplicate](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/cmd/dedup.rst) support \ No newline at end of file diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 new file mode 100644 index 000000000..e74aed30e --- /dev/null +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -0,0 +1,400 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +lexer grammar OpenSearchPPLLexer; + +channels { WHITESPACE, ERRORCHANNEL } + + +// COMMAND KEYWORDS +SEARCH: 'SEARCH'; +DESCRIBE: 'DESCRIBE'; +SHOW: 'SHOW'; +FROM: 'FROM'; +WHERE: 'WHERE'; +FIELDS: 'FIELDS'; +RENAME: 'RENAME'; +STATS: 'STATS'; +DEDUP: 'DEDUP'; +SORT: 'SORT'; +EVAL: 'EVAL'; +HEAD: 'HEAD'; +TOP: 'TOP'; +RARE: 'RARE'; +PARSE: 'PARSE'; +METHOD: 'METHOD'; +REGEX: 'REGEX'; +PUNCT: 'PUNCT'; +GROK: 'GROK'; +PATTERN: 'PATTERN'; +PATTERNS: 'PATTERNS'; +NEW_FIELD: 'NEW_FIELD'; +KMEANS: 'KMEANS'; +AD: 'AD'; +ML: 'ML'; + +// COMMAND ASSIST KEYWORDS +AS: 'AS'; +BY: 'BY'; +SOURCE: 'SOURCE'; +INDEX: 'INDEX'; +D: 'D'; +DESC: 'DESC'; +DATASOURCES: 'DATASOURCES'; + +// CLAUSE KEYWORDS +SORTBY: 'SORTBY'; + +// FIELD KEYWORDS +AUTO: 'AUTO'; +STR: 'STR'; +IP: 'IP'; +NUM: 'NUM'; + +// ARGUMENT KEYWORDS +KEEPEMPTY: 'KEEPEMPTY'; +CONSECUTIVE: 'CONSECUTIVE'; +DEDUP_SPLITVALUES: 'DEDUP_SPLITVALUES'; +PARTITIONS: 'PARTITIONS'; +ALLNUM: 'ALLNUM'; +DELIM: 'DELIM'; +CENTROIDS: 'CENTROIDS'; +ITERATIONS: 'ITERATIONS'; +DISTANCE_TYPE: 'DISTANCE_TYPE'; +NUMBER_OF_TREES: 'NUMBER_OF_TREES'; +SHINGLE_SIZE: 'SHINGLE_SIZE'; +SAMPLE_SIZE: 'SAMPLE_SIZE'; +OUTPUT_AFTER: 'OUTPUT_AFTER'; +TIME_DECAY: 'TIME_DECAY'; +ANOMALY_RATE: 'ANOMALY_RATE'; +CATEGORY_FIELD: 'CATEGORY_FIELD'; +TIME_FIELD: 'TIME_FIELD'; +TIME_ZONE: 'TIME_ZONE'; +TRAINING_DATA_SIZE: 'TRAINING_DATA_SIZE'; +ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; + +// COMPARISON FUNCTION KEYWORDS +CASE: 'CASE'; +IN: 'IN'; + +// LOGICAL KEYWORDS +NOT: 'NOT'; +OR: 'OR'; +AND: 'AND'; +XOR: 'XOR'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +REGEXP: 'REGEXP'; + +// DATETIME, INTERVAL AND UNIT KEYWORDS +CONVERT_TZ: 'CONVERT_TZ'; +DATETIME: 'DATETIME'; +DAY: 'DAY'; +DAY_HOUR: 'DAY_HOUR'; +DAY_MICROSECOND: 'DAY_MICROSECOND'; +DAY_MINUTE: 'DAY_MINUTE'; +DAY_OF_YEAR: 'DAY_OF_YEAR'; +DAY_SECOND: 'DAY_SECOND'; +HOUR: 'HOUR'; +HOUR_MICROSECOND: 'HOUR_MICROSECOND'; +HOUR_MINUTE: 'HOUR_MINUTE'; +HOUR_OF_DAY: 'HOUR_OF_DAY'; +HOUR_SECOND: 'HOUR_SECOND'; +INTERVAL: 'INTERVAL'; +MICROSECOND: 'MICROSECOND'; +MILLISECOND: 'MILLISECOND'; +MINUTE: 'MINUTE'; +MINUTE_MICROSECOND: 'MINUTE_MICROSECOND'; +MINUTE_OF_DAY: 'MINUTE_OF_DAY'; +MINUTE_OF_HOUR: 'MINUTE_OF_HOUR'; +MINUTE_SECOND: 'MINUTE_SECOND'; +MONTH: 'MONTH'; +MONTH_OF_YEAR: 'MONTH_OF_YEAR'; +QUARTER: 'QUARTER'; +SECOND: 'SECOND'; +SECOND_MICROSECOND: 'SECOND_MICROSECOND'; +SECOND_OF_MINUTE: 'SECOND_OF_MINUTE'; +WEEK: 'WEEK'; +WEEK_OF_YEAR: 'WEEK_OF_YEAR'; +YEAR: 'YEAR'; +YEAR_MONTH: 'YEAR_MONTH'; + +// DATASET TYPES +DATAMODEL: 'DATAMODEL'; +LOOKUP: 'LOOKUP'; +SAVEDSEARCH: 'SAVEDSEARCH'; + +// CONVERTED DATA TYPES +INT: 'INT'; +INTEGER: 'INTEGER'; +DOUBLE: 'DOUBLE'; +LONG: 'LONG'; +FLOAT: 'FLOAT'; +STRING: 'STRING'; +BOOLEAN: 'BOOLEAN'; + +// SPECIAL CHARACTERS AND OPERATORS +PIPE: '|'; +COMMA: ','; +DOT: '.'; +EQUAL: '='; +GREATER: '>'; +LESS: '<'; +NOT_GREATER: '<' '='; +NOT_LESS: '>' '='; +NOT_EQUAL: '!' '='; +PLUS: '+'; +MINUS: '-'; +STAR: '*'; +DIVIDE: '/'; +MODULE: '%'; +EXCLAMATION_SYMBOL: '!'; +COLON: ':'; +LT_PRTHS: '('; +RT_PRTHS: ')'; +LT_SQR_PRTHS: '['; +RT_SQR_PRTHS: ']'; +SINGLE_QUOTE: '\''; +DOUBLE_QUOTE: '"'; +BACKTICK: '`'; + +// Operators. Bit + +BIT_NOT_OP: '~'; +BIT_AND_OP: '&'; +BIT_XOR_OP: '^'; + +// AGGREGATIONS +AVG: 'AVG'; +COUNT: 'COUNT'; +DISTINCT_COUNT: 'DISTINCT_COUNT'; +ESTDC: 'ESTDC'; +ESTDC_ERROR: 'ESTDC_ERROR'; +MAX: 'MAX'; +MEAN: 'MEAN'; +MEDIAN: 'MEDIAN'; +MIN: 'MIN'; +MODE: 'MODE'; +RANGE: 'RANGE'; +STDEV: 'STDEV'; +STDEVP: 'STDEVP'; +SUM: 'SUM'; +SUMSQ: 'SUMSQ'; +VAR_SAMP: 'VAR_SAMP'; +VAR_POP: 'VAR_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; +STDDEV_POP: 'STDDEV_POP'; +PERCENTILE: 'PERCENTILE'; +TAKE: 'TAKE'; +FIRST: 'FIRST'; +LAST: 'LAST'; +LIST: 'LIST'; +VALUES: 'VALUES'; +EARLIEST: 'EARLIEST'; +EARLIEST_TIME: 'EARLIEST_TIME'; +LATEST: 'LATEST'; +LATEST_TIME: 'LATEST_TIME'; +PER_DAY: 'PER_DAY'; +PER_HOUR: 'PER_HOUR'; +PER_MINUTE: 'PER_MINUTE'; +PER_SECOND: 'PER_SECOND'; +RATE: 'RATE'; +SPARKLINE: 'SPARKLINE'; +C: 'C'; +DC: 'DC'; + +// BASIC FUNCTIONS +ABS: 'ABS'; +CBRT: 'CBRT'; +CEIL: 'CEIL'; +CEILING: 'CEILING'; +CONV: 'CONV'; +CRC32: 'CRC32'; +E: 'E'; +EXP: 'EXP'; +FLOOR: 'FLOOR'; +LN: 'LN'; +LOG: 'LOG'; +LOG10: 'LOG10'; +LOG2: 'LOG2'; +MOD: 'MOD'; +PI: 'PI'; +POSITION: 'POSITION'; +POW: 'POW'; +POWER: 'POWER'; +RAND: 'RAND'; +ROUND: 'ROUND'; +SIGN: 'SIGN'; +SQRT: 'SQRT'; +TRUNCATE: 'TRUNCATE'; + +// TRIGONOMETRIC FUNCTIONS +ACOS: 'ACOS'; +ASIN: 'ASIN'; +ATAN: 'ATAN'; +ATAN2: 'ATAN2'; +COS: 'COS'; +COT: 'COT'; +DEGREES: 'DEGREES'; +RADIANS: 'RADIANS'; +SIN: 'SIN'; +TAN: 'TAN'; + +// DATE AND TIME FUNCTIONS +ADDDATE: 'ADDDATE'; +ADDTIME: 'ADDTIME'; +CURDATE: 'CURDATE'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURTIME: 'CURTIME'; +DATE: 'DATE'; +DATEDIFF: 'DATEDIFF'; +DATE_ADD: 'DATE_ADD'; +DATE_FORMAT: 'DATE_FORMAT'; +DATE_SUB: 'DATE_SUB'; +DAYNAME: 'DAYNAME'; +DAYOFMONTH: 'DAYOFMONTH'; +DAYOFWEEK: 'DAYOFWEEK'; +DAYOFYEAR: 'DAYOFYEAR'; +DAY_OF_MONTH: 'DAY_OF_MONTH'; +DAY_OF_WEEK: 'DAY_OF_WEEK'; +EXTRACT: 'EXTRACT'; +FROM_DAYS: 'FROM_DAYS'; +FROM_UNIXTIME: 'FROM_UNIXTIME'; +GET_FORMAT: 'GET_FORMAT'; +LAST_DAY: 'LAST_DAY'; +LOCALTIME: 'LOCALTIME'; +LOCALTIMESTAMP: 'LOCALTIMESTAMP'; +MAKEDATE: 'MAKEDATE'; +MAKETIME: 'MAKETIME'; +MONTHNAME: 'MONTHNAME'; +NOW: 'NOW'; +PERIOD_ADD: 'PERIOD_ADD'; +PERIOD_DIFF: 'PERIOD_DIFF'; +SEC_TO_TIME: 'SEC_TO_TIME'; +STR_TO_DATE: 'STR_TO_DATE'; +SUBDATE: 'SUBDATE'; +SUBTIME: 'SUBTIME'; +SYSDATE: 'SYSDATE'; +TIME: 'TIME'; +TIMEDIFF: 'TIMEDIFF'; +TIMESTAMP: 'TIMESTAMP'; +TIMESTAMPADD: 'TIMESTAMPADD'; +TIMESTAMPDIFF: 'TIMESTAMPDIFF'; +TIME_FORMAT: 'TIME_FORMAT'; +TIME_TO_SEC: 'TIME_TO_SEC'; +TO_DAYS: 'TO_DAYS'; +TO_SECONDS: 'TO_SECONDS'; +UNIX_TIMESTAMP: 'UNIX_TIMESTAMP'; +UTC_DATE: 'UTC_DATE'; +UTC_TIME: 'UTC_TIME'; +UTC_TIMESTAMP: 'UTC_TIMESTAMP'; +WEEKDAY: 'WEEKDAY'; +YEARWEEK: 'YEARWEEK'; + +// TEXT FUNCTIONS +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +LTRIM: 'LTRIM'; +RTRIM: 'RTRIM'; +TRIM: 'TRIM'; +TO: 'TO'; +LOWER: 'LOWER'; +UPPER: 'UPPER'; +CONCAT: 'CONCAT'; +CONCAT_WS: 'CONCAT_WS'; +LENGTH: 'LENGTH'; +STRCMP: 'STRCMP'; +RIGHT: 'RIGHT'; +LEFT: 'LEFT'; +ASCII: 'ASCII'; +LOCATE: 'LOCATE'; +REPLACE: 'REPLACE'; +REVERSE: 'REVERSE'; +CAST: 'CAST'; + +// BOOL FUNCTIONS +LIKE: 'LIKE'; +ISNULL: 'ISNULL'; +ISNOTNULL: 'ISNOTNULL'; + +// FLOWCONTROL FUNCTIONS +IFNULL: 'IFNULL'; +NULLIF: 'NULLIF'; +IF: 'IF'; +TYPEOF: 'TYPEOF'; + +// RELEVANCE FUNCTIONS AND PARAMETERS +MATCH: 'MATCH'; +MATCH_PHRASE: 'MATCH_PHRASE'; +MATCH_PHRASE_PREFIX: 'MATCH_PHRASE_PREFIX'; +MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; +SIMPLE_QUERY_STRING: 'SIMPLE_QUERY_STRING'; +MULTI_MATCH: 'MULTI_MATCH'; +QUERY_STRING: 'QUERY_STRING'; + +ALLOW_LEADING_WILDCARD: 'ALLOW_LEADING_WILDCARD'; +ANALYZE_WILDCARD: 'ANALYZE_WILDCARD'; +ANALYZER: 'ANALYZER'; +AUTO_GENERATE_SYNONYMS_PHRASE_QUERY:'AUTO_GENERATE_SYNONYMS_PHRASE_QUERY'; +BOOST: 'BOOST'; +CUTOFF_FREQUENCY: 'CUTOFF_FREQUENCY'; +DEFAULT_FIELD: 'DEFAULT_FIELD'; +DEFAULT_OPERATOR: 'DEFAULT_OPERATOR'; +ENABLE_POSITION_INCREMENTS: 'ENABLE_POSITION_INCREMENTS'; +ESCAPE: 'ESCAPE'; +FLAGS: 'FLAGS'; +FUZZY_MAX_EXPANSIONS: 'FUZZY_MAX_EXPANSIONS'; +FUZZY_PREFIX_LENGTH: 'FUZZY_PREFIX_LENGTH'; +FUZZY_TRANSPOSITIONS: 'FUZZY_TRANSPOSITIONS'; +FUZZY_REWRITE: 'FUZZY_REWRITE'; +FUZZINESS: 'FUZZINESS'; +LENIENT: 'LENIENT'; +LOW_FREQ_OPERATOR: 'LOW_FREQ_OPERATOR'; +MAX_DETERMINIZED_STATES: 'MAX_DETERMINIZED_STATES'; +MAX_EXPANSIONS: 'MAX_EXPANSIONS'; +MINIMUM_SHOULD_MATCH: 'MINIMUM_SHOULD_MATCH'; +OPERATOR: 'OPERATOR'; +PHRASE_SLOP: 'PHRASE_SLOP'; +PREFIX_LENGTH: 'PREFIX_LENGTH'; +QUOTE_ANALYZER: 'QUOTE_ANALYZER'; +QUOTE_FIELD_SUFFIX: 'QUOTE_FIELD_SUFFIX'; +REWRITE: 'REWRITE'; +SLOP: 'SLOP'; +TIE_BREAKER: 'TIE_BREAKER'; +TYPE: 'TYPE'; +ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; + +// SPAN KEYWORDS +SPAN: 'SPAN'; +MS: 'MS'; +S: 'S'; +M: 'M'; +H: 'H'; +W: 'W'; +Q: 'Q'; +Y: 'Y'; + + +// LITERALS AND VALUES +//STRING_LITERAL: DQUOTA_STRING | SQUOTA_STRING | BQUOTA_STRING; +ID: ID_LITERAL; +CLUSTER: CLUSTER_PREFIX_LITERAL; +INTEGER_LITERAL: DEC_DIGIT+; +DECIMAL_LITERAL: (DEC_DIGIT+)? '.' DEC_DIGIT+; + +fragment DATE_SUFFIX: ([\-.][*0-9]+)+; +fragment ID_LITERAL: [@*A-Z]+?[*A-Z_\-0-9]*; +fragment CLUSTER_PREFIX_LITERAL: [*A-Z]+?[*A-Z_\-0-9]* COLON; +ID_DATE_SUFFIX: CLUSTER_PREFIX_LITERAL? ID_LITERAL DATE_SUFFIX; +DQUOTA_STRING: '"' ( '\\'. | '""' | ~('"'| '\\') )* '"'; +SQUOTA_STRING: '\'' ('\\'. | '\'\'' | ~('\'' | '\\'))* '\''; +BQUOTA_STRING: '`' ( '\\'. | '``' | ~('`'|'\\'))* '`'; +fragment DEC_DIGIT: [0-9]; + + +ERROR_RECOGNITION: . -> channel(ERRORCHANNEL); diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 new file mode 100644 index 000000000..69f560f25 --- /dev/null +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -0,0 +1,913 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +parser grammar OpenSearchPPLParser; + + +options { tokenVocab = OpenSearchPPLLexer; } +root + : pplStatement? EOF + ; + +// statement +pplStatement + : dmlStatement + ; + +dmlStatement + : queryStatement + ; + +queryStatement + : pplCommands (PIPE commands)* + ; + +// commands +pplCommands + : searchCommand + | describeCommand + | showDataSourcesCommand + ; + +commands + : whereCommand + | fieldsCommand + | renameCommand + | statsCommand + | dedupCommand + | sortCommand + | evalCommand + | headCommand + | topCommand + | rareCommand + | grokCommand + | parseCommand + | patternsCommand + | kmeansCommand + | adCommand + | mlCommand + ; + +searchCommand + : (SEARCH)? fromClause # searchFrom + | (SEARCH)? fromClause logicalExpression # searchFromFilter + | (SEARCH)? logicalExpression fromClause # searchFilterFrom + ; + +describeCommand + : DESCRIBE tableSourceClause + ; + +showDataSourcesCommand + : SHOW DATASOURCES + ; + +whereCommand + : WHERE logicalExpression + ; + +fieldsCommand + : FIELDS (PLUS | MINUS)? fieldList + ; + +renameCommand + : RENAME renameClasue (COMMA renameClasue)* + ; + +statsCommand + : STATS (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? + ; + +dedupCommand + : DEDUP (number = integerLiteral)? fieldList (KEEPEMPTY EQUAL keepempty = booleanLiteral)? (CONSECUTIVE EQUAL consecutive = booleanLiteral)? + ; + +sortCommand + : SORT sortbyClause + ; + +evalCommand + : EVAL evalClause (COMMA evalClause)* + ; + +headCommand + : HEAD (number = integerLiteral)? (FROM from = integerLiteral)? + ; + +topCommand + : TOP (number = integerLiteral)? fieldList (byClause)? + ; + +rareCommand + : RARE fieldList (byClause)? + ; + +grokCommand + : GROK (source_field = expression) (pattern = stringLiteral) + ; + +parseCommand + : PARSE (source_field = expression) (pattern = stringLiteral) + ; + +patternsCommand + : PATTERNS (patternsParameter)* (source_field = expression) + ; + +patternsParameter + : (NEW_FIELD EQUAL new_field = stringLiteral) + | (PATTERN EQUAL pattern = stringLiteral) + ; + +patternsMethod + : PUNCT + | REGEX + ; + +kmeansCommand + : KMEANS (kmeansParameter)* + ; + +kmeansParameter + : (CENTROIDS EQUAL centroids = integerLiteral) + | (ITERATIONS EQUAL iterations = integerLiteral) + | (DISTANCE_TYPE EQUAL distance_type = stringLiteral) + ; + +adCommand + : AD (adParameter)* + ; + +adParameter + : (NUMBER_OF_TREES EQUAL number_of_trees = integerLiteral) + | (SHINGLE_SIZE EQUAL shingle_size = integerLiteral) + | (SAMPLE_SIZE EQUAL sample_size = integerLiteral) + | (OUTPUT_AFTER EQUAL output_after = integerLiteral) + | (TIME_DECAY EQUAL time_decay = decimalLiteral) + | (ANOMALY_RATE EQUAL anomaly_rate = decimalLiteral) + | (CATEGORY_FIELD EQUAL category_field = stringLiteral) + | (TIME_FIELD EQUAL time_field = stringLiteral) + | (DATE_FORMAT EQUAL date_format = stringLiteral) + | (TIME_ZONE EQUAL time_zone = stringLiteral) + | (TRAINING_DATA_SIZE EQUAL training_data_size = integerLiteral) + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold = decimalLiteral) + ; + +mlCommand + : ML (mlArg)* + ; + +mlArg + : (argName = ident EQUAL argValue = literalValue) + ; + +// clauses +fromClause + : SOURCE EQUAL tableSourceClause + | INDEX EQUAL tableSourceClause + | SOURCE EQUAL tableFunction + | INDEX EQUAL tableFunction + ; + +tableSourceClause + : tableSource (COMMA tableSource)* + ; + +renameClasue + : orignalField = wcFieldExpression AS renamedField = wcFieldExpression + ; + +byClause + : BY fieldList + ; + +statsByClause + : BY fieldList + | BY bySpanClause + | BY bySpanClause COMMA fieldList + ; + +bySpanClause + : spanClause (AS alias = qualifiedName)? + ; + +spanClause + : SPAN LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS + ; + +sortbyClause + : sortField (COMMA sortField)* + ; + +evalClause + : fieldExpression EQUAL expression + ; + +// aggregation terms +statsAggTerm + : statsFunction (AS alias = wcFieldExpression)? + ; + +// aggregation functions +statsFunction + : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall + | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | percentileAggFunction # percentileAggFunctionCall + | takeAggFunction # takeAggFunctionCall + ; + +statsFunctionName + : AVG + | COUNT + | SUM + | MIN + | MAX + | VAR_SAMP + | VAR_POP + | STDDEV_SAMP + | STDDEV_POP + ; + +takeAggFunction + : TAKE LT_PRTHS fieldExpression (COMMA size = integerLiteral)? RT_PRTHS + ; + +percentileAggFunction + : PERCENTILE LESS value = integerLiteral GREATER LT_PRTHS aggField = fieldExpression RT_PRTHS + ; + +// expressions +expression + : logicalExpression + | comparisonExpression + | valueExpression + ; + +logicalExpression + : comparisonExpression # comparsion + | NOT logicalExpression # logicalNot + | left = logicalExpression OR right = logicalExpression # logicalOr + | left = logicalExpression (AND)? right = logicalExpression # logicalAnd + | left = logicalExpression XOR right = logicalExpression # logicalXor + | booleanExpression # booleanExpr + | relevanceExpression # relevanceExpr + ; + +comparisonExpression + : left = valueExpression comparisonOperator right = valueExpression # compareExpr + | valueExpression IN valueList # inExpr + ; + +valueExpression + : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic + | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic + | primaryExpression # valueExpressionDefault + | positionFunction # positionFunctionCall + | extractFunction # extractFunctionCall + | getFormatFunction # getFormatFunctionCall + | timestampFunction # timestampFunctionCall + | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr + ; + +primaryExpression + : evalFunctionCall + | dataTypeFunctionCall + | fieldExpression + | literalValue + ; + +positionFunction + : positionFunctionName LT_PRTHS functionArg IN functionArg RT_PRTHS + ; + +booleanExpression + : booleanFunctionCall + ; + +relevanceExpression + : singleFieldRelevanceFunction + | multiFieldRelevanceFunction + ; + +// Field is a single column +singleFieldRelevanceFunction + : singleFieldRelevanceFunctionName LT_PRTHS field = relevanceField COMMA query = relevanceQuery (COMMA relevanceArg)* RT_PRTHS + ; + +// Field is a list of columns +multiFieldRelevanceFunction + : multiFieldRelevanceFunctionName LT_PRTHS LT_SQR_PRTHS field = relevanceFieldAndWeight (COMMA field = relevanceFieldAndWeight)* RT_SQR_PRTHS COMMA query = relevanceQuery (COMMA relevanceArg)* RT_PRTHS + ; + +// tables +tableSource + : tableQualifiedName + | ID_DATE_SUFFIX + ; + +tableFunction + : qualifiedName LT_PRTHS functionArgs RT_PRTHS + ; + +// fields +fieldList + : fieldExpression (COMMA fieldExpression)* + ; + +wcFieldList + : wcFieldExpression (COMMA wcFieldExpression)* + ; + +sortField + : (PLUS | MINUS)? sortFieldExpression + ; + +sortFieldExpression + : fieldExpression + | AUTO LT_PRTHS fieldExpression RT_PRTHS + | STR LT_PRTHS fieldExpression RT_PRTHS + | IP LT_PRTHS fieldExpression RT_PRTHS + | NUM LT_PRTHS fieldExpression RT_PRTHS + ; + +fieldExpression + : qualifiedName + ; + +wcFieldExpression + : wcQualifiedName + ; + +// functions +evalFunctionCall + : evalFunctionName LT_PRTHS functionArgs RT_PRTHS + ; + +// cast function +dataTypeFunctionCall + : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS + ; + +// boolean functions +booleanFunctionCall + : conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS + ; + +convertedDataType + : typeName = DATE + | typeName = TIME + | typeName = TIMESTAMP + | typeName = INT + | typeName = INTEGER + | typeName = DOUBLE + | typeName = LONG + | typeName = FLOAT + | typeName = STRING + | typeName = BOOLEAN + ; + +evalFunctionName + : mathematicalFunctionName + | dateTimeFunctionName + | textFunctionName + | conditionFunctionBase + | systemFunctionName + | positionFunctionName + ; + +functionArgs + : (functionArg (COMMA functionArg)*)? + ; + +functionArg + : (ident EQUAL)? valueExpression + ; + +relevanceArg + : relevanceArgName EQUAL relevanceArgValue + ; + +relevanceArgName + : ALLOW_LEADING_WILDCARD + | ANALYZER + | ANALYZE_WILDCARD + | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY + | BOOST + | CUTOFF_FREQUENCY + | DEFAULT_FIELD + | DEFAULT_OPERATOR + | ENABLE_POSITION_INCREMENTS + | ESCAPE + | FIELDS + | FLAGS + | FUZZINESS + | FUZZY_MAX_EXPANSIONS + | FUZZY_PREFIX_LENGTH + | FUZZY_REWRITE + | FUZZY_TRANSPOSITIONS + | LENIENT + | LOW_FREQ_OPERATOR + | MAX_DETERMINIZED_STATES + | MAX_EXPANSIONS + | MINIMUM_SHOULD_MATCH + | OPERATOR + | PHRASE_SLOP + | PREFIX_LENGTH + | QUOTE_ANALYZER + | QUOTE_FIELD_SUFFIX + | REWRITE + | SLOP + | TIE_BREAKER + | TIME_ZONE + | TYPE + | ZERO_TERMS_QUERY + ; + +relevanceFieldAndWeight + : field = relevanceField + | field = relevanceField weight = relevanceFieldWeight + | field = relevanceField BIT_XOR_OP weight = relevanceFieldWeight + ; + +relevanceFieldWeight + : integerLiteral + | decimalLiteral + ; + +relevanceField + : qualifiedName + | stringLiteral + ; + +relevanceQuery + : relevanceArgValue + ; + +relevanceArgValue + : qualifiedName + | literalValue + ; + +mathematicalFunctionName + : ABS + | CBRT + | CEIL + | CEILING + | CONV + | CRC32 + | E + | EXP + | FLOOR + | LN + | LOG + | LOG10 + | LOG2 + | MOD + | PI + | POW + | POWER + | RAND + | ROUND + | SIGN + | SQRT + | TRUNCATE + | trigonometricFunctionName + ; + +trigonometricFunctionName + : ACOS + | ASIN + | ATAN + | ATAN2 + | COS + | COT + | DEGREES + | RADIANS + | SIN + | TAN + ; + +dateTimeFunctionName + : ADDDATE + | ADDTIME + | CONVERT_TZ + | CURDATE + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURTIME + | DATE + | DATEDIFF + | DATETIME + | DATE_ADD + | DATE_FORMAT + | DATE_SUB + | DAY + | DAYNAME + | DAYOFMONTH + | DAYOFWEEK + | DAYOFYEAR + | DAY_OF_MONTH + | DAY_OF_WEEK + | DAY_OF_YEAR + | FROM_DAYS + | FROM_UNIXTIME + | HOUR + | HOUR_OF_DAY + | LAST_DAY + | LOCALTIME + | LOCALTIMESTAMP + | MAKEDATE + | MAKETIME + | MICROSECOND + | MINUTE + | MINUTE_OF_DAY + | MINUTE_OF_HOUR + | MONTH + | MONTHNAME + | MONTH_OF_YEAR + | NOW + | PERIOD_ADD + | PERIOD_DIFF + | QUARTER + | SECOND + | SECOND_OF_MINUTE + | SEC_TO_TIME + | STR_TO_DATE + | SUBDATE + | SUBTIME + | SYSDATE + | TIME + | TIMEDIFF + | TIMESTAMP + | TIME_FORMAT + | TIME_TO_SEC + | TO_DAYS + | TO_SECONDS + | UNIX_TIMESTAMP + | UTC_DATE + | UTC_TIME + | UTC_TIMESTAMP + | WEEK + | WEEKDAY + | WEEK_OF_YEAR + | YEAR + | YEARWEEK + ; + +getFormatFunction + : GET_FORMAT LT_PRTHS getFormatType COMMA functionArg RT_PRTHS + ; + +getFormatType + : DATE + | DATETIME + | TIME + | TIMESTAMP + ; + +extractFunction + : EXTRACT LT_PRTHS datetimePart FROM functionArg RT_PRTHS + ; + +simpleDateTimePart + : MICROSECOND + | SECOND + | MINUTE + | HOUR + | DAY + | WEEK + | MONTH + | QUARTER + | YEAR + ; + +complexDateTimePart + : SECOND_MICROSECOND + | MINUTE_MICROSECOND + | MINUTE_SECOND + | HOUR_MICROSECOND + | HOUR_SECOND + | HOUR_MINUTE + | DAY_MICROSECOND + | DAY_SECOND + | DAY_MINUTE + | DAY_HOUR + | YEAR_MONTH + ; + +datetimePart + : simpleDateTimePart + | complexDateTimePart + ; + +timestampFunction + : timestampFunctionName LT_PRTHS simpleDateTimePart COMMA firstArg = functionArg COMMA secondArg = functionArg RT_PRTHS + ; + +timestampFunctionName + : TIMESTAMPADD + | TIMESTAMPDIFF + ; + +// condition function return boolean value +conditionFunctionBase + : LIKE + | IF + | ISNULL + | ISNOTNULL + | IFNULL + | NULLIF + ; + +systemFunctionName + : TYPEOF + ; + +textFunctionName + : SUBSTR + | SUBSTRING + | TRIM + | LTRIM + | RTRIM + | LOWER + | UPPER + | CONCAT + | CONCAT_WS + | LENGTH + | STRCMP + | RIGHT + | LEFT + | ASCII + | LOCATE + | REPLACE + | REVERSE + ; + +positionFunctionName + : POSITION + ; + +// operators + comparisonOperator + : EQUAL + | NOT_EQUAL + | LESS + | NOT_LESS + | GREATER + | NOT_GREATER + | REGEXP + ; + +singleFieldRelevanceFunctionName + : MATCH + | MATCH_PHRASE + | MATCH_BOOL_PREFIX + | MATCH_PHRASE_PREFIX + ; + +multiFieldRelevanceFunctionName + : SIMPLE_QUERY_STRING + | MULTI_MATCH + | QUERY_STRING + ; + +// literals and values +literalValue + : intervalLiteral + | stringLiteral + | integerLiteral + | decimalLiteral + | booleanLiteral + | datetimeLiteral //#datetime + ; + +intervalLiteral + : INTERVAL valueExpression intervalUnit + ; + +stringLiteral + : DQUOTA_STRING + | SQUOTA_STRING + ; + +integerLiteral + : (PLUS | MINUS)? INTEGER_LITERAL + ; + +decimalLiteral + : (PLUS | MINUS)? DECIMAL_LITERAL + ; + +booleanLiteral + : TRUE + | FALSE + ; + +// Date and Time Literal, follow ANSI 92 +datetimeLiteral + : dateLiteral + | timeLiteral + | timestampLiteral + ; + +dateLiteral + : DATE date = stringLiteral + ; + +timeLiteral + : TIME time = stringLiteral + ; + +timestampLiteral + : TIMESTAMP timestamp = stringLiteral + ; + +intervalUnit + : MICROSECOND + | SECOND + | MINUTE + | HOUR + | DAY + | WEEK + | MONTH + | QUARTER + | YEAR + | SECOND_MICROSECOND + | MINUTE_MICROSECOND + | MINUTE_SECOND + | HOUR_MICROSECOND + | HOUR_SECOND + | HOUR_MINUTE + | DAY_MICROSECOND + | DAY_SECOND + | DAY_MINUTE + | DAY_HOUR + | YEAR_MONTH + ; + +timespanUnit + : MS + | S + | M + | H + | D + | W + | Q + | Y + | MILLISECOND + | SECOND + | MINUTE + | HOUR + | DAY + | WEEK + | MONTH + | QUARTER + | YEAR + ; + +valueList + : LT_PRTHS literalValue (COMMA literalValue)* RT_PRTHS + ; + +qualifiedName + : ident (DOT ident)* # identsAsQualifiedName + ; + +tableQualifiedName + : tableIdent (DOT ident)* # identsAsTableQualifiedName + ; + +wcQualifiedName + : wildcard (DOT wildcard)* # identsAsWildcardQualifiedName + ; + +ident + : (DOT)? ID + | BACKTICK ident BACKTICK + | BQUOTA_STRING + | keywordsCanBeId + ; + +tableIdent + : (CLUSTER)? ident + ; + +wildcard + : ident (MODULE ident)* (MODULE)? + | SINGLE_QUOTE wildcard SINGLE_QUOTE + | DOUBLE_QUOTE wildcard DOUBLE_QUOTE + | BACKTICK wildcard BACKTICK + ; + +keywordsCanBeId + : D // OD SQL and ODBC special + | timespanUnit + | SPAN + | evalFunctionName + | relevanceArgName + | intervalUnit + | dateTimeFunctionName + | textFunctionName + | mathematicalFunctionName + | positionFunctionName + // commands + | SEARCH + | DESCRIBE + | SHOW + | FROM + | WHERE + | FIELDS + | RENAME + | STATS + | DEDUP + | SORT + | EVAL + | HEAD + | TOP + | RARE + | PARSE + | METHOD + | REGEX + | PUNCT + | GROK + | PATTERN + | PATTERNS + | NEW_FIELD + | KMEANS + | AD + | ML + // commands assist keywords + | SOURCE + | INDEX + | DESC + | DATASOURCES + // CLAUSEKEYWORDS + | SORTBY + // FIELDKEYWORDSAUTO + | STR + | IP + | NUM + // ARGUMENT KEYWORDS + | KEEPEMPTY + | CONSECUTIVE + | DEDUP_SPLITVALUES + | PARTITIONS + | ALLNUM + | DELIM + | CENTROIDS + | ITERATIONS + | DISTANCE_TYPE + | NUMBER_OF_TREES + | SHINGLE_SIZE + | SAMPLE_SIZE + | OUTPUT_AFTER + | TIME_DECAY + | ANOMALY_RATE + | CATEGORY_FIELD + | TIME_FIELD + | TIME_ZONE + | TRAINING_DATA_SIZE + | ANOMALY_SCORE_THRESHOLD + // AGGREGATIONS + | AVG + | COUNT + | DISTINCT_COUNT + | ESTDC + | ESTDC_ERROR + | MAX + | MEAN + | MEDIAN + | MIN + | MODE + | RANGE + | STDEV + | STDEVP + | SUM + | SUMSQ + | VAR_SAMP + | VAR_POP + | STDDEV_SAMP + | STDDEV_POP + | PERCENTILE + | TAKE + | FIRST + | LAST + | LIST + | VALUES + | EARLIEST + | EARLIEST_TIME + | LATEST + | LATEST_TIME + | PER_DAY + | PER_HOUR + | PER_MINUTE + | PER_SECOND + | RATE + | SPARKLINE + | C + | DC + ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java new file mode 100644 index 000000000..9a2e88484 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast; + +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.AttributeList; +import org.opensearch.sql.ast.expression.Between; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.EqualTo; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.expression.UnresolvedAttribute; +import org.opensearch.sql.ast.expression.When; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.statement.Explain; +import org.opensearch.sql.ast.statement.Query; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Limit; +import org.opensearch.sql.ast.tree.Parse; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Rename; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Values; + +/** AST nodes visitor Defines the traverse path. */ +public abstract class AbstractNodeVisitor { + + public T visit(Node node, C context) { + return null; + } + + /** + * Visit child node. + * + * @param node {@link Node} + * @param context Context + * @return Return Type. + */ + public T visitChildren(Node node, C context) { + T result = defaultResult(); + + for (Node child : node.getChild()) { + T childResult = child.accept(this, context); + result = aggregateResult(result, childResult); + } + return result; + } + + private T defaultResult() { + return null; + } + + private T aggregateResult(T aggregate, T nextResult) { + return nextResult; + } + + public T visitRelation(Relation node, C context) { + return visitChildren(node, context); + } + + public T visitTableFunction(TableFunction node, C context) { + return visitChildren(node, context); + } + + public T visitFilter(Filter node, C context) { + return visitChildren(node, context); + } + + public T visitProject(Project node, C context) { + return visitChildren(node, context); + } + + public T visitAggregation(Aggregation node, C context) { + return visitChildren(node, context); + } + + public T visitEqualTo(EqualTo node, C context) { + return visitChildren(node, context); + } + + public T visitLiteral(Literal node, C context) { + return visitChildren(node, context); + } + + public T visitUnresolvedAttribute(UnresolvedAttribute node, C context) { + return visitChildren(node, context); + } + + public T visitAttributeList(AttributeList node, C context) { + return visitChildren(node, context); + } + + public T visitMap(Map node, C context) { + return visitChildren(node, context); + } + + public T visitNot(Not node, C context) { + return visitChildren(node, context); + } + + public T visitOr(Or node, C context) { + return visitChildren(node, context); + } + + public T visitAnd(And node, C context) { + return visitChildren(node, context); + } + + public T visitXor(Xor node, C context) { + return visitChildren(node, context); + } + + public T visitAggregateFunction(AggregateFunction node, C context) { + return visitChildren(node, context); + } + + public T visitFunction(Function node, C context) { + return visitChildren(node, context); + } + + public T visitWindowFunction(WindowFunction node, C context) { + return visitChildren(node, context); + } + + public T visitIn(In node, C context) { + return visitChildren(node, context); + } + + public T visitCompare(Compare node, C context) { + return visitChildren(node, context); + } + + public T visitBetween(Between node, C context) { + return visitChildren(node, context); + } + + public T visitArgument(Argument node, C context) { + return visitChildren(node, context); + } + + public T visitField(Field node, C context) { + return visitChildren(node, context); + } + + public T visitQualifiedName(QualifiedName node, C context) { + return visitChildren(node, context); + } + + public T visitRename(Rename node, C context) { + return visitChildren(node, context); + } + + public T visitEval(Eval node, C context) { + return visitChildren(node, context); + } + + public T visitParse(Parse node, C context) { + return visitChildren(node, context); + } + + public T visitLet(Let node, C context) { + return visitChildren(node, context); + } + + public T visitSort(Sort node, C context) { + return visitChildren(node, context); + } + + public T visitDedupe(Dedupe node, C context) { + return visitChildren(node, context); + } + + public T visitHead(Head node, C context) { + return visitChildren(node, context); + } + + public T visitRareTopN(RareTopN node, C context) { + return visitChildren(node, context); + } + public T visitValues(Values node, C context) { + return visitChildren(node, context); + } + + public T visitAlias(Alias node, C context) { + return visitChildren(node, context); + } + + public T visitAllFields(AllFields node, C context) { + return visitChildren(node, context); + } + + public T visitInterval(Interval node, C context) { + return visitChildren(node, context); + } + + public T visitCase(Case node, C context) { + return visitChildren(node, context); + } + + public T visitWhen(When node, C context) { + return visitChildren(node, context); + } + + public T visitUnresolvedArgument(UnresolvedArgument node, C context) { + return visitChildren(node, context); + } + + public T visitLimit(Limit node, C context) { + return visitChildren(node, context); + } + + public T visitSpan(Span node, C context) { + return visitChildren(node, context); + } + + public T visitKmeans(Kmeans node, C context) { + return visitChildren(node, context); + } + + public T visitStatement(Statement node, C context) { + return visit(node, context); + } + + public T visitQuery(Query node, C context) { + return visitStatement(node, context); + } + + public T visitExplain(Explain node, C context) { + return visitStatement(node, context); + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/Node.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/Node.java new file mode 100644 index 000000000..710142ea0 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/Node.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast; + +import java.util.List; + +/** AST node. */ +public abstract class Node { + + public R accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitChildren(this, context); + } + + public List getChild() { + return null; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java new file mode 100644 index 000000000..b912ef686 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Collections; +import java.util.List; + +import static java.lang.String.format; + +/** + * Expression node of aggregate functions. Params include aggregate function name (AVG, SUM, MAX + * etc.) and the field to aggregate. + */ +public class AggregateFunction extends UnresolvedExpression { + private final String funcName; + private final UnresolvedExpression field; + private final List argList; + + private UnresolvedExpression condition; + + private Boolean distinct = false; + + /** + * Constructor. + * + * @param funcName function name. + * @param field {@link UnresolvedExpression}. + */ + public AggregateFunction(String funcName, UnresolvedExpression field) { + this.funcName = funcName; + this.field = field; + this.argList = Collections.emptyList(); + } + + /** + * Constructor. + * + * @param funcName function name. + * @param field {@link UnresolvedExpression}. + * @param distinct whether distinct field is specified or not. + */ + public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { + this.funcName = funcName; + this.field = field; + this.argList = Collections.emptyList(); + this.distinct = distinct; + } + + public AggregateFunction(String funcName, UnresolvedExpression field, List argList) { + this.funcName = funcName; + this.field = field; + this.argList = argList; + } + + @Override + public List getChild() { + return Collections.singletonList(field); + } + + public String getFuncName() { + return funcName; + } + + public UnresolvedExpression getField() { + return field; + } + + public List getArgList() { + return argList; + } + + public UnresolvedExpression getCondition() { + return condition; + } + + public Boolean getDistinct() { + return distinct; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAggregateFunction(this, context); + } + + @Override + public String toString() { + return format("%s(%s)", funcName, field); + } + + public UnresolvedExpression condition(UnresolvedExpression condition) { + this.condition = condition; + return this; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Alias.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Alias.java new file mode 100644 index 000000000..83e08330f --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Alias.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Alias abstraction that associate an unnamed expression with a name and an optional alias. The + * name and alias information preserved is useful for semantic analysis and response formatting + * eventually. This can avoid restoring the info in toString() method which is inaccurate because + * original info is already lost. + */ +public class Alias extends UnresolvedExpression { + + /** Original field name. */ + private String name; + + /** Expression aliased. */ + private UnresolvedExpression delegated; + + /** Optional field alias. */ + private String alias; + + public Alias(String name, UnresolvedExpression delegated, String alias) { + this.name = name; + this.delegated = delegated; + this.alias = alias; + } + + public Alias(String name, UnresolvedExpression delegated) { + this.name = name; + this.delegated = delegated; + } + + public String getName() { + return name; + } + + public UnresolvedExpression getDelegated() { + return delegated; + } + + public String getAlias() { + return alias; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAlias(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AllFields.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AllFields.java new file mode 100644 index 000000000..eb4a16efa --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AllFields.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +import java.util.Collections; +import java.util.List; + +/** Represent the All fields which is been used in SELECT *. */ +public class AllFields extends UnresolvedExpression { + public static final AllFields INSTANCE = new AllFields(); + + private AllFields() {} + + public static AllFields of() { + return INSTANCE; + } + + @Override + public List getChild() { + return Collections.emptyList(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAllFields(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java new file mode 100644 index 000000000..f19de2a05 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/And.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of logic AND. */ +public class And extends UnresolvedExpression { + private UnresolvedExpression left; + private UnresolvedExpression right; + + public And(UnresolvedExpression left, UnresolvedExpression right) { + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + public UnresolvedExpression getLeft() { + return left; + } + + public UnresolvedExpression getRight() { + return right; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAnd(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java new file mode 100644 index 000000000..3f51b595e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Argument.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Argument. */ +public class Argument extends UnresolvedExpression { + private final String name; + private String argName; + private Literal value; + + public Argument(String name, Literal value) { + this.name = name; + this.value = value; + } + + // private final DataType valueType; + @Override + public List getChild() { + return Arrays.asList(value); + } + + public String getArgName() { + return argName; + } + + public Literal getValue() { + return value; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitArgument(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AttributeList.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AttributeList.java new file mode 100644 index 000000000..c08265ea8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/AttributeList.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Expression node that includes a list of Expression nodes. */ +public class AttributeList extends UnresolvedExpression { + private List attrList; + + @Override + public List getChild() { + return ImmutableList.copyOf(attrList); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAttributeList(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Between.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Between.java new file mode 100644 index 000000000..c936da71c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Between.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +import java.util.Arrays; +import java.util.List; + +/** Unresolved expression for BETWEEN. */ +public class Between extends UnresolvedExpression { + + /** Value for range check. */ + private UnresolvedExpression value; + + /** Lower bound of the range (inclusive). */ + private UnresolvedExpression lowerBound; + + /** Upper bound of the range (inclusive). */ + private UnresolvedExpression upperBound; + + public Between(UnresolvedExpression value, UnresolvedExpression lowerBound, UnresolvedExpression upperBound) { + this.value = value; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + } + + @Override + public List getChild() { + return Arrays.asList(value, lowerBound, upperBound); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitBetween(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Case.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Case.java new file mode 100644 index 000000000..265db3ba7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Case.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +import java.util.List; + +/** AST node that represents CASE clause similar as Switch statement in programming language. */ +public class Case extends UnresolvedExpression { + + /** Value to be compared by WHEN statements. Null in the case of CASE WHEN conditions. */ + private UnresolvedExpression caseValue; + + /** + * Expression list that represents WHEN statements. Each is a mapping from condition to its + * result. + */ + private List whenClauses; + + /** Expression that represents ELSE statement result. */ + private UnresolvedExpression elseClause; + + public Case(UnresolvedExpression caseValue, List whenClauses, UnresolvedExpression elseClause) { + this.caseValue =caseValue; + this.whenClauses = whenClauses; + this.elseClause = elseClause; + } + + @Override + public List getChild() { + ImmutableList.Builder children = ImmutableList.builder(); + if (caseValue != null) { + children.add(caseValue); + } + children.addAll(whenClauses); + + if (elseClause != null) { + children.add(elseClause); + } + return children.build(); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCase(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Compare.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Compare.java new file mode 100644 index 000000000..d623612e8 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Compare.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +public class Compare extends UnresolvedExpression { + private String operator; + private UnresolvedExpression left; + private UnresolvedExpression right; + + public Compare(String operator, UnresolvedExpression left, UnresolvedExpression right) { + this.operator = operator; + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + public String getOperator() { + return operator; + } + + public UnresolvedExpression getLeft() { + return left; + } + + public UnresolvedExpression getRight() { + return right; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCompare(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java new file mode 100644 index 000000000..516106705 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/DataType.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.data.type.ExprCoreType; + +/** The DataType defintion in AST. Question, could we use {@link ExprCoreType} directly in AST? */ + +public enum DataType { + TYPE_ERROR(ExprCoreType.UNKNOWN), + NULL(ExprCoreType.UNDEFINED), + + INTEGER(ExprCoreType.INTEGER), + LONG(ExprCoreType.LONG), + SHORT(ExprCoreType.SHORT), + FLOAT(ExprCoreType.FLOAT), + DOUBLE(ExprCoreType.DOUBLE), + STRING(ExprCoreType.STRING), + BOOLEAN(ExprCoreType.BOOLEAN), + + DATE(ExprCoreType.DATE), + TIME(ExprCoreType.TIME), + TIMESTAMP(ExprCoreType.TIMESTAMP), + INTERVAL(ExprCoreType.INTERVAL); + + private final ExprCoreType coreType; + + DataType(ExprCoreType type) { + this.coreType = type; + } + + public ExprCoreType getCoreType() { + return coreType; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/EqualTo.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/EqualTo.java new file mode 100644 index 000000000..d792e59e7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/EqualTo.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of binary operator or comparison relation EQUAL. */ + +public class EqualTo extends UnresolvedExpression { + private UnresolvedExpression left; + private UnresolvedExpression right; + + public EqualTo(UnresolvedExpression left, UnresolvedExpression right) { + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitEqualTo(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java new file mode 100644 index 000000000..7c77fae1f --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Field.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Collections; +import java.util.List; +public class Field extends UnresolvedExpression { + private final UnresolvedExpression field; + private final List fieldArgs; + + /** Constructor of Field. */ + public Field(UnresolvedExpression field) { + this(field, Collections.emptyList()); + } + + /** Constructor of Field. */ + public Field(UnresolvedExpression field, List fieldArgs) { + this.field = field; + this.fieldArgs = fieldArgs; + } + + public UnresolvedExpression getField() { + return field; + } + + public List getFieldArgs() { + return fieldArgs; + } + + public boolean hasArgument() { + return !fieldArgs.isEmpty(); + } + + @Override + public List getChild() { + return ImmutableList.of(this.field); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitField(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Function.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Function.java new file mode 100644 index 000000000..c546d001d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Function.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Expression node of scalar function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ + +public class Function extends UnresolvedExpression { + private String funcName; + private List funcArgs; + + public Function(String funcName, List funcArgs) { + this.funcName = funcName; + this.funcArgs = funcArgs; + } + + @Override + public List getChild() { + return Collections.unmodifiableList(funcArgs); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFunction(this, context); + } + + public String getFuncName() { + return funcName; + } + + public List getFuncArgs() { + return funcArgs; + } + + @Override + public String toString() { + return String.format( + "%s(%s)", + funcName, funcArgs.stream().map(Object::toString).collect(Collectors.joining(", "))); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/In.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/In.java new file mode 100644 index 000000000..16a75963e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/In.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** + * Expression node of one-to-many mapping relation IN. Params include the field expression and/or + * wildcard field expression, nested field expression (@field). And the values that the field is + * mapped to (@valueList). + */ + +public class In extends UnresolvedExpression { + private UnresolvedExpression field; + private List valueList; + + public In(UnresolvedExpression field, List valueList) { + this.field = field; + this.valueList = valueList; + } + + @Override + public List getChild() { + return Arrays.asList(field); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitIn(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java new file mode 100644 index 000000000..92b5ca333 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Interval.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Collections; +import java.util.List; + + +public class Interval extends UnresolvedExpression { + + private final UnresolvedExpression value; + private final IntervalUnit unit; + + public Interval(UnresolvedExpression value, IntervalUnit unit) { + this.value = value; + this.unit = unit; + } + public Interval(UnresolvedExpression value, String unit) { + this.value = value; + this.unit = IntervalUnit.of(unit); + } + + @Override + public List getChild() { + return Collections.singletonList(value); + } + + public UnresolvedExpression getValue() { + return value; + } + + public IntervalUnit getUnit() { + return unit; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitInterval(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java new file mode 100644 index 000000000..14c7e0d45 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/IntervalUnit.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + + +public enum IntervalUnit { + UNKNOWN, + + MICROSECOND, + SECOND, + MINUTE, + HOUR, + DAY, + WEEK, + MONTH, + QUARTER, + YEAR, + SECOND_MICROSECOND, + MINUTE_MICROSECOND, + MINUTE_SECOND, + HOUR_MICROSECOND, + HOUR_SECOND, + HOUR_MINUTE, + DAY_MICROSECOND, + DAY_SECOND, + DAY_MINUTE, + DAY_HOUR, + YEAR_MONTH; + + private static final List INTERVAL_UNITS; + + static { + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + INTERVAL_UNITS = builder.add(IntervalUnit.values()).build(); + } + + /** Util method to get interval unit given the unit name. */ + public static IntervalUnit of(String unit) { + return INTERVAL_UNITS.stream() + .filter(v -> unit.equalsIgnoreCase(v.name())) + .findFirst() + .orElse(IntervalUnit.UNKNOWN); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Let.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Let.java new file mode 100644 index 000000000..85c5f45de --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Let.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** + * Represent the assign operation. e.g. velocity = distance/speed. + */ + + +public class Let extends UnresolvedExpression { + private Field var; + private UnresolvedExpression expression; + + public Let(Field var, UnresolvedExpression expression) { + this.var = var; + this.expression = expression; + } + + public Field getVar() { + return var; + } + + public UnresolvedExpression getExpression() { + return expression; + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLet(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Literal.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Literal.java new file mode 100644 index 000000000..e7f1937ba --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Literal.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** + * Expression node of literal type Params include literal value (@value) and literal data type + * (@type) which can be selected from {@link DataType}. + */ + +public class Literal extends UnresolvedExpression { + + private Object value; + private DataType type; + + public Literal(Object value, DataType dataType) { + this.value = value; + this.type = dataType; + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLiteral(this, context); + } + + public Object getValue() { + return value; + } + + public DataType getType() { + return type; + } + + @Override + public String toString() { + return String.valueOf(value); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Map.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Map.java new file mode 100644 index 000000000..825a0f184 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Map.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of one-to-one mapping relation. */ + +public class Map extends UnresolvedExpression { + private UnresolvedExpression origin; + private UnresolvedExpression target; + + public Map(UnresolvedExpression origin, UnresolvedExpression target) { + this.origin = origin; + this.target = target; + } + + public UnresolvedExpression getOrigin() { + return origin; + } + + public UnresolvedExpression getTarget() { + return target; + } + + @Override + public List getChild() { + return Arrays.asList(origin, target); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitMap(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Not.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Not.java new file mode 100644 index 000000000..f55433774 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Not.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of the logic NOT. */ + +public class Not extends UnresolvedExpression { + private UnresolvedExpression expression; + + public Not(UnresolvedExpression expression) { + this.expression = expression; + } + + @Override + public List getChild() { + return Arrays.asList(expression); + } + + public UnresolvedExpression getExpression() { + return expression; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitNot(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java new file mode 100644 index 000000000..65e1a2e6d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Or.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of the logic OR. */ + +public class Or extends UnresolvedExpression { + private UnresolvedExpression left; + private UnresolvedExpression right; + + public Or(UnresolvedExpression left, UnresolvedExpression right) { + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + public UnresolvedExpression getLeft() { + return left; + } + + public UnresolvedExpression getRight() { + return right; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitOr(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ParseMethod.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ParseMethod.java new file mode 100644 index 000000000..2ae3235a9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ParseMethod.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +public enum ParseMethod { + REGEX("regex"), + GROK("grok"), + PATTERNS("patterns"); + + private final String name; + + ParseMethod(String name) { + this.name = name; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java new file mode 100644 index 000000000..8abd3a98c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/QualifiedName.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.StreamSupport; + +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; + +public class QualifiedName extends UnresolvedExpression { + private final List parts; + + public QualifiedName(String name) { + this.parts = Collections.singletonList(name); + } + + /** QualifiedName Constructor. */ + public QualifiedName(Iterable parts) { + List partsList = StreamSupport.stream(parts.spliterator(), false).collect(toList()); + if (partsList.isEmpty()) { + throw new IllegalArgumentException("parts is empty"); + } + this.parts = partsList; + } + + public List getParts() { + return parts; + } + + /** Construct {@link QualifiedName} from list of string. */ + public static QualifiedName of(String first, String... rest) { + requireNonNull(first); + ArrayList parts = new ArrayList<>(); + parts.add(first); + parts.addAll(Arrays.asList(rest)); + return new QualifiedName(parts); + } + + public static QualifiedName of(Iterable parts) { + return new QualifiedName(parts); + } + + /** Get Prefix of {@link QualifiedName}. */ + public Optional getPrefix() { + if (parts.size() == 1) { + return Optional.empty(); + } + return Optional.of(QualifiedName.of(parts.subList(0, parts.size() - 1))); + } + + public String getSuffix() { + return parts.get(parts.size() - 1); + } + + /** + * Get first part of the qualified name. + * + * @return first part + */ + public Optional first() { + if (parts.size() == 1) { + return Optional.empty(); + } + return Optional.of(parts.get(0)); + } + + /** + *
+   * Get rest parts of the qualified name. Assume that there must be remaining parts so caller is
+   * responsible for the check (first() or size() must be called first).
+   * For example:
+   * {@code
+   * QualifiedName name = ...
+   * Optional first = name.first();
+   * if (first.isPresent()) {
+   *    name.rest() ...
+   * }
+   * }
+   * @return rest part(s)
+   * 
+ */ + public QualifiedName rest() { + return QualifiedName.of(parts.subList(1, parts.size())); + } + + public String toString() { + return String.join(".", this.parts); + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitQualifiedName(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java new file mode 100644 index 000000000..b68edbc62 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Span.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** Span expression node. Params include field expression and the span value. */ +public class Span extends UnresolvedExpression { + private UnresolvedExpression field; + private UnresolvedExpression value; + private SpanUnit unit; + + public Span(UnresolvedExpression field, UnresolvedExpression value, SpanUnit unit) { + this.field = field; + this.value = value; + this.unit = unit; + } + + public UnresolvedExpression getField() { + return field; + } + + public UnresolvedExpression getValue() { + return value; + } + + public SpanUnit getUnit() { + return unit; + } + + @Override + public List getChild() { + return ImmutableList.of(field, value); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitSpan(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/SpanUnit.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/SpanUnit.java new file mode 100644 index 000000000..d8bacc2f9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/SpanUnit.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + + +public enum SpanUnit { + UNKNOWN("unknown"), + NONE(""), + MILLISECOND("ms"), + MS("ms"), + SECOND("s"), + S("s"), + MINUTE("m"), + m("m"), + HOUR("h"), + H("h"), + DAY("d"), + D("d"), + WEEK("w"), + W("w"), + MONTH("M"), + M("M"), + QUARTER("q"), + Q("q"), + YEAR("y"), + Y("y"); + + private final String name; + private static final List SPAN_UNITS; + + static { + ImmutableList.Builder builder = ImmutableList.builder(); + SPAN_UNITS = builder.add(SpanUnit.values()).build(); + } + + SpanUnit(String name) { + this.name = name; + } + + /** Util method to get span unit given the unit name. */ + public static SpanUnit of(String unit) { + switch (unit) { + case "": + return NONE; + case "M": + return M; + case "m": + return m; + default: + return SPAN_UNITS.stream() + .filter(v -> unit.equalsIgnoreCase(v.name())) + .findFirst() + .orElse(UNKNOWN); + } + } + + public static String getName(SpanUnit unit) { + return unit.name; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedArgument.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedArgument.java new file mode 100644 index 000000000..38daa476b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedArgument.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Argument. */ + +public class UnresolvedArgument extends UnresolvedExpression { + private final String argName; + private final UnresolvedExpression value; + + public UnresolvedArgument(String argName, UnresolvedExpression value) { + this.argName = argName; + this.value = value; + } + + @Override + public List getChild() { + return Arrays.asList(value); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitUnresolvedArgument(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedAttribute.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedAttribute.java new file mode 100644 index 000000000..043d1dd02 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedAttribute.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** + * Expression node, representing the syntax that is not resolved to any other expression nodes yet + * but non-negligible This expression is often created as the index name, field name etc. + */ + +public class UnresolvedAttribute extends UnresolvedExpression { + private String attr; + + public UnresolvedAttribute(String attr) { + this.attr = attr; + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitUnresolvedAttribute(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedExpression.java new file mode 100644 index 000000000..25029e07d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/UnresolvedExpression.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +public abstract class UnresolvedExpression extends Node { + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitChildren(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/When.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/When.java new file mode 100644 index 000000000..9341f6c2e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/When.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +import java.util.List; + +/** AST node that represents WHEN clause. */ +public class When extends UnresolvedExpression { + + /** WHEN condition, either a search condition or compare value if case value present. */ + private UnresolvedExpression condition; + + /** Result to return if condition matched. */ + private UnresolvedExpression result; + + public When(UnresolvedExpression condition, UnresolvedExpression result) { + this.condition = condition; + this.result = result; + } + + @Override + public List getChild() { + return ImmutableList.of(condition, result); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitWhen(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/WindowFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/WindowFunction.java new file mode 100644 index 000000000..eccf5c6e7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/WindowFunction.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import com.google.common.collect.ImmutableList; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.tree.Sort.SortOption; + +import java.util.List; + +public class WindowFunction extends UnresolvedExpression { + private UnresolvedExpression function; + private List partitionByList; + private List> sortList; + + public WindowFunction(UnresolvedExpression function, List partitionByList, List> sortList) { + this.function = function; + this.partitionByList = partitionByList; + this.sortList = sortList; + } + + @Override + public List getChild() { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(function); + children.addAll(partitionByList); + sortList.forEach(pair -> children.add(pair.getRight())); + return children.build(); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitWindowFunction(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java new file mode 100644 index 000000000..9368a6363 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Xor.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.Arrays; +import java.util.List; + +/** Expression node of the logic XOR. */ + +public class Xor extends UnresolvedExpression { + private UnresolvedExpression left; + private UnresolvedExpression right; + + public Xor(UnresolvedExpression left, UnresolvedExpression right) { + this.left = left; + this.right = right; + } + + @Override + public List getChild() { + return Arrays.asList(left, right); + } + + public UnresolvedExpression getLeft() { + return left; + } + + public UnresolvedExpression getRight() { + return right; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitXor(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java new file mode 100644 index 000000000..4968668ac --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.ast.statement; + +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** Explain Statement. */ +public class Explain extends Statement { + + private Statement statement; + + public Explain(Query statement) { + this.statement = statement; + } + + public Statement getStatement() { + return statement; + } + + @Override + public R accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitExplain(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Query.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Query.java new file mode 100644 index 000000000..6a7ac1530 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Query.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.ast.statement; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +/** Query Statement. */ +public class Query extends Statement { + + protected UnresolvedPlan plan; + protected int fetchSize; + + public Query(UnresolvedPlan plan, int fetchSize) { + this.plan = plan; + this.fetchSize = fetchSize; + } + + public UnresolvedPlan getPlan() { + return plan; + } + + @Override + public R accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitQuery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Statement.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Statement.java new file mode 100644 index 000000000..d90071a0c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Statement.java @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.ast.statement; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** Statement is the high interface of core engine. */ +public abstract class Statement extends Node { + @Override + public R accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitStatement(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java new file mode 100644 index 000000000..825c6d340 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Aggregation.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** Logical plan node of Aggregation, the interface for building aggregation actions in queries. */ +public class Aggregation extends UnresolvedPlan { + private List aggExprList; + private List sortExprList; + private List groupExprList; + private UnresolvedExpression span; + private List argExprList; + private UnresolvedPlan child; + + /** Aggregation Constructor without span and argument. */ + public Aggregation( + List aggExprList, + List sortExprList, + List groupExprList) { + this(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + } + + /** Aggregation Constructor. */ + public Aggregation( + List aggExprList, + List sortExprList, + List groupExprList, + UnresolvedExpression span, + List argExprList) { + this.aggExprList = aggExprList; + this.sortExprList = sortExprList; + this.groupExprList = groupExprList; + this.span = span; + this.argExprList = argExprList; + } + + public List getAggExprList() { + return aggExprList; + } + + public List getSortExprList() { + return sortExprList; + } + + public List getGroupExprList() { + return groupExprList; + } + + public UnresolvedExpression getSpan() { + return span; + } + + public List getArgExprList() { + return argExprList; + } + + public boolean hasArgument() { + return !aggExprList.isEmpty(); + } + + @Override + public Aggregation attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitAggregation(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Dedupe.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Dedupe.java new file mode 100644 index 000000000..a428e68ad --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Dedupe.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +/** AST node represent Dedupe operation. */ +public class Dedupe extends UnresolvedPlan { + private UnresolvedPlan child; + private List options; + private List fields; + + public Dedupe(UnresolvedPlan child, List options, List fields) { + this.child = child; + this.options = options; + this.fields = fields; + } + public Dedupe(List options, List fields) { + this.options = options; + this.fields = fields; + } + + @Override + public Dedupe attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public List getOptions() { + return options; + } + + public List getFields() { + return fields; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitDedupe(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java new file mode 100644 index 000000000..24a6bb428 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Eval.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Let; + +import java.util.List; + +/** AST node represent Eval operation. */ +public class Eval extends UnresolvedPlan { + private List expressionList; + private UnresolvedPlan child; + + public Eval(List expressionList) { + this.expressionList = expressionList; + } + + public List getExpressionList() { + return expressionList; + } + + @Override + public Eval attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitEval(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Filter.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Filter.java new file mode 100644 index 000000000..244181653 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Filter.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +/** Logical plan node of Filter, the interface for building filters in queries. */ + +public class Filter extends UnresolvedPlan { + private UnresolvedExpression condition; + private UnresolvedPlan child; + + public Filter(UnresolvedExpression condition) { + this.condition = condition; + } + + @Override + public Filter attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public UnresolvedExpression getCondition() { + return condition; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitFilter(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Head.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Head.java new file mode 100644 index 000000000..560ffda6e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Head.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +/** AST node represent Head operation. */ + +public class Head extends UnresolvedPlan { + + private UnresolvedPlan child; + private Integer size; + private Integer from; + + public Head(UnresolvedPlan child, Integer size, Integer from) { + this.child = child; + this.size = size; + this.from = from; + } + + public Head(Integer size, Integer from) { + this.size = size; + this.from = from; + } + + @Override + public Head attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public Integer getSize() { + return size; + } + + public Integer getFrom() { + return from; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitHead(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java new file mode 100644 index 000000000..6e3e67eaa --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; + +import java.util.List; +import java.util.Map; + +public class Kmeans extends UnresolvedPlan { + private UnresolvedPlan child; + + private Map arguments; + + public Kmeans(ImmutableMap arguments) { + this.arguments = arguments; + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitKmeans(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Limit.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Limit.java new file mode 100644 index 000000000..3fce9c0aa --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Limit.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +import java.util.List; + +public class Limit extends UnresolvedPlan { + private UnresolvedPlan child; + private Integer limit; + private Integer offset; + + public Limit(Integer limit, Integer offset) { + this.limit = limit; + this.offset = offset; + } + + @Override + public Limit attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitLimit(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Parse.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Parse.java new file mode 100644 index 000000000..4b2d6e9c1 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Parse.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Map; + +/** AST node represent Parse with regex operation. */ + +public class Parse extends UnresolvedPlan { + /** Method used to parse a field. */ + private ParseMethod parseMethod; + + /** Field. */ + private UnresolvedExpression sourceField; + + /** Pattern. */ + private Literal pattern; + + /** Optional arguments. */ + private Map arguments; + + /** Child Plan. */ + private UnresolvedPlan child; + + public Parse(ParseMethod parseMethod, UnresolvedExpression sourceField, Literal pattern, Map arguments, UnresolvedPlan child) { + this.parseMethod = parseMethod; + this.sourceField = sourceField; + this.pattern = pattern; + this.arguments = arguments; + this.child = child; + } + + public Parse(ParseMethod parseMethod, UnresolvedExpression sourceField, Literal pattern, Map arguments) { + + this.parseMethod = parseMethod; + this.sourceField = sourceField; + this.pattern = pattern; + this.arguments = arguments; + } + + public ParseMethod getParseMethod() { + return parseMethod; + } + + public UnresolvedExpression getSourceField() { + return sourceField; + } + + public Literal getPattern() { + return pattern; + } + + public Map getArguments() { + return arguments; + } + + @Override + public Parse attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitParse(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Project.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Project.java new file mode 100644 index 000000000..6237f6b4c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Project.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** Logical plan node of Project, the interface for building the list of searching fields. */ +public class Project extends UnresolvedPlan { + private List projectList; + private List argExprList; + private UnresolvedPlan child; + + public Project(List projectList) { + this.projectList = projectList; + this.argExprList = Collections.emptyList(); + } + + public Project(List projectList, List argExprList) { + this.projectList = projectList; + this.argExprList = argExprList; + } + + public List getProjectList() { + return projectList; + } + + public List getArgExprList() { + return argExprList; + } + + public boolean hasArgument() { + return !argExprList.isEmpty(); + } + + /** The Project could been used to exclude fields from the source. */ + public boolean isExcluded() { + if (hasArgument()) { + Argument argument = argExprList.get(0); + return (Boolean) argument.getValue().getValue(); + } + return false; + } + + @Override + public Project attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + + return nodeVisitor.visitProject(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java new file mode 100644 index 000000000..6b05288cc --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** AST node represent RareTopN operation. */ + +public class RareTopN extends UnresolvedPlan { + + private UnresolvedPlan child; + private CommandType commandType; + private List noOfResults; + private List fields; + private List groupExprList; + + public RareTopN( CommandType commandType, List noOfResults, List fields, List groupExprList) { + this.commandType = commandType; + this.noOfResults = noOfResults; + this.fields = fields; + this.groupExprList = groupExprList; + } + + @Override + public RareTopN attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + public CommandType getCommandType() { + return commandType; + } + + public List getNoOfResults() { + return noOfResults; + } + + public List getFields() { + return fields; + } + + public List getGroupExprList() { + return groupExprList; + } + + @Override + public List getChild() { + return Collections.singletonList(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitRareTopN(this, context); + } + + public enum CommandType { + TOP, + RARE + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java new file mode 100644 index 000000000..6a482db67 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** Logical plan node of Relation, the interface for building the searching sources. */ + +public class Relation extends UnresolvedPlan { + private static final String COMMA = ","; + + private final List tableName; + + public Relation(UnresolvedExpression tableName) { + this(tableName, null); + } + + public Relation(List tableName) { + this.tableName = tableName; + } + + public Relation(UnresolvedExpression tableName, String alias) { + this.tableName = Arrays.asList(tableName); + this.alias = alias; + } + + /** Optional alias name for the relation. */ + private String alias; + + /** + * Return table name. + * + * @return table name + */ + public List getTableName() { + return tableName.stream().map(Object::toString).collect(Collectors.toList()); + } + + + /** + * Return alias. + * + * @return alias. + */ + public String getAlias() { + return alias; + } + + /** + * Get Qualified name preservs parts of the user given identifiers. This can later be utilized to + * determine DataSource,Schema and Table Name during Analyzer stage. So Passing QualifiedName + * directly to Analyzer Stage. + * + * @return TableQualifiedName. + */ + public QualifiedName getTableQualifiedName() { + if (tableName.size() == 1) { + return (QualifiedName) tableName.get(0); + } else { + return new QualifiedName( + tableName.stream() + .map(UnresolvedExpression::toString) + .collect(Collectors.joining(COMMA))); + } + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitRelation(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + return this; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java new file mode 100644 index 000000000..c3f215177 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Rename.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Map; + +import java.util.List; + +public class Rename extends UnresolvedPlan { + private final List renameList; + private UnresolvedPlan child; + + public Rename(List renameList, UnresolvedPlan child) { + this.renameList = renameList; + this.child = child; + } + + public Rename(List renameList) { + this.renameList = renameList; + } + + public List getRenameList() { + return renameList; + } + + @Override + public Rename attach(UnresolvedPlan child) { + if (null == this.child) { + this.child = child; + } else { + this.child.attach(child); + } + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitRename(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Sort.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Sort.java new file mode 100644 index 000000000..e502662f4 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Sort.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Field; + +import java.util.List; + +import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; +import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST; +import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; +import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; + +/** + * AST node for Sort {@link Sort#sortList} represent a list of sort expression and sort options. + */ + + +public class Sort extends UnresolvedPlan { + private UnresolvedPlan child; + private List sortList; + + public Sort(List sortList) { + this.sortList = sortList; + } + public Sort(UnresolvedPlan child, List sortList) { + this.child = child; + this.sortList = sortList; + } + + @Override + public Sort attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitSort(this, context); + } + + public List getSortList() { + return sortList; + } + + /** + * Sort Options. + */ + + public static class SortOption { + + /** + * Default ascending sort option, null first. + */ + public static SortOption DEFAULT_ASC = new SortOption(ASC, NULL_FIRST); + + /** + * Default descending sort option, null last. + */ + public static SortOption DEFAULT_DESC = new SortOption(DESC, NULL_LAST); + + private SortOrder sortOrder; + private NullOrder nullOrder; + + public SortOption(SortOrder sortOrder, NullOrder nullOrder) { + this.sortOrder = sortOrder; + this.nullOrder = nullOrder; + } + } + + public enum SortOrder { + ASC, + DESC + } + + public enum NullOrder { + NULL_FIRST, + NULL_LAST + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TableFunction.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TableFunction.java new file mode 100644 index 000000000..823c975e9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TableFunction.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +/** + * AST Node for Table Function. + */ + + +public class TableFunction extends UnresolvedPlan { + + private UnresolvedExpression functionName; + + private List arguments; + + public TableFunction(UnresolvedExpression functionName, List arguments) { + this.functionName = functionName; + this.arguments = arguments; + } + + public List getArguments() { + return arguments; + } + + public QualifiedName getFunctionName() { + return (QualifiedName) functionName; + } + + @Override + public List getChild() { + return ImmutableList.of(); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitTableFunction(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + return null; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/UnresolvedPlan.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/UnresolvedPlan.java new file mode 100644 index 000000000..2de40e53a --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/UnresolvedPlan.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** Abstract unresolved plan. */ + + +public abstract class UnresolvedPlan extends Node { + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitChildren(this, context); + } + + public abstract UnresolvedPlan attach(UnresolvedPlan child); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Values.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Values.java new file mode 100644 index 000000000..d6af4b6be --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Values.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Literal; + +import java.util.List; + +/** + * AST node class for a sequence of literal values. + */ + + +public class Values extends UnresolvedPlan { + + private List> values; + + public Values(List list) { + + } + + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + throw new UnsupportedOperationException("Values node is supposed to have no child node"); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitValues(this, context); + } + + @Override + public List getChild() { + return ImmutableList.of(); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/CaseInsensitiveCharStream.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/CaseInsensitiveCharStream.java new file mode 100644 index 000000000..89381872c --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/CaseInsensitiveCharStream.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.antlr; + +import org.antlr.v4.runtime.CharStream; +import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.misc.Interval; + +/** + * Custom stream to convert character to upper case for case insensitive grammar before sending to + * lexer. + */ +public class CaseInsensitiveCharStream implements CharStream { + + /** Character stream. */ + private final CharStream charStream; + + public CaseInsensitiveCharStream(String sql) { + this.charStream = CharStreams.fromString(sql); + } + + @Override + public String getText(Interval interval) { + return charStream.getText(interval); + } + + @Override + public void consume() { + charStream.consume(); + } + + @Override + public int LA(int i) { + int c = charStream.LA(i); + if (c <= 0) { + return c; + } + return Character.toUpperCase(c); + } + + @Override + public int mark() { + return charStream.mark(); + } + + @Override + public void release(int marker) { + charStream.release(marker); + } + + @Override + public int index() { + return charStream.index(); + } + + @Override + public void seek(int index) { + charStream.seek(index); + } + + @Override + public int size() { + return charStream.size(); + } + + @Override + public String getSourceName() { + return charStream.getSourceName(); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java new file mode 100644 index 000000000..7962f53ef --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/Parser.java @@ -0,0 +1,7 @@ +package org.opensearch.sql.common.antlr; + +import org.antlr.v4.runtime.tree.ParseTree; + +public interface Parser { + ParseTree parse(String query); +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxAnalysisErrorListener.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxAnalysisErrorListener.java new file mode 100644 index 000000000..42f35a15f --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxAnalysisErrorListener.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.antlr; + +import org.antlr.v4.runtime.BaseErrorListener; +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.RecognitionException; +import org.antlr.v4.runtime.Recognizer; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.misc.IntervalSet; + +import java.util.Locale; + +/** + * Syntax analysis error listener that handles any syntax error by throwing exception with useful + * information. + */ +public class SyntaxAnalysisErrorListener extends BaseErrorListener { + + @Override + public void syntaxError( + Recognizer recognizer, + Object offendingSymbol, + int line, + int charPositionInLine, + String msg, + RecognitionException e) { + + CommonTokenStream tokens = (CommonTokenStream) recognizer.getInputStream(); + Token offendingToken = (Token) offendingSymbol; + String query = tokens.getText(); + + throw new SyntaxCheckException( + String.format( + Locale.ROOT, + "Failed to parse query due to offending symbol [%s] " + + "at: '%s' <--- HERE... More details: %s", + getOffendingText(offendingToken), + truncateQueryAtOffendingToken(query, offendingToken), + getDetails(recognizer, msg, e))); + } + + private String getOffendingText(Token offendingToken) { + return offendingToken.getText(); + } + + private String truncateQueryAtOffendingToken(String query, Token offendingToken) { + return query.substring(0, offendingToken.getStopIndex() + 1); + } + + /** + * As official JavaDoc says, e=null means parser was able to recover from the error. In other + * words, "msg" argument includes the information we want. + */ + private String getDetails(Recognizer recognizer, String msg, RecognitionException e) { + String details; + if (e == null) { + details = msg; + } else { + IntervalSet followSet = e.getExpectedTokens(); + details = "Expecting tokens in " + followSet.toString(recognizer.getVocabulary()); + } + return details; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxCheckException.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxCheckException.java new file mode 100644 index 000000000..d3c9c111e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/antlr/SyntaxCheckException.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.antlr; + +public class SyntaxCheckException extends RuntimeException { + public SyntaxCheckException(String message) { + super(message); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java new file mode 100644 index 000000000..ef2717ac1 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprCoreType.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.data.type; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** Expression Type. */ +public enum ExprCoreType implements ExprType { + /** Unknown due to unsupported data type. */ + UNKNOWN, + + /** + * Undefined type for special literal such as NULL. As the root of data type tree, it is + * compatible with any other type. In other word, undefined type is the "narrowest" type. + */ + UNDEFINED, + + /** Numbers. */ + BYTE(UNDEFINED), + SHORT(BYTE), + INTEGER(SHORT), + LONG(INTEGER), + FLOAT(LONG), + DOUBLE(FLOAT), + + /** String. */ + STRING(UNDEFINED), + + /** Boolean. */ + BOOLEAN(STRING), + + /** Date. */ + DATE(STRING), + TIME(STRING), + TIMESTAMP(STRING, DATE, TIME), + INTERVAL(UNDEFINED), + + /** Struct. */ + STRUCT(UNDEFINED), + + /** Array. */ + ARRAY(UNDEFINED); + + /** Parents (wider/compatible types) of current base type. */ + private final List parents = new ArrayList<>(); + + /** The mapping between Type and legacy JDBC type name. */ + private static final Map LEGACY_TYPE_NAME_MAPPING = + new ImmutableMap.Builder() + .put(STRUCT, "OBJECT") + .put(ARRAY, "NESTED") + .put(STRING, "KEYWORD") + .build(); + + private static final Set NUMBER_TYPES = + new ImmutableSet.Builder() + .add(BYTE) + .add(SHORT) + .add(INTEGER) + .add(LONG) + .add(FLOAT) + .add(DOUBLE) + .build(); + + ExprCoreType(ExprCoreType... compatibleTypes) { + for (ExprCoreType subType : compatibleTypes) { + subType.parents.add(this); + } + } + + @Override + public List getParent() { + return parents.isEmpty() ? ExprType.super.getParent() : parents; + } + + @Override + public String typeName() { + return this.name(); + } + + @Override + public String legacyTypeName() { + return LEGACY_TYPE_NAME_MAPPING.getOrDefault(this, this.name()); + } + + /** Return all the valid ExprCoreType. */ + public static List coreTypes() { + return Arrays.stream(ExprCoreType.values()) + .filter(type -> type != UNKNOWN) + .filter(type -> type != UNDEFINED) + .collect(Collectors.toList()); + } + + public static Set numberTypes() { + return NUMBER_TYPES; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprType.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprType.java new file mode 100644 index 000000000..39f55540d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/data/type/ExprType.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.data.type; + + +import java.util.Arrays; +import java.util.List; + +import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN; + +/** The Type of {@link Expression} and {@link ExprValue}. */ +public interface ExprType { + /** Is compatible with other types. */ + default boolean isCompatible(ExprType other) { + if (this.equals(other)) { + return true; + } else { + if (other.equals(UNKNOWN)) { + return false; + } + for (ExprType parentTypeOfOther : other.getParent()) { + if (isCompatible(parentTypeOfOther)) { + return true; + } + } + return false; + } + } + + /** + * Should cast this type to other type or not. By default, cast is always required if the given + * type is different from this type. + * + * @param other other data type + * @return true if cast is required, otherwise false + */ + default boolean shouldCast(ExprType other) { + return !this.equals(other); + } + + /** Get the parent type. */ + default List getParent() { + return Arrays.asList(UNKNOWN); + } + + /** Get the type name. */ + String typeName(); + + /** Get the legacy type name for old engine. */ + default String legacyTypeName() { + return typeName(); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java new file mode 100644 index 000000000..f12648eb2 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -0,0 +1,298 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import com.google.common.collect.ImmutableMap; + +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +/** Builtin Function Name. */ +public enum BuiltinFunctionName { + /** Mathematical Functions. */ + ABS(FunctionName.of("abs")), + CEIL(FunctionName.of("ceil")), + CEILING(FunctionName.of("ceiling")), + CONV(FunctionName.of("conv")), + CRC32(FunctionName.of("crc32")), + E(FunctionName.of("e")), + EXP(FunctionName.of("exp")), + EXPM1(FunctionName.of("expm1")), + FLOOR(FunctionName.of("floor")), + LN(FunctionName.of("ln")), + LOG(FunctionName.of("log")), + LOG10(FunctionName.of("log10")), + LOG2(FunctionName.of("log2")), + PI(FunctionName.of("pi")), + POW(FunctionName.of("pow")), + POWER(FunctionName.of("power")), + RAND(FunctionName.of("rand")), + RINT(FunctionName.of("rint")), + ROUND(FunctionName.of("round")), + SIGN(FunctionName.of("sign")), + SIGNUM(FunctionName.of("signum")), + SINH(FunctionName.of("sinh")), + SQRT(FunctionName.of("sqrt")), + CBRT(FunctionName.of("cbrt")), + TRUNCATE(FunctionName.of("truncate")), + + ACOS(FunctionName.of("acos")), + ASIN(FunctionName.of("asin")), + ATAN(FunctionName.of("atan")), + ATAN2(FunctionName.of("atan2")), + COS(FunctionName.of("cos")), + COSH(FunctionName.of("cosh")), + COT(FunctionName.of("cot")), + DEGREES(FunctionName.of("degrees")), + RADIANS(FunctionName.of("radians")), + SIN(FunctionName.of("sin")), + TAN(FunctionName.of("tan")), + + /** Date and Time Functions. */ + ADDDATE(FunctionName.of("adddate")), + ADDTIME(FunctionName.of("addtime")), + CONVERT_TZ(FunctionName.of("convert_tz")), + DATE(FunctionName.of("date")), + DATEDIFF(FunctionName.of("datediff")), + DATETIME(FunctionName.of("datetime")), + DATE_ADD(FunctionName.of("date_add")), + DATE_FORMAT(FunctionName.of("date_format")), + DATE_SUB(FunctionName.of("date_sub")), + DAY(FunctionName.of("day")), + DAYNAME(FunctionName.of("dayname")), + DAYOFMONTH(FunctionName.of("dayofmonth")), + DAY_OF_MONTH(FunctionName.of("day_of_month")), + DAYOFWEEK(FunctionName.of("dayofweek")), + DAYOFYEAR(FunctionName.of("dayofyear")), + DAY_OF_WEEK(FunctionName.of("day_of_week")), + DAY_OF_YEAR(FunctionName.of("day_of_year")), + EXTRACT(FunctionName.of("extract")), + FROM_DAYS(FunctionName.of("from_days")), + FROM_UNIXTIME(FunctionName.of("from_unixtime")), + GET_FORMAT(FunctionName.of("get_format")), + HOUR(FunctionName.of("hour")), + HOUR_OF_DAY(FunctionName.of("hour_of_day")), + LAST_DAY(FunctionName.of("last_day")), + MAKEDATE(FunctionName.of("makedate")), + MAKETIME(FunctionName.of("maketime")), + MICROSECOND(FunctionName.of("microsecond")), + MINUTE(FunctionName.of("minute")), + MINUTE_OF_DAY(FunctionName.of("minute_of_day")), + MINUTE_OF_HOUR(FunctionName.of("minute_of_hour")), + MONTH(FunctionName.of("month")), + MONTH_OF_YEAR(FunctionName.of("month_of_year")), + MONTHNAME(FunctionName.of("monthname")), + PERIOD_ADD(FunctionName.of("period_add")), + PERIOD_DIFF(FunctionName.of("period_diff")), + QUARTER(FunctionName.of("quarter")), + SEC_TO_TIME(FunctionName.of("sec_to_time")), + SECOND(FunctionName.of("second")), + SECOND_OF_MINUTE(FunctionName.of("second_of_minute")), + STR_TO_DATE(FunctionName.of("str_to_date")), + SUBDATE(FunctionName.of("subdate")), + SUBTIME(FunctionName.of("subtime")), + TIME(FunctionName.of("time")), + TIMEDIFF(FunctionName.of("timediff")), + TIME_TO_SEC(FunctionName.of("time_to_sec")), + TIMESTAMP(FunctionName.of("timestamp")), + TIMESTAMPADD(FunctionName.of("timestampadd")), + TIMESTAMPDIFF(FunctionName.of("timestampdiff")), + TIME_FORMAT(FunctionName.of("time_format")), + TO_DAYS(FunctionName.of("to_days")), + TO_SECONDS(FunctionName.of("to_seconds")), + UTC_DATE(FunctionName.of("utc_date")), + UTC_TIME(FunctionName.of("utc_time")), + UTC_TIMESTAMP(FunctionName.of("utc_timestamp")), + UNIX_TIMESTAMP(FunctionName.of("unix_timestamp")), + WEEK(FunctionName.of("week")), + WEEKDAY(FunctionName.of("weekday")), + WEEKOFYEAR(FunctionName.of("weekofyear")), + WEEK_OF_YEAR(FunctionName.of("week_of_year")), + YEAR(FunctionName.of("year")), + YEARWEEK(FunctionName.of("yearweek")), + + // `now`-like functions + NOW(FunctionName.of("now")), + CURDATE(FunctionName.of("curdate")), + CURRENT_DATE(FunctionName.of("current_date")), + CURTIME(FunctionName.of("curtime")), + CURRENT_TIME(FunctionName.of("current_time")), + LOCALTIME(FunctionName.of("localtime")), + CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")), + LOCALTIMESTAMP(FunctionName.of("localtimestamp")), + SYSDATE(FunctionName.of("sysdate")), + + /** Text Functions. */ + TOSTRING(FunctionName.of("tostring")), + + /** Arithmetic Operators. */ + ADD(FunctionName.of("+")), + ADDFUNCTION(FunctionName.of("add")), + DIVIDE(FunctionName.of("/")), + DIVIDEFUNCTION(FunctionName.of("divide")), + MOD(FunctionName.of("mod")), + MODULUS(FunctionName.of("%")), + MODULUSFUNCTION(FunctionName.of("modulus")), + MULTIPLY(FunctionName.of("*")), + MULTIPLYFUNCTION(FunctionName.of("multiply")), + SUBTRACT(FunctionName.of("-")), + SUBTRACTFUNCTION(FunctionName.of("subtract")), + + /** Boolean Operators. */ + AND(FunctionName.of("and")), + OR(FunctionName.of("or")), + XOR(FunctionName.of("xor")), + NOT(FunctionName.of("not")), + EQUAL(FunctionName.of("=")), + NOTEQUAL(FunctionName.of("!=")), + LESS(FunctionName.of("<")), + LTE(FunctionName.of("<=")), + GREATER(FunctionName.of(">")), + GTE(FunctionName.of(">=")), + LIKE(FunctionName.of("like")), + NOT_LIKE(FunctionName.of("not like")), + + /** Aggregation Function. */ + AVG(FunctionName.of("avg")), + SUM(FunctionName.of("sum")), + COUNT(FunctionName.of("count")), + MIN(FunctionName.of("min")), + MAX(FunctionName.of("max")), + // sample variance + VARSAMP(FunctionName.of("var_samp")), + // population standard variance + VARPOP(FunctionName.of("var_pop")), + // sample standard deviation. + STDDEV_SAMP(FunctionName.of("stddev_samp")), + // population standard deviation. + STDDEV_POP(FunctionName.of("stddev_pop")), + // take top documents from aggregation bucket. + TAKE(FunctionName.of("take")), + // Not always an aggregation query + NESTED(FunctionName.of("nested")), + + /** Text Functions. */ + ASCII(FunctionName.of("ascii")), + CONCAT(FunctionName.of("concat")), + CONCAT_WS(FunctionName.of("concat_ws")), + LEFT(FunctionName.of("left")), + LENGTH(FunctionName.of("length")), + LOCATE(FunctionName.of("locate")), + LOWER(FunctionName.of("lower")), + LTRIM(FunctionName.of("ltrim")), + POSITION(FunctionName.of("position")), + REGEXP(FunctionName.of("regexp")), + REPLACE(FunctionName.of("replace")), + REVERSE(FunctionName.of("reverse")), + RIGHT(FunctionName.of("right")), + RTRIM(FunctionName.of("rtrim")), + STRCMP(FunctionName.of("strcmp")), + SUBSTR(FunctionName.of("substr")), + SUBSTRING(FunctionName.of("substring")), + TRIM(FunctionName.of("trim")), + UPPER(FunctionName.of("upper")), + + /** NULL Test. */ + IS_NULL(FunctionName.of("is null")), + IS_NOT_NULL(FunctionName.of("is not null")), + IFNULL(FunctionName.of("ifnull")), + IF(FunctionName.of("if")), + NULLIF(FunctionName.of("nullif")), + ISNULL(FunctionName.of("isnull")), + + ROW_NUMBER(FunctionName.of("row_number")), + RANK(FunctionName.of("rank")), + DENSE_RANK(FunctionName.of("dense_rank")), + + INTERVAL(FunctionName.of("interval")), + + /** Data Type Convert Function. */ + CAST_TO_STRING(FunctionName.of("cast_to_string")), + CAST_TO_BYTE(FunctionName.of("cast_to_byte")), + CAST_TO_SHORT(FunctionName.of("cast_to_short")), + CAST_TO_INT(FunctionName.of("cast_to_int")), + CAST_TO_LONG(FunctionName.of("cast_to_long")), + CAST_TO_FLOAT(FunctionName.of("cast_to_float")), + CAST_TO_DOUBLE(FunctionName.of("cast_to_double")), + CAST_TO_BOOLEAN(FunctionName.of("cast_to_boolean")), + CAST_TO_DATE(FunctionName.of("cast_to_date")), + CAST_TO_TIME(FunctionName.of("cast_to_time")), + CAST_TO_TIMESTAMP(FunctionName.of("cast_to_timestamp")), + CAST_TO_DATETIME(FunctionName.of("cast_to_datetime")), + TYPEOF(FunctionName.of("typeof")), + + /** Relevance Function. */ + MATCH(FunctionName.of("match")), + SIMPLE_QUERY_STRING(FunctionName.of("simple_query_string")), + MATCH_PHRASE(FunctionName.of("match_phrase")), + MATCHPHRASE(FunctionName.of("matchphrase")), + MATCHPHRASEQUERY(FunctionName.of("matchphrasequery")), + QUERY_STRING(FunctionName.of("query_string")), + MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")), + HIGHLIGHT(FunctionName.of("highlight")), + MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")), + SCORE(FunctionName.of("score")), + SCOREQUERY(FunctionName.of("scorequery")), + SCORE_QUERY(FunctionName.of("score_query")), + + /** Legacy Relevance Function. */ + QUERY(FunctionName.of("query")), + MATCH_QUERY(FunctionName.of("match_query")), + MATCHQUERY(FunctionName.of("matchquery")), + MULTI_MATCH(FunctionName.of("multi_match")), + MULTIMATCH(FunctionName.of("multimatch")), + MULTIMATCHQUERY(FunctionName.of("multimatchquery")), + WILDCARDQUERY(FunctionName.of("wildcardquery")), + WILDCARD_QUERY(FunctionName.of("wildcard_query")); + + private FunctionName name; + + private static final Map ALL_NATIVE_FUNCTIONS; + + BuiltinFunctionName(FunctionName functionName) { + this.name = functionName; + } + + public FunctionName getName() { + return name; + } + + static { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (BuiltinFunctionName func : BuiltinFunctionName.values()) { + builder.put(func.getName(), func); + } + ALL_NATIVE_FUNCTIONS = builder.build(); + } + + + private static final Map AGGREGATION_FUNC_MAPPING = + new ImmutableMap.Builder() + .put("max", BuiltinFunctionName.MAX) + .put("min", BuiltinFunctionName.MIN) + .put("avg", BuiltinFunctionName.AVG) + .put("count", BuiltinFunctionName.COUNT) + .put("sum", BuiltinFunctionName.SUM) + .put("var_pop", BuiltinFunctionName.VARPOP) + .put("var_samp", BuiltinFunctionName.VARSAMP) + .put("variance", BuiltinFunctionName.VARPOP) + .put("std", BuiltinFunctionName.STDDEV_POP) + .put("stddev", BuiltinFunctionName.STDDEV_POP) + .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) + .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) + .put("take", BuiltinFunctionName.TAKE) + .build(); + + public static Optional of(String str) { + return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); + } + + public static Optional ofAggregation(String functionName) { + return Optional.ofNullable( + AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/FunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/FunctionName.java new file mode 100644 index 000000000..ed84a41eb --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/FunctionName.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.io.Serializable; +import java.util.Objects; + +/** + * The definition of Function Name. + */ +public class FunctionName implements Serializable { + private String functionName; + + public FunctionName(String functionName) { + this.functionName = functionName; + } + + public static FunctionName of(String functionName) { + return new FunctionName(functionName.toLowerCase()); + } + + @Override + public String toString() { + return functionName; + } + + public String getFunctionName() { + return toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FunctionName that = (FunctionName) o; + return Objects.equals(getFunctionName(), that.getFunctionName()); + } + + @Override + public int hashCode() { + return Objects.hash(getFunctionName()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java new file mode 100644 index 000000000..7e21ac9a9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Union; +import scala.collection.Seq; + +import java.util.Stack; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.collection.JavaConverters.asScalaBuffer; + +/** + * The context used for Catalyst logical plan. + */ +public class CatalystPlanContext { + /** + * Catalyst evolving logical plan + **/ + private Stack planBranches = new Stack<>(); + + /** + * NamedExpression contextual parameters + **/ + private final Stack namedParseExpressions = new Stack<>(); + + /** + * Grouping NamedExpression contextual parameters + **/ + private final Stack groupingParseExpressions = new Stack<>(); + + public LogicalPlan getPlan() { + if (this.planBranches.size() == 1) { + return planBranches.peek(); + } + //default unify sub-plans + return new Union(asScalaBuffer(this.planBranches), true, true); + } + + public Stack getNamedParseExpressions() { + return namedParseExpressions; + } + + public Stack getGroupingParseExpressions() { + return groupingParseExpressions; + } + + /** + * append context with evolving plan + * + * @param plan + */ + public void with(LogicalPlan plan) { + this.planBranches.push(plan); + } + + public LogicalPlan plan(Function transformFunction) { + this.planBranches.replaceAll(transformFunction::apply); + return getPlan(); + } + + /** + * retain all expressions and clear expression stack + * @return + */ + public Seq retainAllNamedParseExpressions(Function transformFunction) { + Seq aggregateExpressions = seq(getNamedParseExpressions().stream() + .map(transformFunction::apply).collect(Collectors.toList())); + getNamedParseExpressions().retainAll(emptyList()); + return aggregateExpressions; + } + + /** + * retain all aggregate expressions and clear expression stack + * @return + */ + public Seq retainAllGroupingNamedParseExpressions(Function transformFunction) { + Seq aggregateExpressions = seq(getGroupingParseExpressions().stream() + .map(transformFunction::apply).collect(Collectors.toList())); + getGroupingParseExpressions().retainAll(emptyList()); + return aggregateExpressions; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java new file mode 100644 index 000000000..ff7e54e22 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -0,0 +1,379 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.statement.Explain; +import org.opensearch.sql.ast.statement.Query; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ppl.utils.AggregatorTranslator; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.SortUtils; +import scala.Option; +import scala.collection.Seq; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static java.util.List.of; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Utility class to traverse PPL logical plan and translate it into catalyst logical plan + */ +public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { + + private final ExpressionAnalyzer expressionAnalyzer; + + public CatalystQueryPlanVisitor() { + this.expressionAnalyzer = new ExpressionAnalyzer(); + } + + public LogicalPlan visit(Statement plan, CatalystPlanContext context) { + return plan.accept(this, context); + } + + /** + * Handle Query Statement. + */ + @Override + public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { + return node.getPlan().accept(this, context); + } + + @Override + public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { + return node.getStatement().accept(this, context); + } + + @Override + public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + node.getTableName().forEach(t -> { + // Resolving the qualifiedName which is composed of a datasource.schema.table + context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)); + }); + return context.getPlan(); + } + + @Override + public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + Expression conditionExpression = visitExpression(node.getCondition(), context); + Expression innerConditionExpression = context.getNamedParseExpressions().pop(); + return context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression, p)); + } + + @Override + public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List aggsExpList = visitExpressionList(node.getAggExprList(), context); + List groupExpList = visitExpressionList(node.getGroupExprList(), context); + + if (!groupExpList.isEmpty()) { + //add group by fields to context + context.getGroupingParseExpressions().addAll(groupExpList); + } + + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + span.accept(this, context); + //add span's group alias field (most recent added expression) + context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); + } + // build the aggregation logical step + return extractedAggregation(context); + } + + private static LogicalPlan extractedAggregation(CatalystPlanContext context) { + Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p); + Seq aggregateExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + return context.plan(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); + } + + @Override + public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { + expressionAnalyzer.visitAlias(node, context); + return context.getPlan(); + } + + @Override + public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List expressionList = visitExpressionList(node.getProjectList(), context); + + // Create a projection list from the existing expressions + Seq projectList = seq(context.getNamedParseExpressions()); + if (!projectList.isEmpty()) { + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } + if (node.hasArgument()) { + Argument argument = node.getArgExprList().get(0); + //todo exclude the argument from the projected arguments list + } + return child; + } + + @Override + public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + visitFieldList(node.getSortList(), context); + Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); + return context.plan(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); + } + + @Override + public LogicalPlan visitHead(Head node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + return context.plan(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + node.getSize(), DataTypes.IntegerType), p)); + } + + private void visitFieldList(List fieldList, CatalystPlanContext context) { + fieldList.forEach(field -> visitExpression(field, context)); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> visitExpression(field, context)) + .collect(Collectors.toList()); + } + + private Expression visitExpression(UnresolvedExpression expression, CatalystPlanContext context) { + return expressionAnalyzer.analyze(expression, context); + } + + @Override + public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public LogicalPlan visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public LogicalPlan visitIn(In node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : In"); + } + + @Override + public LogicalPlan visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); + } + + @Override + public LogicalPlan visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : dedupe "); + } + + /** + * Expression Analyzer. + */ + private static class ExpressionAnalyzer extends AbstractNodeVisitor { + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + node.getRight().accept(this, context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Expression arg = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(arg)); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.getNamedParseExpressions().pop(); + node.getValue().accept(this, context); + Expression value = (Expression) context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.getNamedParseExpressions().pop(); + Expression aggregator = AggregatorTranslator.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Expression left = (Expression) context.getNamedParseExpressions().pop(); + analyze(node.getRight(), context); + Expression right = (Expression) context.getNamedParseExpressions().pop(); + Predicate comparator = ComparatorTransformer.comparator(node, left, right); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + + @Override + public Expression visitField(Field node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getField().toString()))); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + // Case of aggregation step - no start projection can be added + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.getNamedParseExpressions().pop(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Function"); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Interval"); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : In"); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java new file mode 100644 index 000000000..1b26255f9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -0,0 +1,343 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.parser; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Parse; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Rename; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ppl.utils.ArgumentFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + + +/** Class of building the AST. Refines the visit path and build the AST nodes */ +public class AstBuilder extends OpenSearchPPLParserBaseVisitor { + + private AstExpressionBuilder expressionBuilder; + + /** + * PPL query to get original token text. This is necessary because token.getText() returns text + * without whitespaces or other characters discarded by lexer. + */ + private String query; + + public AstBuilder(AstExpressionBuilder expressionBuilder, String query) { + this.expressionBuilder = expressionBuilder; + this.query = query; + } + + @Override + public UnresolvedPlan visitQueryStatement(OpenSearchPPLParser.QueryStatementContext ctx) { + UnresolvedPlan pplCommand = visit(ctx.pplCommands()); + return ctx.commands().stream().map(this::visit).reduce(pplCommand, (r, e) -> e.attach(r)); + } + + /** Search command. */ + @Override + public UnresolvedPlan visitSearchFrom(OpenSearchPPLParser.SearchFromContext ctx) { + return visitFromClause(ctx.fromClause()); + } + + @Override + public UnresolvedPlan visitSearchFromFilter(OpenSearchPPLParser.SearchFromFilterContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())) + .attach(visit(ctx.fromClause())); + } + + @Override + public UnresolvedPlan visitSearchFilterFrom(OpenSearchPPLParser.SearchFilterFromContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())) + .attach(visit(ctx.fromClause())); + } + + @Override + public UnresolvedPlan visitDescribeCommand(OpenSearchPPLParser.DescribeCommandContext ctx) { + final Relation table = (Relation) visitTableSourceClause(ctx.tableSourceClause()); + QualifiedName tableQualifiedName = table.getTableQualifiedName(); + ArrayList parts = new ArrayList<>(tableQualifiedName.getParts()); + return new Relation(new QualifiedName(parts)); + } + + /** Where command. */ + @Override + public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext ctx) { + return new Filter(internalVisitExpression(ctx.logicalExpression())); + } + + /** Fields command. */ + @Override + public UnresolvedPlan visitFieldsCommand(OpenSearchPPLParser.FieldsCommandContext ctx) { + return new Project( + ctx.fieldList().fieldExpression().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()), + ArgumentFactory.getArgumentList(ctx)); + } + + /** Rename command. */ + @Override + public UnresolvedPlan visitRenameCommand(OpenSearchPPLParser.RenameCommandContext ctx) { + return new Rename( + ctx.renameClasue().stream() + .map( + ct -> + new Map( + internalVisitExpression(ct.orignalField), + internalVisitExpression(ct.renamedField))) + .collect(Collectors.toList())); + } + + /** Stats command. */ + @Override + public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext ctx) { + ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + for (OpenSearchPPLParser.StatsAggTermContext aggCtx : ctx.statsAggTerm()) { + UnresolvedExpression aggExpression = internalVisitExpression(aggCtx.statsFunction()); + String name = + aggCtx.alias == null + ? getTextInQuery(aggCtx) + : aggCtx.alias.getText(); + Alias alias = new Alias(name, aggExpression); + aggListBuilder.add(alias); + } + + List groupList = + Optional.ofNullable(ctx.statsByClause()) + .map(OpenSearchPPLParser.StatsByClauseContext::fieldList) + .map( + expr -> + expr.fieldExpression().stream() + .map( + groupCtx -> + (UnresolvedExpression) + new Alias( + getTextInQuery(groupCtx), + internalVisitExpression(groupCtx))) + .collect(Collectors.toList())) + .orElse(Collections.emptyList()); + + UnresolvedExpression span = + Optional.ofNullable(ctx.statsByClause()) + .map(OpenSearchPPLParser.StatsByClauseContext::bySpanClause) + .map(this::internalVisitExpression) + .orElse(null); + + Aggregation aggregation = + new Aggregation( + aggListBuilder.build(), + Collections.emptyList(), + groupList, + span, + ArgumentFactory.getArgumentList(ctx)); + return aggregation; + } + + /** Dedup command. */ + @Override + public UnresolvedPlan visitDedupCommand(OpenSearchPPLParser.DedupCommandContext ctx) { + return new Dedupe(ArgumentFactory.getArgumentList(ctx), getFieldList(ctx.fieldList())); + } + + /** Head command visitor. */ + @Override + public UnresolvedPlan visitHeadCommand(OpenSearchPPLParser.HeadCommandContext ctx) { + Integer size = ctx.number != null ? Integer.parseInt(ctx.number.getText()) : 10; + Integer from = ctx.from != null ? Integer.parseInt(ctx.from.getText()) : 0; + return new Head(size, from); + } + + /** Sort command. */ + @Override + public UnresolvedPlan visitSortCommand(OpenSearchPPLParser.SortCommandContext ctx) { + return new Sort( + ctx.sortbyClause().sortField().stream() + .map(sort -> (Field) internalVisitExpression(sort)) + .collect(Collectors.toList())); + } + + /** Eval command. */ + @Override + public UnresolvedPlan visitEvalCommand(OpenSearchPPLParser.EvalCommandContext ctx) { + return new Eval( + ctx.evalClause().stream() + .map(ct -> (Let) internalVisitExpression(ct)) + .collect(Collectors.toList())); + } + + private List getGroupByList(OpenSearchPPLParser.ByClauseContext ctx) { + return ctx.fieldList().fieldExpression().stream() + .map(this::internalVisitExpression) + .collect(Collectors.toList()); + } + + private List getFieldList(OpenSearchPPLParser.FieldListContext ctx) { + return ctx.fieldExpression().stream() + .map(field -> (Field) internalVisitExpression(field)) + .collect(Collectors.toList()); + } + + /** Rare command. */ + @Override + public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { + throw new RuntimeException("Rare Command is not supported "); + } + + @Override + public UnresolvedPlan visitGrokCommand(OpenSearchPPLParser.GrokCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + Literal pattern = (Literal) internalVisitExpression(ctx.pattern); + + return new Parse(ParseMethod.GROK, sourceField, pattern, ImmutableMap.of()); + } + + @Override + public UnresolvedPlan visitParseCommand(OpenSearchPPLParser.ParseCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + Literal pattern = (Literal) internalVisitExpression(ctx.pattern); + + return new Parse(ParseMethod.REGEX, sourceField, pattern, ImmutableMap.of()); + } + + @Override + public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.patternsParameter() + .forEach( + x -> { + builder.put( + x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + java.util.Map arguments = builder.build(); + Literal pattern = arguments.getOrDefault("pattern", new Literal("", DataType.STRING)); + + return new Parse(ParseMethod.PATTERNS, sourceField, pattern, arguments); + } + + /** Top command. */ + @Override + public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { + List groupList = + ctx.byClause() == null ? Collections.emptyList() : getGroupByList(ctx.byClause()); + return new RareTopN( + RareTopN.CommandType.TOP, + ArgumentFactory.getArgumentList(ctx), + getFieldList(ctx.fieldList()), + groupList); + } + + /** From clause. */ + @Override + public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { + if (ctx.tableFunction() != null) { + return visitTableFunction(ctx.tableFunction()); + } else { + return visitTableSourceClause(ctx.tableSourceClause()); + } + } + + @Override + public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { + return new Relation( + ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + } + + @Override + public UnresolvedPlan visitTableFunction(OpenSearchPPLParser.TableFunctionContext ctx) { + ImmutableList.Builder builder = ImmutableList.builder(); + ctx.functionArgs() + .functionArg() + .forEach( + arg -> { + String argName = (arg.ident() != null) ? arg.ident().getText() : null; + builder.add( + new UnresolvedArgument( + argName, this.internalVisitExpression(arg.valueExpression()))); + }); + return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); + } + + /** Navigate to & build AST expression. */ + private UnresolvedExpression internalVisitExpression(ParseTree tree) { + return expressionBuilder.visit(tree); + } + + /** Simply return non-default value for now. */ + @Override + protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPlan nextResult) { + if (nextResult != defaultResult()) { + return nextResult; + } + return aggregate; + } + + /** Kmeans command. */ + @Override + public UnresolvedPlan visitKmeansCommand(OpenSearchPPLParser.KmeansCommandContext ctx) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.kmeansParameter() + .forEach( + x -> { + builder.put( + x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + return new Kmeans(builder.build()); + } + + /** AD command. */ + @Override + public UnresolvedPlan visitAdCommand(OpenSearchPPLParser.AdCommandContext ctx) { + throw new RuntimeException("AD Command is not supported "); + + } + + /** ml command. */ + @Override + public UnresolvedPlan visitMlCommand(OpenSearchPPLParser.MlCommandContext ctx) { + throw new RuntimeException("ML Command is not supported "); + } + + /** Get original text in query. */ + private String getTextInQuery(ParserRuleContext ctx) { + Token start = ctx.getStart(); + Token stop = ctx.getStop(); + return query.substring(start.getStartIndex(), stop.getStopIndex() + 1); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java new file mode 100644 index 000000000..e7d723afd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -0,0 +1,387 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.parser; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.RuleContext; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.IntervalUnit; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ppl.utils.ArgumentFactory; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION; + + +/** + * Class of building AST Expression nodes. + */ +public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor { + + private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10; + + /** + * The function name mapping between fronted and core engine. + */ + private static Map FUNCTION_NAME_MAPPING = + new ImmutableMap.Builder() + .put("isnull", IS_NULL.getName().getFunctionName()) + .put("isnotnull", IS_NOT_NULL.getName().getFunctionName()) + .build(); + + /** + * Eval clause. + */ + @Override + public UnresolvedExpression visitEvalClause(OpenSearchPPLParser.EvalClauseContext ctx) { + return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); + } + + /** + * Logical expression excluding boolean, comparison. + */ + @Override + public UnresolvedExpression visitLogicalNot(OpenSearchPPLParser.LogicalNotContext ctx) { + return new Not(visit(ctx.logicalExpression())); + } + + @Override + public UnresolvedExpression visitLogicalOr(OpenSearchPPLParser.LogicalOrContext ctx) { + return new Or(visit(ctx.left), visit(ctx.right)); + } + + @Override + public UnresolvedExpression visitLogicalAnd(OpenSearchPPLParser.LogicalAndContext ctx) { + return new And(visit(ctx.left), visit(ctx.right)); + } + + @Override + public UnresolvedExpression visitLogicalXor(OpenSearchPPLParser.LogicalXorContext ctx) { + return new Xor(visit(ctx.left), visit(ctx.right)); + } + + /** + * Comparison expression. + */ + @Override + public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprContext ctx) { + return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right)); + } + + /** + * Value Expression. + */ + @Override + public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) { + return new Function( + ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + } + + @Override + public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) { + return visit(ctx.valueExpression()); // Discard parenthesis around + } + + /** + * Field expression. + */ + @Override + public UnresolvedExpression visitFieldExpression(OpenSearchPPLParser.FieldExpressionContext ctx) { + return new Field((QualifiedName) visit(ctx.qualifiedName())); + } + + @Override + public UnresolvedExpression visitWcFieldExpression(OpenSearchPPLParser.WcFieldExpressionContext ctx) { + return new Field((QualifiedName) visit(ctx.wcQualifiedName())); + } + + @Override + public UnresolvedExpression visitSortField(OpenSearchPPLParser.SortFieldContext ctx) { + return new Field( + visit(ctx.sortFieldExpression().fieldExpression().qualifiedName()), + ArgumentFactory.getArgumentList(ctx)); + } + + /** + * Aggregation function. + */ + @Override + public UnresolvedExpression visitStatsFunctionCall(OpenSearchPPLParser.StatsFunctionCallContext ctx) { + return new AggregateFunction(ctx.statsFunctionName().getText(), visit(ctx.valueExpression())); + } + + @Override + public UnresolvedExpression visitCountAllFunctionCall(OpenSearchPPLParser.CountAllFunctionCallContext ctx) { + return new AggregateFunction("count", AllFields.of()); + } + + @Override + public UnresolvedExpression visitDistinctCountFunctionCall(OpenSearchPPLParser.DistinctCountFunctionCallContext ctx) { + return new AggregateFunction("count", visit(ctx.valueExpression()), true); + } + + @Override + public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.PercentileAggFunctionContext ctx) { + return new AggregateFunction( + ctx.PERCENTILE().getText(), + visit(ctx.aggField), + Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value)))); + } + + @Override + public UnresolvedExpression visitTakeAggFunctionCall( + OpenSearchPPLParser.TakeAggFunctionCallContext ctx) { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add( + new UnresolvedArgument( + "size", + ctx.takeAggFunction().size != null + ? visit(ctx.takeAggFunction().size) + : new Literal(DEFAULT_TAKE_FUNCTION_SIZE_VALUE, DataType.INTEGER))); + return new AggregateFunction( + "take", visit(ctx.takeAggFunction().fieldExpression()), builder.build()); + } + + /** + * Eval function. + */ + @Override + public UnresolvedExpression visitBooleanFunctionCall(OpenSearchPPLParser.BooleanFunctionCallContext ctx) { + final String functionName = ctx.conditionFunctionBase().getText(); + return buildFunction( + FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), + ctx.functionArgs().functionArg()); + } + + /** + * Eval function. + */ + @Override + public UnresolvedExpression visitEvalFunctionCall(OpenSearchPPLParser.EvalFunctionCallContext ctx) { + return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg()); + } + + @Override + public UnresolvedExpression visitConvertedDataType(OpenSearchPPLParser.ConvertedDataTypeContext ctx) { + return new Literal(ctx.getText(), DataType.STRING); + } + + private Function buildFunction( + String functionName, List args) { + return new Function( + functionName, args.stream().map(this::visitFunctionArg).collect(Collectors.toList())); + } + + public AstExpressionBuilder() { + } + + @Override + public UnresolvedExpression visitMultiFieldRelevanceFunction( + OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { + return new Function( + ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(), + multiFieldRelevanceArguments(ctx)); + } + + @Override + public UnresolvedExpression visitTableSource(OpenSearchPPLParser.TableSourceContext ctx) { + if (ctx.getChild(0) instanceof OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) { + return visitIdentsAsTableQualifiedName((OpenSearchPPLParser.IdentsAsTableQualifiedNameContext) ctx.getChild(0)); + } else { + return visitIdentifiers(Arrays.asList(ctx)); + } + } + + @Override + public UnresolvedExpression visitPositionFunction( + OpenSearchPPLParser.PositionFunctionContext ctx) { + return new Function( + POSITION.getName().getFunctionName(), + Arrays.asList(visitFunctionArg(ctx.functionArg(0)), visitFunctionArg(ctx.functionArg(1)))); + } + + @Override + public UnresolvedExpression visitExtractFunctionCall( + OpenSearchPPLParser.ExtractFunctionCallContext ctx) { + return new Function( + ctx.extractFunction().EXTRACT().toString(), getExtractFunctionArguments(ctx)); + } + + private List getExtractFunctionArguments( + OpenSearchPPLParser.ExtractFunctionCallContext ctx) { + List args = + Arrays.asList( + new Literal(ctx.extractFunction().datetimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.extractFunction().functionArg())); + return args; + } + + @Override + public UnresolvedExpression visitGetFormatFunctionCall( + OpenSearchPPLParser.GetFormatFunctionCallContext ctx) { + return new Function( + ctx.getFormatFunction().GET_FORMAT().toString(), getFormatFunctionArguments(ctx)); + } + + private List getFormatFunctionArguments( + OpenSearchPPLParser.GetFormatFunctionCallContext ctx) { + List args = + Arrays.asList( + new Literal(ctx.getFormatFunction().getFormatType().getText(), DataType.STRING), + visitFunctionArg(ctx.getFormatFunction().functionArg())); + return args; + } + + @Override + public UnresolvedExpression visitTimestampFunctionCall( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + return new Function( + ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); + } + + private List timestampFunctionArguments( + OpenSearchPPLParser.TimestampFunctionCallContext ctx) { + List args = + Arrays.asList( + new Literal(ctx.timestampFunction().simpleDateTimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.timestampFunction().firstArg), + visitFunctionArg(ctx.timestampFunction().secondArg)); + return args; + } + + /** + * Literal and value. + */ + @Override + public UnresolvedExpression visitIdentsAsQualifiedName(OpenSearchPPLParser.IdentsAsQualifiedNameContext ctx) { + return visitIdentifiers(ctx.ident()); + } + + @Override + public UnresolvedExpression visitIdentsAsTableQualifiedName( + OpenSearchPPLParser.IdentsAsTableQualifiedNameContext ctx) { + return visitIdentifiers( + Stream.concat(Stream.of(ctx.tableIdent()), ctx.ident().stream()) + .collect(Collectors.toList())); + } + + @Override + public UnresolvedExpression visitIdentsAsWildcardQualifiedName( + OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext ctx) { + return visitIdentifiers(ctx.wildcard()); + } + + @Override + public UnresolvedExpression visitIntervalLiteral(OpenSearchPPLParser.IntervalLiteralContext ctx) { + return new Interval( + visit(ctx.valueExpression()), IntervalUnit.of(ctx.intervalUnit().getText())); + } + + @Override + public UnresolvedExpression visitStringLiteral(OpenSearchPPLParser.StringLiteralContext ctx) { + return new Literal(ctx.getText(), DataType.STRING); + } + + @Override + public UnresolvedExpression visitIntegerLiteral(OpenSearchPPLParser.IntegerLiteralContext ctx) { + long number = Long.parseLong(ctx.getText()); + if (Integer.MIN_VALUE <= number && number <= Integer.MAX_VALUE) { + return new Literal((int) number, DataType.INTEGER); + } + return new Literal(number, DataType.LONG); + } + + @Override + public UnresolvedExpression visitDecimalLiteral(OpenSearchPPLParser.DecimalLiteralContext ctx) { + return new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE); + } + + @Override + public UnresolvedExpression visitBooleanLiteral(OpenSearchPPLParser.BooleanLiteralContext ctx) { + return new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN); + } + + @Override + public UnresolvedExpression visitBySpanClause(OpenSearchPPLParser.BySpanClauseContext ctx) { + String name = ctx.spanClause().getText(); + return ctx.alias != null + ? new Alias( + name, visit(ctx.spanClause()), ctx.alias.getText()) + : new Alias(name, visit(ctx.spanClause())); + } + + @Override + public UnresolvedExpression visitSpanClause(OpenSearchPPLParser.SpanClauseContext ctx) { + String unit = ctx.unit != null ? ctx.unit.getText() : ""; + return new Span(visit(ctx.fieldExpression()), visit(ctx.value), SpanUnit.of(unit)); + } + + private QualifiedName visitIdentifiers(List ctx) { + return new QualifiedName( + ctx.stream() + .map(RuleContext::getText) + .collect(Collectors.toList())); + } + + private List singleFieldRelevanceArguments( + OpenSearchPPLParser.SingleFieldRelevanceFunctionContext ctx) { + // all the arguments are defaulted to string values + // to skip environment resolving and function signature resolving + ImmutableList.Builder builder = ImmutableList.builder(); + builder.add( + new UnresolvedArgument( + "field", new QualifiedName(ctx.field.getText()))); + builder.add( + new UnresolvedArgument( + "query", new Literal(ctx.query.getText(), DataType.STRING))); + ctx.relevanceArg() + .forEach( + v -> + builder.add( + new UnresolvedArgument( + v.relevanceArgName().getText().toLowerCase(), + new Literal( + v.relevanceArgValue().getText(), + DataType.STRING)))); + return builder.build(); + } + + private List multiFieldRelevanceArguments( + OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { + throw new RuntimeException("ML Command is not supported "); + + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java new file mode 100644 index 000000000..23ca992d9 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sql.ppl.parser; + +import com.google.common.collect.ImmutableList; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.statement.Explain; +import org.opensearch.sql.ast.statement.Query; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +/** Build {@link Statement} from PPL Query. */ + +public class AstStatementBuilder extends OpenSearchPPLParserBaseVisitor { + + private AstBuilder astBuilder; + + private StatementBuilderContext context; + + public AstStatementBuilder(AstBuilder astBuilder, StatementBuilderContext context) { + this.astBuilder = astBuilder; + this.context = context; + } + + @Override + public Statement visitDmlStatement(OpenSearchPPLParser.DmlStatementContext ctx) { + Query query = new Query(addSelectAll(astBuilder.visit(ctx)), context.getFetchSize()); + return context.isExplain ? new Explain(query) : query; + } + + @Override + protected Statement aggregateResult(Statement aggregate, Statement nextResult) { + return nextResult != null ? nextResult : aggregate; + } + + public AstBuilder builder() { + return astBuilder; + } + + public StatementBuilderContext getContext() { + return context; + } + + public static class StatementBuilderContext { + private boolean isExplain; + private int fetchSize; + + public StatementBuilderContext(boolean isExplain, int fetchSize) { + this.isExplain = isExplain; + this.fetchSize = fetchSize; + } + + public static StatementBuilderContext builder() { + //todo set the default statement builder init params configurable + return new StatementBuilderContext(false,1000); + } + + public StatementBuilderContext explain(boolean isExplain) { + this.isExplain = isExplain; + return this; + } + + public int getFetchSize() { + return fetchSize; + } + + public Object build() { + return null; + } + } + + private UnresolvedPlan addSelectAll(UnresolvedPlan plan) { + if ((plan instanceof Project) && !((Project) plan).isExcluded()) { + return plan; + } else { + return new Project(ImmutableList.of(AllFields.of())).attach(plan); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java new file mode 100644 index 000000000..e15324cc0 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; + +/** + * aggregator expression builder building a catalyst aggregation function from PPL's aggregation logical step + * + * @return + */ +public interface AggregatorTranslator { + + static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction aggregateFunction, Expression arg) { + if (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).isEmpty()) + throw new IllegalStateException("Unexpected value: " + aggregateFunction.getFuncName()); + + // Additional aggregation function operators will be added here + switch (BuiltinFunctionName.ofAggregation(aggregateFunction.getFuncName()).get()) { + case MAX: + return new UnresolvedFunction(seq("MAX"), seq(arg),false, empty(),false); + case MIN: + return new UnresolvedFunction(seq("MIN"), seq(arg),false, empty(),false); + case AVG: + return new UnresolvedFunction(seq("AVG"), seq(arg),false, empty(),false); + case COUNT: + return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false); + case SUM: + return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false); + } + throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java new file mode 100644 index 000000000..43f696bcd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.opensearch.flint.spark.ppl.OpenSearchPPLParser; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** Util class to get all arguments as a list from the PPL command. */ +public class ArgumentFactory { + + /** + * Get list of {@link Argument}. + * + * @param ctx FieldsCommandContext instance + * @return the list of arguments fetched from the fields command + */ + public static List getArgumentList(OpenSearchPPLParser.FieldsCommandContext ctx) { + return Collections.singletonList( + ctx.MINUS() != null + ? new Argument("exclude", new Literal(true, DataType.BOOLEAN)) + : new Argument("exclude", new Literal(false, DataType.BOOLEAN))); + } + + /** + * Get list of {@link Argument}. + * + * @param ctx StatsCommandContext instance + * @return the list of arguments fetched from the stats command + */ + public static List getArgumentList(OpenSearchPPLParser.StatsCommandContext ctx) { + return Arrays.asList( + ctx.partitions != null + ? new Argument("partitions", getArgumentValue(ctx.partitions)) + : new Argument("partitions", new Literal(1, DataType.INTEGER)), + ctx.allnum != null + ? new Argument("allnum", getArgumentValue(ctx.allnum)) + : new Argument("allnum", new Literal(false, DataType.BOOLEAN)), + ctx.delim != null + ? new Argument("delim", getArgumentValue(ctx.delim)) + : new Argument("delim", new Literal(" ", DataType.STRING)), + ctx.dedupsplit != null + ? new Argument("dedupsplit", getArgumentValue(ctx.dedupsplit)) + : new Argument("dedupsplit", new Literal(false, DataType.BOOLEAN))); + } + + /** + * Get list of {@link Argument}. + * + * @param ctx DedupCommandContext instance + * @return the list of arguments fetched from the dedup command + */ + public static List getArgumentList(OpenSearchPPLParser.DedupCommandContext ctx) { + return Arrays.asList( + ctx.number != null + ? new Argument("number", getArgumentValue(ctx.number)) + : new Argument("number", new Literal(1, DataType.INTEGER)), + ctx.keepempty != null + ? new Argument("keepempty", getArgumentValue(ctx.keepempty)) + : new Argument("keepempty", new Literal(false, DataType.BOOLEAN)), + ctx.consecutive != null + ? new Argument("consecutive", getArgumentValue(ctx.consecutive)) + : new Argument("consecutive", new Literal(false, DataType.BOOLEAN))); + } + + /** + * Get list of {@link Argument}. + * + * @param ctx SortFieldContext instance + * @return the list of arguments fetched from the sort field in sort command + */ + public static List getArgumentList(OpenSearchPPLParser.SortFieldContext ctx) { + return Arrays.asList( + ctx.MINUS() != null + ? new Argument("asc", new Literal(false, DataType.BOOLEAN)) + : new Argument("asc", new Literal(true, DataType.BOOLEAN)), + ctx.sortFieldExpression().AUTO() != null + ? new Argument("type", new Literal("auto", DataType.STRING)) + : ctx.sortFieldExpression().IP() != null + ? new Argument("type", new Literal("ip", DataType.STRING)) + : ctx.sortFieldExpression().NUM() != null + ? new Argument("type", new Literal("num", DataType.STRING)) + : ctx.sortFieldExpression().STR() != null + ? new Argument("type", new Literal("str", DataType.STRING)) + : new Argument("type", new Literal(null, DataType.NULL))); + } + + /** + * Get list of {@link Argument}. + * + * @param ctx TopCommandContext instance + * @return the list of arguments fetched from the top command + */ + public static List getArgumentList(OpenSearchPPLParser.TopCommandContext ctx) { + return Collections.singletonList( + ctx.number != null + ? new Argument("noOfResults", getArgumentValue(ctx.number)) + : new Argument("noOfResults", new Literal(10, DataType.INTEGER))); + } + + /** + * Get list of {@link Argument}. + * + * @param ctx RareCommandContext instance + * @return the list of argument with default number of results for the rare command + */ + public static List getArgumentList(OpenSearchPPLParser.RareCommandContext ctx) { + return Collections.singletonList( + new Argument("noOfResults", new Literal(10, DataType.INTEGER))); + } + + /** + * parse argument value into Literal. + * + * @param ctx ParserRuleContext instance + * @return Literal + */ + private static Literal getArgumentValue(ParserRuleContext ctx) { + return ctx instanceof OpenSearchPPLParser.IntegerLiteralContext + ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) + : ctx instanceof OpenSearchPPLParser.BooleanLiteralContext + ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : new Literal(ctx.getText(), DataType.STRING); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java new file mode 100644 index 000000000..2a176ec3d --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ComparatorTransformer.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GreaterThan; +import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.LessThan; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.Not; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Transform the PPL Logical comparator into catalyst comparator + */ +public interface ComparatorTransformer { + /** + * comparator expression builder building a catalyst binary comparator from PPL's compare logical step + * + * @return + */ + static Predicate comparator(Compare expression, Expression left, Expression right) { + if (BuiltinFunctionName.of(expression.getOperator()).isEmpty()) + throw new IllegalStateException("Unexpected value: " + expression.getOperator()); + + if (left == null) { + throw new IllegalStateException("Unexpected value: No Left operands found in expression"); + } + + if (right == null) { + throw new IllegalStateException("Unexpected value: No Right operands found in expression"); + } + + // Additional function operators will be added here + switch (BuiltinFunctionName.of(expression.getOperator()).get()) { + case EQUAL: + return new EqualTo(left, right); + case NOTEQUAL: + return new Not(new EqualTo(left, right)); + case LESS: + return new LessThan(left, right); + case LTE: + return new LessThanOrEqual(left, right); + case GREATER: + return new GreaterThan(left, right); + case GTE: + return new GreaterThanOrEqual(left, right); + } + throw new IllegalStateException("Not Supported value: " + expression.getOperator()); + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java new file mode 100644 index 000000000..0c7269a07 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DataTypeTransformer.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + + +import org.apache.spark.sql.types.ByteType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.StringType$; +import org.apache.spark.unsafe.types.UTF8String; +import org.opensearch.sql.ast.expression.SpanUnit; +import scala.collection.mutable.Seq; + +import java.util.List; + +import static org.opensearch.sql.ast.expression.SpanUnit.DAY; +import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; +import static org.opensearch.sql.ast.expression.SpanUnit.MILLISECOND; +import static org.opensearch.sql.ast.expression.SpanUnit.MINUTE; +import static org.opensearch.sql.ast.expression.SpanUnit.MONTH; +import static org.opensearch.sql.ast.expression.SpanUnit.NONE; +import static org.opensearch.sql.ast.expression.SpanUnit.QUARTER; +import static org.opensearch.sql.ast.expression.SpanUnit.SECOND; +import static org.opensearch.sql.ast.expression.SpanUnit.WEEK; +import static org.opensearch.sql.ast.expression.SpanUnit.YEAR; +import static scala.collection.JavaConverters.asScalaBufferConverter; + +/** + * translate the PPL ast expressions data-types into catalyst data-types + */ +public interface DataTypeTransformer { + static Seq seq(T element) { + return seq(List.of(element)); + } + static Seq seq(List list) { + return asScalaBufferConverter(list).asScala().seq(); + } + + static DataType translate(org.opensearch.sql.ast.expression.DataType source) { + switch (source.getCoreType()) { + case TIME: + return DateType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case BYTE: + return ByteType$.MODULE$; + default: + return StringType$.MODULE$; + } + } + + static Object translate(Object value, org.opensearch.sql.ast.expression.DataType source) { + switch (source.getCoreType()) { + case STRING: + /* The regex ^'(.*)'$ matches strings that start and end with a single quote. The content inside the quotes is captured using the (.*). + * The $1 in the replaceAll method refers to the first captured group, which is the content inside the quotes. + * If the string matches the pattern, the content inside the quotes is returned; otherwise, the original string is returned. + */ + return UTF8String.fromString(value.toString().replaceAll("^'(.*)'$", "$1")); + default: + return value; + } + } + + static String translate(SpanUnit unit) { + switch (unit) { + case UNKNOWN: + case NONE: + return NONE.name(); + case MILLISECOND: + case MS: + return MILLISECOND.name(); + case SECOND: + case S: + return SECOND.name(); + case MINUTE: + case m: + return MINUTE.name(); + case HOUR: + case H: + return HOUR.name(); + case DAY: + case D: + return DAY.name(); + case WEEK: + case W: + return WEEK.name(); + case MONTH: + case M: + return MONTH.name(); + case QUARTER: + case Q: + return QUARTER.name(); + case YEAR: + case Y: + return YEAR.name(); + } + return ""; + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java new file mode 100644 index 000000000..83603b031 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/SortUtils.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.jetbrains.annotations.NotNull; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.Sort; + +import java.util.ArrayList; +import java.util.Optional; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +/** + * Utility interface for sorting operations. + * Provides methods to generate sort orders based on given criteria. + */ +public interface SortUtils { + + /** + * Retrieves the sort direction for a given field name from a sort node. + * + * @param node The sort node containing the list of fields and their sort directions. + * @param expression The field name for which the sort direction is to be retrieved. + * @return SortOrder representing the sort direction of the given field name or null if the field is not found. + */ + static SortOrder getSortDirection(Sort node, NamedExpression expression) { + Optional field = node.getSortList().stream() + .filter(f -> f.getField().toString().equals(expression.name())) + .findAny(); + + return field.map(value -> sortOrder((Expression) expression, + (Boolean) value.getFieldArgs().get(0).getValue().getValue())) + .orElse(null); + } + + @NotNull + static SortOrder sortOrder(Expression expression, boolean ascending) { + return new SortOrder( + expression, + ascending ? Ascending$.MODULE$ : Descending$.MODULE$, + ascending ? Ascending$.MODULE$.defaultNullOrdering() : Descending$.MODULE$.defaultNullOrdering(), + seq(new ArrayList()) + ); + } +} \ No newline at end of file diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java new file mode 100644 index 000000000..c215caec5 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.Divide; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Floor; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Multiply; +import org.apache.spark.sql.catalyst.expressions.TimeWindow; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.StringType$; +import org.opensearch.sql.ast.expression.SpanUnit; + +import static java.lang.String.format; +import static org.opensearch.sql.ast.expression.DataType.STRING; +import static org.opensearch.sql.ast.expression.SpanUnit.NONE; +import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; + +public interface WindowSpecTransformer { + + /** + * create a static window buckets based on the given value + * + * @param fieldExpression + * @param valueExpression + * @param unit + * @return + */ + static Expression window(Expression fieldExpression, Expression valueExpression, SpanUnit unit) { + // In case the unit is time unit - use TimeWindowSpec if possible + if (isTimeBased(unit)) { + return new TimeWindow(fieldExpression,timeLiteral(valueExpression, unit)); + } + // if the unit is not time base - create a math expression to bucket the span partitions + return new Multiply(new Floor(new Divide(fieldExpression, valueExpression)), valueExpression); + } + + static boolean isTimeBased(SpanUnit unit) { + return !(unit == NONE || unit == UNKNOWN); + } + + + static org.apache.spark.sql.catalyst.expressions.Literal timeLiteral( Expression valueExpression, SpanUnit unit) { + String format = format("%s %s", valueExpression.toString(), translate(unit)); + return new org.apache.spark.sql.catalyst.expressions.Literal( + translate(format, STRING), translate(STRING)); + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintPPLSparkExtensions.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintPPLSparkExtensions.scala new file mode 100644 index 000000000..26ad4b69b --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintPPLSparkExtensions.scala @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.spark.ppl.FlintSparkPPLParser + +import org.apache.spark.sql.SparkSessionExtensions + +/** + * Flint PPL Spark extension entrypoint. + */ +class FlintPPLSparkExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectParser { (spark, parser) => + new FlintSparkPPLParser(parser) + } + } +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala new file mode 100644 index 000000000..332dabc95 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Flint PPL parser that parse PPL Query Language into spark logical plan - if parse fails it will + * fall back to spark's parser. + * + * @param sparkParser + * Spark SQL parser + */ +class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface { + + /** OpenSearch (PPL) AST builder. */ + private val planTrnasormer = new CatalystQueryPlanVisitor() + + private val pplParser = new PPLSyntaxParser() + + override def parsePlan(sqlText: String): LogicalPlan = { + try { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + planTrnasormer.visit(plan(pplParser, sqlText, false), context) + context.getPlan + } catch { + // Fall back to Spark parse plan logic if flint cannot parse + case _: ParseException | _: SyntaxCheckException => sparkParser.parsePlan(sqlText) + } + } + + override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + sparkParser.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + sparkParser.parseFunctionIdentifier(sqlText) + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + sparkParser.parseMultipartIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + sparkParser.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = sparkParser.parseDataType(sqlText) + + override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) + +} diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala new file mode 100644 index 000000000..e579d82f4 --- /dev/null +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import org.antlr.v4.runtime.{CommonTokenStream, Lexer} +import org.antlr.v4.runtime.tree.ParseTree +import org.opensearch.sql.ast.statement.Statement +import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, Parser, SyntaxAnalysisErrorListener} +import org.opensearch.sql.ppl.parser.{AstBuilder, AstExpressionBuilder, AstStatementBuilder} + +class PPLSyntaxParser extends Parser { + // Analyze the query syntax + override def parse(query: String): ParseTree = { + val parser = createParser(createLexer(query)) + parser.addErrorListener(new SyntaxAnalysisErrorListener()) + parser.root() + } + + private def createParser(lexer: Lexer): OpenSearchPPLParser = { + new OpenSearchPPLParser(new CommonTokenStream(lexer)) + } + + private def createLexer(query: String): OpenSearchPPLLexer = { + new OpenSearchPPLLexer(new CaseInsensitiveCharStream(query)) + } +} + +object PlaneUtils { + def plan(parser: PPLSyntaxParser, query: String, isExplain: Boolean): Statement = { + val builder = new AstStatementBuilder( + new AstBuilder(new AstExpressionBuilder(), query), + AstStatementBuilder.StatementBuilderContext.builder()) + builder.visit(parser.parse(query)) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala new file mode 100644 index 000000000..450f21c63 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintPPLSuite.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +trait FlintPPLSuite extends SharedSparkSession { + override protected def sparkConf = { + val conf = new SparkConf() + .set("spark.ui.enabled", "false") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set("spark.sql.extensions", classOf[FlintPPLSparkExtensions].getName) + conf + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala new file mode 100644 index 000000000..a36b34ef4 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/LogicalPlanTestUtils.scala @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} + +/** + * general utility functions for ppl to spark transformation test + */ +trait LogicalPlanTestUtils { + + /** + * utility function to compare two logical plans while ignoring the auto-generated expressionId + * associated with the alias which is used for projection or aggregation + * @param plan + * @return + */ + def compareByString(plan: LogicalPlan): String = { + // Create a rule to replace Alias's ExprId with a dummy id + val rule: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + val newProjections = p.projectList.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + p.copy(projectList = newProjections) + + case agg: Aggregate => + val newGrouping = agg.groupingExpressions.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + val newAggregations = agg.aggregateExpressions.map { + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case other => other + } + agg.copy(groupingExpressions = newGrouping, aggregateExpressions = newAggregations) + + case other => other + } + + // Apply the rule using transform + val transformedPlan = plan.transform(rule) + + // Return the string representation of the transformed plan + transformedPlan.toString + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala new file mode 100644 index 000000000..8434c5bf1 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalAdvancedTranslatorTestSuite.scala @@ -0,0 +1,483 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, Like, Literal, SortOrder, UnixTimestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalAdvancedTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTrnasformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + ignore("Find What are the average prices for different types of properties") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = housing_properties | stats avg(price) by property_type", false), + context) + // SQL: SELECT property_type, AVG(price) FROM housing_properties GROUP BY property_type + val table = UnresolvedRelation(Seq("housing_properties")) + + val avgPrice = Alias(Average(UnresolvedAttribute("price")), "avg(price)")() + val propertyType = UnresolvedAttribute("property_type") + val grouped = Aggregate(Seq(propertyType), Seq(propertyType, avgPrice), table) + + val projectList = Seq( + UnresolvedAttribute("property_type"), + Alias(Average(UnresolvedAttribute("price")), "avg(price)")()) + val expectedPlan = Project(projectList, grouped) + + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore( + "Find the top 10 most expensive properties in California, including their addresses, prices, and cities") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = housing_properties | where state = `CA` | fields address, price, city | sort - price | head 10", + false), + context) + // SQL: SELECT address, price, city FROM housing_properties WHERE state = 'CA' ORDER BY price DESC LIMIT 10 + + // Constructing the expected Catalyst Logical Plan + val table = UnresolvedRelation(Seq("housing_properties")) + val filter = Filter(EqualTo(UnresolvedAttribute("state"), Literal("CA")), table) + val projectList = Seq( + UnresolvedAttribute("address"), + UnresolvedAttribute("price"), + UnresolvedAttribute("city")) + val projected = Project(projectList, filter) + val sortOrder = SortOrder(UnresolvedAttribute("price"), Descending) :: Nil + val sorted = Sort(sortOrder, true, projected) + val limited = Limit(Literal(10), sorted) + val finalProjectList = Seq( + UnresolvedAttribute("address"), + UnresolvedAttribute("price"), + UnresolvedAttribute("city")) + + val expectedPlan = Project(finalProjectList, limited) + + // Assert that the generated plan is as expected + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find the average price per unit of land space for properties in different cities") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = housing_properties | where land_space > 0 | eval price_per_land_unit = price / land_space | stats avg(price_per_land_unit) by city", + false), + context) + // SQL: SELECT city, AVG(price / land_space) AS avg_price_per_land_unit FROM housing_properties WHERE land_space > 0 GROUP BY city + val table = UnresolvedRelation(Seq("housing_properties")) + val filter = Filter(GreaterThan(UnresolvedAttribute("land_space"), Literal(0)), table) + val expression = AggregateExpression( + Average(Divide(UnresolvedAttribute("price"), UnresolvedAttribute("land_space"))), + mode = Complete, + isDistinct = false) + val aggregateExpr = Alias(expression, "avg_price_per_land_unit")() + val groupBy = Aggregate( + groupingExpressions = Seq(UnresolvedAttribute("city")), + aggregateExpressions = Seq(aggregateExpr), + filter) + + val expectedPlan = Project( + projectList = + Seq(UnresolvedAttribute("city"), UnresolvedAttribute("avg_price_per_land_unit")), + groupBy) + // Continue with your test... + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find the houses posted in the last month, how many are still for sale") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "search source=housing_properties | where listing_age >= 0 | where listing_age < 30 | stats count() by property_status", + false), + context) + // SQL: SELECT property_status, COUNT(*) FROM housing_properties WHERE listing_age >= 0 AND listing_age < 30 GROUP BY property_status; + + val filter = Filter( + LessThan(UnresolvedAttribute("listing_age"), Literal(30)), + Filter( + GreaterThanOrEqual(UnresolvedAttribute("listing_age"), Literal(0)), + UnresolvedRelation(Seq("housing_properties")))) + + val expression = AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) + + val aggregateExpressions = Seq(Alias(expression, "count")()) + + val groupByAttributes = Seq(UnresolvedAttribute("property_status")) + val expectedPlan = Aggregate(groupByAttributes, aggregateExpressions, filter) + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore( + "Find all the houses listed by agency Compass in decreasing price order. Also provide only price, address and agency name information.") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = housing_properties | where match( agency_name , `Compass` ) | fields address , agency_name , price | sort - price ", + false), + context) + // SQL: SELECT address, agency_name, price FROM housing_properties WHERE agency_name LIKE '%Compass%' ORDER BY price DESC + + val projectList = Seq( + UnresolvedAttribute("address"), + UnresolvedAttribute("agency_name"), + UnresolvedAttribute("price")) + val table = UnresolvedRelation(Seq("housing_properties")) + + val filterCondition = Like(UnresolvedAttribute("agency_name"), Literal("%Compass%"), '\\') + val filter = Filter(filterCondition, table) + + val sortOrder = Seq(SortOrder(UnresolvedAttribute("price"), Descending)) + val sort = Sort(sortOrder, true, filter) + + val expectedPlan = Project(projectList, sort) + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find details of properties owned by Zillow with at least 3 bedrooms and 2 bathrooms") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = housing_properties | where is_owned_by_zillow = 1 and bedroom_number >= 3 and bathroom_number >= 2 | fields address, price, city, listing_age", + false), + context) + // SQL:SELECT address, price, city, listing_age FROM housing_properties WHERE is_owned_by_zillow = 1 AND bedroom_number >= 3 AND bathroom_number >= 2; + val projectList = Seq( + UnresolvedAttribute("address"), + UnresolvedAttribute("price"), + UnresolvedAttribute("city"), + UnresolvedAttribute("listing_age")) + + val filterCondition = And( + And( + EqualTo(UnresolvedAttribute("is_owned_by_zillow"), Literal(1)), + GreaterThanOrEqual(UnresolvedAttribute("bedroom_number"), Literal(3))), + GreaterThanOrEqual(UnresolvedAttribute("bathroom_number"), Literal(2))) + + val expectedPlan = Project( + projectList, + Filter(filterCondition, UnresolvedRelation(TableIdentifier("housing_properties")))) + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find which cities in WA state have the largest number of houses for sale") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = housing_properties | where property_status = 'FOR_SALE' and state = 'WA' | stats count() as count by city | sort -count | head", + false), + context) + // SQL : SELECT city, COUNT(*) as count FROM housing_properties WHERE property_status = 'FOR_SALE' AND state = 'WA' GROUP BY city ORDER BY count DESC LIMIT 10; + val aggregateExpressions = Seq( + Alias( + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), + "count")()) + val groupByAttributes = Seq(UnresolvedAttribute("city")) + + val filterCondition = And( + EqualTo(UnresolvedAttribute("property_status"), Literal("FOR_SALE")), + EqualTo(UnresolvedAttribute("state"), Literal("WA"))) + + val expectedPlan = Limit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("count"), Descending)), + true, + Aggregate( + groupByAttributes, + aggregateExpressions, + Filter(filterCondition, UnresolvedRelation(TableIdentifier("housing_properties")))))) + + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find the top 5 referrers for the '/' path in apache access logs") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = access_logs | where path = `/` | top 5 referer", false), + context) + /* + SQL: SELECT referer, COUNT(*) as count + FROM access_logs + WHERE path = '/' GROUP BY referer ORDER BY count DESC LIMIT 5; + */ + val aggregateExpressions = Seq( + Alias( + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), + "count")()) + val groupByAttributes = Seq(UnresolvedAttribute("referer")) + val filterCondition = EqualTo(UnresolvedAttribute("path"), Literal("/")) + val expectedPlan = Limit( + Literal(5), + Sort( + Seq(SortOrder(UnresolvedAttribute("count"), Descending)), + true, + Aggregate( + groupByAttributes, + aggregateExpressions, + Filter(filterCondition, UnresolvedRelation(TableIdentifier("access_logs")))))) + + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore( + "Find access paths by status code. How many error responses (status code 400 or higher) are there for each access path in the Apache access logs") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = access_logs | where status >= 400 | stats count() by path, status", + false), + context) + /* + SQL: SELECT path, status, COUNT(*) as count + FROM access_logs + WHERE status >=400 GROUP BY path, status; + */ + val aggregateExpressions = Seq( + Alias( + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false), + "count")()) + val groupByAttributes = Seq(UnresolvedAttribute("path"), UnresolvedAttribute("status")) + + val filterCondition = GreaterThanOrEqual(UnresolvedAttribute("status"), Literal(400)) + + val expectedPlan = Aggregate( + groupByAttributes, + aggregateExpressions, + Filter(filterCondition, UnresolvedRelation(TableIdentifier("access_logs")))) + + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore("Find max size of nginx access requests for every 15min") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = access_logs | stats max(size) by span( request_time , 15m) ", + false), + context) + + // SQL: SELECT MAX(size) AS max_size, floor(request_time / 900) AS time_span FROM access_logs GROUP BY time_span; + val aggregateExpressions = Seq(Alias( + AggregateExpression(Max(UnresolvedAttribute("size")), mode = Complete, isDistinct = false), + "max_size")()) + val groupByAttributes = + Seq(Alias(Floor(Divide(UnresolvedAttribute("request_time"), Literal(900))), "time_span")()) + + val expectedPlan = Aggregate( + groupByAttributes, + aggregateExpressions ++ groupByAttributes, + UnresolvedRelation(TableIdentifier("access_logs"))) + + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore("Find nginx logs with non 2xx status code and url containing 'products'") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = sso_logs-nginx-* | where match(http.url, 'products') and http.response.status_code >= `300`", + false), + context) + // SQL : SELECT * FROM `sso_logs-nginx-*` WHERE http.url LIKE '%products%' AND http.response.status_code >= 300; + val aggregateExpressions = Seq(Alias( + AggregateExpression(Max(UnresolvedAttribute("size")), mode = Complete, isDistinct = false), + "max_size")()) + val groupByAttributes = + Seq(Alias(Floor(Divide(UnresolvedAttribute("request_time"), Literal(900))), "time_span")()) + + val expectedPlan = Aggregate( + groupByAttributes, + aggregateExpressions, + UnresolvedRelation(TableIdentifier("access_logs"))) + + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore( + "Find What are the details (URL, response status code, timestamp, source address) of events in the nginx logs where the response status code is 400 or higher") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = sso_logs-nginx-* | where http.response.status_code >= `400` | fields http.url, http.response.status_code, @timestamp, communication.source.address", + false), + context) + // SQL : SELECT http.url, http.response.status_code, @timestamp, communication.source.address FROM sso_logs-nginx-* WHERE http.response.status_code >= 400; + val projectList = Seq( + UnresolvedAttribute("http.url"), + UnresolvedAttribute("http.response.status_code"), + UnresolvedAttribute("@timestamp"), + UnresolvedAttribute("communication.source.address")) + + val filterCondition = + GreaterThanOrEqual(UnresolvedAttribute("http.response.status_code"), Literal(400)) + + val expectedPlan = Project( + projectList, + Filter(filterCondition, UnresolvedRelation(TableIdentifier("sso_logs-nginx-*")))) + + // Add to your unit test + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore( + "Find What are the average and max http response sizes, grouped by request method, for access events in the nginx logs") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = sso_logs-nginx-* | where event.name = `access` | stats avg(http.response.bytes), max(http.response.bytes) by http.request.method", + false), + context) + // SQL : SELECT AVG(http.response.bytes) AS avg_size, MAX(http.response.bytes) AS max_size, http.request.method FROM sso_logs-nginx-* WHERE event.name = 'access' GROUP BY http.request.method; + val aggregateExpressions = Seq( + Alias( + AggregateExpression( + Average(UnresolvedAttribute("http.response.bytes")), + mode = Complete, + isDistinct = false), + "avg_size")(), + Alias( + AggregateExpression( + Max(UnresolvedAttribute("http.response.bytes")), + mode = Complete, + isDistinct = false), + "max_size")()) + val groupByAttributes = Seq(UnresolvedAttribute("http.request.method")) + + val expectedPlan = Aggregate( + groupByAttributes, + aggregateExpressions ++ groupByAttributes, + Filter( + EqualTo(UnresolvedAttribute("event.name"), Literal("access")), + UnresolvedRelation(TableIdentifier("sso_logs-nginx-*")))) + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + } + + ignore( + "Find flights from which carrier has the longest average delay for flights over 6k miles") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = opensearch_dashboards_sample_data_flights | where DistanceMiles > 6000 | stats avg(FlightDelayMin) by Carrier | sort -`avg(FlightDelayMin)` | head 1", + false), + context) + // SQL: SELECT AVG(FlightDelayMin) AS avg_delay, Carrier FROM opensearch_dashboards_sample_data_flights WHERE DistanceMiles > 6000 GROUP BY Carrier ORDER BY avg_delay DESC LIMIT 1; + val aggregateExpressions = Seq( + Alias( + AggregateExpression( + Average(UnresolvedAttribute("FlightDelayMin")), + mode = Complete, + isDistinct = false), + "avg_delay")()) + val groupByAttributes = Seq(UnresolvedAttribute("Carrier")) + + val expectedPlan = Limit( + Literal(1), + Sort( + Seq(SortOrder(UnresolvedAttribute("avg_delay"), Descending)), + true, + Aggregate( + groupByAttributes, + aggregateExpressions ++ groupByAttributes, + Filter( + GreaterThan(UnresolvedAttribute("DistanceMiles"), Literal(6000)), + UnresolvedRelation(TableIdentifier("opensearch_dashboards_sample_data_flights")))))) + + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } + + ignore("Find What's the average ram usage of windows machines over time aggregated by 1 week") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = opensearch_dashboards_sample_data_logs | where match(machine.os, 'win') | stats avg(machine.ram) by span(timestamp,1w)", + false), + context) + // SQL : SELECT AVG(machine.ram) AS avg_ram, floor(extract(epoch from timestamp) / 604800) + // AS week_span FROM opensearch_dashboards_sample_data_logs WHERE machine.os LIKE '%win%' GROUP BY week_span + val aggregateExpressions = Seq( + Alias( + AggregateExpression( + Average(UnresolvedAttribute("machine.ram")), + mode = Complete, + isDistinct = false), + "avg_ram")()) + val groupByAttributes = Seq( + Alias( + Floor( + Divide( + UnixTimestamp(UnresolvedAttribute("timestamp"), Literal("yyyy-MM-dd HH:mm:ss")), + Literal(604800))), + "week_span")()) + + val expectedPlan = Aggregate( + groupByAttributes, + aggregateExpressions ++ groupByAttributes, + Filter( + Like(UnresolvedAttribute("machine.os"), Literal("%win%"), '\\'), + UnresolvedRelation(TableIdentifier("opensearch_dashboards_sample_data_logs")))) + + assertEquals(expectedPlan, context.getPlan) + assertEquals(logPlan, "???") + + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala new file mode 100644 index 000000000..955aac3f5 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -0,0 +1,302 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, Literal, Multiply, SortOrder, TimeWindow} +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanAggregationQueriesTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTrnasformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test average price ") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source = table | stats avg(price) ", false), context) + // SQL: SELECT avg(price) as avg_price FROM table + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + + ignore("test average price with Alias") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = table | stats avg(price) as avg_price", false), + context) + // SQL: SELECT avg(price) as avg_price FROM table + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg_price")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + + test("test average price group by product ") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = table | stats avg(price) by product", false), + context) + // SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + + test("test average price group by product and filter") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = table country ='USA' | stats avg(price) by product", false), + context) + // SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val filterExpr = EqualTo(countryField, Literal("USA")) + val filterPlan = Filter(filterExpr, table) + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + + test("test average price group by product and filter sorted") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table country ='USA' | stats avg(price) by product | sort product", + false), + context) + // SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val filterExpr = EqualTo(countryField, Literal("USA")) + val filterPlan = Filter(filterExpr, table) + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("product"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + test("create ppl simple avg age by span of interval of 10 years query test ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source = table | stats avg(age) by span(age, 10) as age_span", false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("create ppl simple avg age by span of interval of 10 years query with sort test ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats avg(age) by span(age, 10) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("create ppl simple avg age by span of interval of 10 years by country query test ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats avg(age) by span(age, 10) as age_span, country", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate( + Seq(countryAlias, span), + Seq(aggregateExpressions, countryAlias, span), + tableRelation) + val expectedPlan = Project(star, aggregatePlan) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + test("create ppl query count sales by weeks window and productId with sorting test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date | sort age_date", + false), + context) + + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = + Aggregate(Seq(windowExpression), Seq(aggregateExpressions, windowExpression), table) + + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + aggregatePlan) + + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("create ppl query count sales by days window and productId with sorting test") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan( + pplParser, + "source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date, productId | sort age_date", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val productsId = Alias(UnresolvedAttribute("productId"), "productId")() + val productsAmount = UnresolvedAttribute("productsAmount") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("transactionDate"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "age_date")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(productsAmount), isDistinct = false), + "sum(productsAmount)")() + val aggregatePlan = Aggregate( + Seq(productsId, windowExpression), + Seq(aggregateExpressions, productsId, windowExpression), + table) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_date"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala new file mode 100644 index 000000000..1b04189db --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.junit.Assert.assertEquals +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanBasicQueriesTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTrnasformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple search with only one table and no explicit fields (defaults to all fields)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=table", false), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) + assertEquals(expectedPlan, logPlan) + + } + + test("test simple search with schema.table and no explicit fields (defaults to all fields)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=schema.table", false), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) + assertEquals(expectedPlan, logPlan) + + } + + test("test simple search with schema.table and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=schema.table | fields A", false), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table"))) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with only one table with one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=table | fields A", false), context) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A")) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with only one table with two fields projected") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=t | fields A, B", false), context) + + val table = UnresolvedRelation(Seq("t")) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val expectedPlan = Project(projectList, table) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with one table with two fields projected sorted by one field") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t | sort A | fields A, B", false), context) + + val table = UnresolvedRelation(Seq("t")) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + // Sort by A ascending + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Ascending)) + val sorted = Sort(sortOrder, true, table) + val expectedPlan = Project(projectList, sorted) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test( + "test simple search with only one table with two fields with head (limit ) command projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t | fields A, B | head 5", false), context) + + val table = UnresolvedRelation(Seq("t")) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val planWithLimit = + GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with two fields with head (limit ) command projected sorted by one descending field") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source=t | sort - A | fields A, B | head 5", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val sortOrder = Seq(SortOrder(UnresolvedAttribute("A"), Descending)) + val sorted = Sort(sortOrder, true, table) + val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val projectAB = Project(projectList, sorted) + + val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } + + test( + "Search multiple tables - translated into union call - fields expected to exist in both tables ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "search source = table1, table2 | fields A, B", false), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + + val allFields1 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val allFields2 = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + + val projectedTable1 = Project(allFields1, table1) + val projectedTable2 = Project(allFields2, table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + assertEquals(expectedPlan, logPlan) + } + + test("Search multiple tables - translated into union call with fields") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source = table1, table2 ", false), context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + + val allFields1 = UnresolvedStar(None) + val allFields2 = UnresolvedStar(None) + + val projectedTable1 = Project(Seq(allFields1), table1) + val projectedTable2 = Project(Seq(allFields2), table2) + + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + assertEquals(expectedPlan, logPlan) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala new file mode 100644 index 000000000..fe9485f4b --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFiltersTranslatorTestSuite.scala @@ -0,0 +1,223 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.hadoop.conf.Configuration +import org.junit.Assert.assertEquals +import org.mockito.Mockito.when +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Like, Literal, NamedExpression, Not, Or, SortOrder, UnixTimestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class PPLLogicalPlanFiltersTranslatorTestSuite + extends SparkFunSuite + with LogicalPlanTestUtils + with Matchers { + + private val planTrnasformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test simple search with only one table with one field literal filtered ") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=t a = 1 ", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with only one table with two field with 'and' filtered ") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a = 1 AND b != 2", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterAExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) + val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) + val filterPlan = Filter(And(filterAExpr, filterBExpr), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with only one table with two field with 'or' filtered ") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a = 1 OR b != 2", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterAExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) + val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) + val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test("test simple search with only one table with two field with 'not' filtered ") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t not a = 1 or b != 2 ", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterAExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal(1))) + val filterBExpr = Not(EqualTo(UnresolvedAttribute("b"), Literal(2))) + val filterPlan = Filter(Or(filterAExpr, filterBExpr), table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field literal int equality filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a = 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field literal string equality filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, """source=t a = 'hi' | fields a""", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo(UnresolvedAttribute("a"), Literal("hi")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field literal string none equality filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, """source=t a != 'bye' | fields a""", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal("bye"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field greater than filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a > 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = GreaterThan(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field greater than equal filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a >= 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = GreaterThanOrEqual(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field lower than filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a < 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = LessThan(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field lower than equal filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a <= 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = LessThanOrEqual(UnresolvedAttribute("a"), Literal(1)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field not equal filtered and one field projected") { + val context = new CatalystPlanContext + val logPlan = + planTrnasformer.visit(plan(pplParser, "source=t a != 1 | fields a", false), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal(1))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val expectedPlan = Project(projectList, filterPlan) + assertEquals(expectedPlan, logPlan) + } + + test( + "test simple search with only one table with one field not equal filtered and one field projected and sorted") { + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit( + plan(pplParser, "source=t a != 1 | fields a | sort a", false), + context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = Not(EqualTo(UnresolvedAttribute("a"), Literal(1))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("a")) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("a"), Ascending)), + global = true, + Project(projectList, filterPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + assertEquals(compareByString(expectedPlan), compareByString(logPlan)) + } +} diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e0258d98b..c52b1d229 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -63,7 +63,7 @@ This file is divided into 3 sections: - + true @@ -106,7 +106,7 @@ This file is divided into 3 sections: - +