diff --git a/build.sbt b/build.sbt index 593542e5c..73fb481a6 100644 --- a/build.sbt +++ b/build.sbt @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind"), "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), - "software.amazon.awssdk" % "auth-crt" % "2.25.23", + "software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index 49dc8e355..efb001785 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -13,8 +13,39 @@ import org.opensearch.flint.common.model.FlintStatement trait QueryResultWriter { /** - * Writes the given DataFrame, which represents the result of a query execution, to an external - * data storage based on the provided FlintStatement metadata. + * Writes the given DataFrame to an external data storage based on the FlintStatement metadata. + * This method is responsible for persisting the query results. + * + * Note: This method typically involves I/O operations and may trigger Spark actions to + * materialize the DataFrame if it hasn't been processed yet. + * + * @param dataFrame + * The DataFrame containing the query results to be written. + * @param flintStatement + * The FlintStatement containing metadata that guides the writing process. */ def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit + + /** + * Defines transformations on the given DataFrame and triggers an action to process it. This + * method applies necessary transformations based on the FlintStatement metadata and executes an + * action to compute the result. + * + * Note: Calling this method will trigger the actual data processing in Spark. If the Spark SQL + * thread is waiting for the result of a query, termination on the same thread will be blocked + * until the action completes. + * + * @param dataFrame + * The DataFrame to be processed. + * @param flintStatement + * The FlintStatement containing statement metadata. + * @param queryStartTime + * The start time of the query execution. + * @return + * The processed DataFrame after applying transformations and executing an action. + */ + def processDataFrame( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryStartTime: Long): DataFrame } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index 596626698..e10b2e2a6 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -10,7 +10,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLEvalITSuite @@ -22,6 +22,7 @@ class FlintSparkPPLEvalITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log" + private val duplicatesNullableTestTable = "spark_catalog.default.duplicates_nullable_test" override def beforeAll(): Unit = { super.beforeAll() @@ -29,6 +30,7 @@ class FlintSparkPPLEvalITSuite // Create test table createPartitionedStateCountryTable(testTable) createTableHttpLog(testTableHttpLog) + createDuplicationNullableTable(duplicatesNullableTestTable) } protected override def afterEach(): Unit = { @@ -632,8 +634,45 @@ class FlintSparkPPLEvalITSuite EqualTo(Literal(true), and) } - // Todo excluded fields not support yet + test("Test eval and signum function") { + val frame = sql(s""" + | source = $duplicatesNullableTestTable | fields id | sort id | eval i = pow(-2, id), s = signum(i) | head 5 + | """.stripMargin) + val rows = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, -2d, -1d), + Row(2, 4d, 1d), + Row(3, -8d, -1d), + Row(4, 16d, 1d), + Row(5, -32d, -1d)) + assert(rows.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val tablePlan = + UnresolvedRelation(Seq("spark_catalog", "default", "duplicates_nullable_test")) + val projectIdPlan = Project(Seq(UnresolvedAttribute("id")), tablePlan) + val sortPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, projectIdPlan) + val evalPlan = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "pow", + Seq(Literal(-2), UnresolvedAttribute("id")), + isDistinct = false), + "i")(), + Alias( + UnresolvedFunction("signum", Seq(UnresolvedAttribute("i")), isDistinct = false), + "s")()), + sortPlan) + val localLimitPlan = LocalLimit(Literal(5), evalPlan) + val globalLimitPlan = GlobalLimit(Literal(5), localLimitPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), globalLimitPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + // Todo excluded fields not support yet ignore("test single eval expression with excluded fields") { val frame = sql(s""" | source = $testTable | eval new_field = "New Field" | fields - age diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala new file mode 100644 index 000000000..ee08e692a --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -0,0 +1,407 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{And, Descending, EqualTo, InSubquery, ListQuery, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLInSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test where id in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // InSubquery: (0, 2, 3, 5, 6) + val frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id) in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // InSubquery: (0, 2, 3, 5, 6) + val frame = sql(s""" + source = $outerTable + | | where (id) in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id, name) in (select uid, name from inner)") { + // InSubquery: (0, 2, 3, 5) + val frame = sql(s""" + source = $outerTable + | | where (id, name) in [ + | source = $innerTable | fields uid, name + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + ListQuery( + Project(Seq(UnresolvedAttribute("uid"), UnresolvedAttribute("name")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where id not in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // Not InSubquery: (1, 4) + val frame = sql(s""" + source = $outerTable + | | where id not in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id, name) not in (select uid, name from inner)") { + // Not InSubquery: (1, 4, 6) + val frame = sql(s""" + source = $outerTable + | | where (id, name) not in [ + | source = $innerTable | fields uid, name + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1001, "Hello", 70000), Row(1004, "David", 0), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + ListQuery( + Project(Seq(UnresolvedAttribute("uid"), UnresolvedAttribute("name")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test empty subquery") { + // id (0, 1, 2, 3, 4, 5, 6), uid () + // InSubquery: () + // Not InSubquery: (0, 1, 2, 3, 4, 5, 6) + var frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | where uid = 0000 | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + var results: Array[Row] = frame.collect() + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + var expectedResults: Array[Row] = Array() + assert(results.sorted.sameElements(expectedResults.sorted)) + + frame = sql(s""" + source = $outerTable + | | where id not in [ + | source = $innerTable | where uid = 0000 | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + results = frame.collect() + expectedResults = Array( + Row(1000, "Jake", 100000), + Row(1001, "Hello", 70000), + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1004, "David", 0), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test nested subquery") { + val frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable + | | where occupation in [ + | source = $nestedInnerTable | where occupation != 'Engineer' | fields occupation + | ] + | | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + frame.show() + frame.explain(true) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val filter = + Filter(Not(EqualTo(UnresolvedAttribute("occupation"), Literal("Engineer"))), inner2) + val inSubqueryForOccupation = + Filter( + InSubquery( + Seq(UnresolvedAttribute("occupation")), + ListQuery(Project(Seq(UnresolvedAttribute("occupation")), filter))), + inner1) + val inSubqueryForId = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inSubqueryForOccupation))), + outer) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + inSubqueryForId) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test in-subquery as a join filter") { + val frame = sql(s""" + | source = $outerTable + | | inner join left=a, right=b + | ON a.id = b.uid AND b.occupation in [ + | source = $nestedInnerTable| where occupation != 'Engineer' | fields occupation + | ] + | $innerTable + | | fields a.id, a.name, a.salary + | """.stripMargin) + frame.explain(true) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val plan1 = SubqueryAlias("a", table1) + val plan2 = SubqueryAlias("b", table2) + val filter = + Filter(Not(EqualTo(UnresolvedAttribute("occupation"), Literal("Engineer"))), inner) + val inSubqueryForOccupation = + InSubquery( + Seq(UnresolvedAttribute("b.occupation")), + ListQuery(Project(Seq(UnresolvedAttribute("occupation")), filter))) + val joinCondition = + And( + EqualTo(UnresolvedAttribute("a.id"), UnresolvedAttribute("b.uid")), + inSubqueryForOccupation) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("a.id"), + UnresolvedAttribute("a.name"), + UnresolvedAttribute("a.salary")), + joinPlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("throw exception because the number of columns not match output of subquery") { + val ex = intercept[AnalysisException](sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | fields uid, department + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin)) + assert(ex.getMessage.contains( + "The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery")) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 02baaab45..6b3996f52 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -333,6 +333,7 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | eval a = 10 | fields a,b,c` - `source = table | eval a = a * 2 | stats avg(a)` - `source = table | eval a = abs(a) | where a > 0` + - `source = table | eval a = signum(a) | where a < 0` **Aggregations** - `source = table | stats avg(a) ` @@ -434,6 +435,97 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ Details of Lookup command syntax, see [PPL-Lookup-Command](../docs/PPL-Lookup-command.md) +**InSubquery** +- `source = outer | where a in [ source = inner | fields b ]` +- `source = outer | where (a) in [ source = inner | fields b ]` +- `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` +- `source = outer | where a not in [ source = inner | fields b ]` +- `source = outer | where (a) not in [ source = inner | fields b ]` +- `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) +- `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) + +SQL Migration examples with IN-Subquery PPL: +1. tpch q4 (in-subquery with aggregation) +```sql +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and o_orderkey in ( + select + l_orderkey + from + lineitem + where l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +``` +Rewritten by PPL InSubquery query: +```sql +source = orders +| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" and o_orderkey IN + [ source = lineitem + | where l_commitdate < l_receiptdate + | fields l_orderkey + ] +| stats count(1) as order_count by o_orderpriority +| sort o_orderpriority +| fields o_orderpriority, order_count +``` +2.tpch q20 (nested in-subquery) +```sql +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name +``` +Rewritten by PPL InSubquery query: +```sql +source = supplier +| where s_suppkey IN [ + source = partsupp + | where ps_partkey IN [ + source = part + | where like(p_name, "forest%") + | fields p_partkey + ] + | fields ps_suppkey + ] +| inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' + nation +| sort s_name +``` + --- #### Experimental Commands: - `correlation` - [See details](../docs/PPL-Correlation-command.md) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index aed885afb..c27721dfd 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -276,6 +276,7 @@ POWER: 'POWER'; RAND: 'RAND'; ROUND: 'ROUND'; SIGN: 'SIGN'; +SIGNUM: 'SIGNUM'; SQRT: 'SQRT'; TRUNCATE: 'TRUNCATE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index f35f8d743..33ba5c5ed 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -36,6 +36,10 @@ queryStatement : pplCommands (PIPE commands)* ; +subSearch + : searchCommand (PIPE commands)* + ; + // commands pplCommands : searchCommand @@ -377,6 +381,12 @@ logicalExpression comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr | valueExpression IN valueList # inExpr + | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + ; + +valueExpressionList + : valueExpression + | LT_PRTHS valueExpression (COMMA valueExpression)* RT_PRTHS ; valueExpression @@ -599,6 +609,7 @@ mathematicalFunctionName | RAND | ROUND | SIGN + | SIGNUM | SQRT | TRUNCATE | trigonometricFunctionName @@ -1049,4 +1060,13 @@ commandNames | SPARKLINE | C | DC + // JOIN TYPE + | OUTER + | INNER + | CROSS + | LEFT + | RIGHT + | FULL + | SEMI + | ANTI ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index aea7bbb1d..76f9479f4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -19,6 +19,7 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -289,4 +290,7 @@ public T visitExplain(Explain node, C context) { return visitStatement(node, context); } + public T visitInSubquery(InSubquery node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java new file mode 100644 index 000000000..ed40e4b45 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +import java.util.Arrays; +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class InSubquery extends UnresolvedExpression { + private final List value; + private final UnresolvedPlan query; + + @Override + public List getChild() { + return value; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitInSubquery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 211b73084..cbb21a51e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -13,6 +13,8 @@ import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortDirection; @@ -42,6 +44,7 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -77,6 +80,8 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -125,6 +130,10 @@ public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } + public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { + return plan.accept(this, context); + } + /** * Handle Query Statement. */ @@ -342,7 +351,16 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) { if(node instanceof DescribeCommand) { return context.with(new PrintLiteralCommandDescriptionLogicalPlan(((DescribeCommand) node).getDescription())); } - if (!node.isExcluded()) { + if (node.isExcluded()) { + List intersect = context.getProjectedFields().stream() + .filter(node.getProjectList()::contains) + .collect(Collectors.toList()); + if (!intersect.isEmpty()) { + // Fields in parent projection, but they have be excluded in child. For example, + // source=t | fields - A, B | fields A, B, C will throw "[Field A, Field B] can't be resolved" + throw new SyntaxCheckException(intersect + " can't be resolved"); + } + } else { context.withProjectedFields(node.getProjectList()); } LogicalPlan child = node.getChild().get(0).accept(this, context); @@ -482,7 +500,7 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { /** * Expression Analyzer. */ - public static class ExpressionAnalyzer extends AbstractNodeVisitor { + public class ExpressionAnalyzer extends AbstractNodeVisitor { public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { return unresolved.accept(this, context); @@ -729,5 +747,24 @@ public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : WindowFunction"); } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 7019e11fb..dc882e3a2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -88,6 +88,12 @@ public UnresolvedPlan visitQueryStatement(OpenSearchPPLParser.QueryStatementCont return ctx.commands().stream().map(this::visit).reduce(pplCommand, (r, e) -> e.attach(r)); } + @Override + public UnresolvedPlan visitSubSearch(OpenSearchPPLParser.SubSearchContext ctx) { + UnresolvedPlan searchCommand = visit(ctx.searchCommand()); + return ctx.commands().stream().map(this::visit).reduce(searchCommand, (r, e) -> e.attach(r)); + } + /** Search command. */ @Override public UnresolvedPlan visitSearchFrom(OpenSearchPPLParser.SearchFromContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 2706d85e5..f5e9269be 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; @@ -62,6 +63,13 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index a0138be08..9b004ba60 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -5,7 +5,6 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ - package org.opensearch.flint.spark.ppl import org.antlr.v4.runtime.{CommonTokenStream, Lexer} @@ -42,8 +41,11 @@ object PlaneUtils { // Create an instance of each visitor val expressionBuilder = new AstExpressionBuilder() val astBuilder = new AstBuilder(expressionBuilder, query, parser.getParserVersion()) + expressionBuilder.setAstBuilder(astBuilder) + // description visitor val astDescriptionBuilder = new AstCommandDescriptionVisitor(expressionBuilder, query, parser.getParserVersion()) + // statement visitor val statementContext = AstStatementBuilder.StatementBuilderContext.builder() // Chain visitors diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 30deecc31..cc87e8853 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers @@ -311,4 +312,33 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logPlan, false) } + + test("test fields + then - field list") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source=t | fields + A, B, C | fields - A, B"), + context) + + val table = UnresolvedRelation(Seq("t")) + val projectABC = Project( + Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"), UnresolvedAttribute("C")), + table) + val dropList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B")) + val dropAB = DataFrameDropColumns(dropList, projectABC) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropAB) + comparePlans(expectedPlan, logPlan, false) + } + + test("test fields - then + field list") { + val context = new CatalystPlanContext + val thrown = intercept[SyntaxCheckException] { + planTransformer.visit( + plan(pplParser, "source=t | fields - A, B | fields + A, B, C"), + context) + } + assert( + thrown.getMessage + === "[Field(field=A, fieldArgs=[]), Field(field=B, fieldArgs=[])] can't be resolved") + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala new file mode 100644 index 000000000..03bcdd623 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala @@ -0,0 +1,365 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThanOrEqual, InSubquery, LessThan, ListQuery, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} + +class PPLLogicalPlanInSubqueryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test where a in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a) in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a, b, c) in (select d, e, f from inner)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where (a, b, c) in [ + | source = spark_catalog.default.inner | fields d, e, f + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + ListQuery( + Project( + Seq(UnresolvedAttribute("d"), UnresolvedAttribute("e"), UnresolvedAttribute("f")), + inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where a not in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a not in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a, b, c) not in (select d, e, f from inner)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where (a, b, c) not in [ + | source = spark_catalog.default.inner | fields d, e, f + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + ListQuery( + Project( + Seq(UnresolvedAttribute("d"), UnresolvedAttribute("e"), UnresolvedAttribute("f")), + inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test nested subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner1 + | | where b in [ + | source = spark_catalog.default.inner2 | fields c + | ] + | | fields b + | ] + | | sort - a + | | fields a, d + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "inner1")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "inner2")) + val inSubqueryForB = + Filter( + InSubquery( + Seq(UnresolvedAttribute("b")), + ListQuery(Project(Seq(UnresolvedAttribute("c")), inner2))), + inner1) + val inSubqueryForA = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inSubqueryForB))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubqueryForA) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("d")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + // TODO throw exception with syntax check, now it throw AnalysisException in Spark + ignore("The number of columns not match output of subquery") { + val context = new CatalystPlanContext + val ex = intercept[SyntaxCheckException] { + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b, d + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + } + assert(ex.getMessage === "The number of columns not match output of subquery") + } + + test("test tpch q4: in-subquery with aggregation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = orders + | | where o_orderdate >= "1993-07-01" AND o_orderdate < "1993-10-01" AND o_orderkey IN + | [ source = lineitem + | | where l_commitdate < l_receiptdate + | | fields l_orderkey + | ] + | | stats count(1) as order_count by o_orderpriority + | | sort o_orderpriority + | | fields o_orderpriority, order_count + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("orders")) + val inner = UnresolvedRelation(Seq("lineitem")) + val inSubquery = + Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("o_orderdate"), Literal("1993-07-01")), + LessThan(UnresolvedAttribute("o_orderdate"), Literal("1993-10-01"))), + InSubquery( + Seq(UnresolvedAttribute("o_orderkey")), + ListQuery( + Project( + Seq(UnresolvedAttribute("l_orderkey")), + Filter( + LessThan( + UnresolvedAttribute("l_commitdate"), + UnresolvedAttribute("l_receiptdate")), + inner))))), + outer) + val o_orderpriorityAlias = Alias(UnresolvedAttribute("o_orderpriority"), "o_orderpriority")() + val groupByAttributes = Seq(o_orderpriorityAlias) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), + "order_count")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, o_orderpriorityAlias), inSubquery) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("o_orderpriority"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project( + Seq(UnresolvedAttribute("o_orderpriority"), UnresolvedAttribute("order_count")), + sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test tpch q20 (partial): nested in-subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = supplier + | | where s_suppkey IN [ + | source = partsupp + | | where ps_partkey IN [ + | source = part + | | where like(p_name, "forest%") + | | fields p_partkey + | ] + | | fields ps_suppkey + | ] + | | inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' + | nation + | | sort s_name + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("supplier")) + val inner = UnresolvedRelation(Seq("partsupp")) + val nestedInner = UnresolvedRelation(Seq("part")) + val right = UnresolvedRelation(Seq("nation")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("s_suppkey")), + ListQuery( + Project( + Seq(UnresolvedAttribute("ps_suppkey")), + Filter( + InSubquery( + Seq(UnresolvedAttribute("ps_partkey")), + ListQuery(Project( + Seq(UnresolvedAttribute("p_partkey")), + Filter( + UnresolvedFunction( + "like", + Seq(UnresolvedAttribute("p_name"), Literal("forest%")), + isDistinct = false), + nestedInner)))), + inner)))), + outer) + val leftPlan = SubqueryAlias("l", inSubquery) + val rightPlan = SubqueryAlias("r", right) + val joinCondition = + And( + EqualTo(UnresolvedAttribute("s_nationkey"), UnresolvedAttribute("n_nationkey")), + EqualTo(UnresolvedAttribute("n_name"), Literal("CANADA"))) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("s_name"), Ascending)), global = true, joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala index ed72a3d40..feaa7d8ca 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -191,4 +191,18 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val expectedPlan = Project(projectList, evalProject) comparePlans(expectedPlan, logPlan, false) } + + test("test signum") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = signum(b)"), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("signum", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 42b1ae2f6..56bd9cb00 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -18,7 +18,6 @@ case class CommandContext( jobType: String, sessionId: String, sessionManager: SessionManager, - queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 24d68fd47..c076f9974 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava +import org.opensearch.common.Strings import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage} import org.opensearch.flint.core.metrics.MetricConstants @@ -533,7 +534,7 @@ trait FlintJobExecutor { } def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { - if (className.isEmpty) { + if (Strings.isNullOrEmpty(className)) { defaultConstructor } else { try { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index a0516a37a..cdeebe663 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val queryResultWriter = - instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( applicationId, jobId, @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor { jobType, sessionId, sessionManager, - queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - + val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(statementsExecutionManager, commandContext, commandState) + processCommands( + statementsExecutionManager, + queryResultWriter, + commandContext, + commandState) val ( updatedLastActivityTime, @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -525,6 +527,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( statementExecutionManager, + queryResultWriter, flintStatement, state, context) @@ -532,7 +535,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = returnedVerificationResult finalizeCommand( statementExecutionManager, - context, + queryResultWriter, dataToWrite, flintStatement, statementTimerContext) @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ private def finalizeCommand( statementExecutionManager: StatementExecutionManager, - commandContext: CommandContext, + queryResultWriter: QueryResultWriter, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, statementTimerContext: Timer.Context): Unit = { - import commandContext._ try { dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processStatementOnVerification( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, flintStatement: FlintStatement, commandState: CommandState, commandContext: CommandContext) = { @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementsExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -801,7 +809,14 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - statementsExecutionManager.executeStatement(flintStatement) + val startTime = System.currentTimeMillis() + // Execute the statement and get the resulting DataFrame + // This step may involve Spark transformations, but not necessarily actions + val df = statementsExecutionManager.executeStatement(flintStatement) + // Process the DataFrame, applying any necessary transformations + // and triggering Spark actions to materialize the results + // This is where the actual data processing occurs + queryResultWriter.processDataFrame(df, flintStatement, startTime) }(executionContext) // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) @@ -998,11 +1013,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def instantiateQueryResultWriter( - sparkConf: SparkConf, - context: Map[String, Any]): QueryResultWriter = { + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { instantiate( - new QueryResultWriterImpl(context), - sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + new QueryResultWriterImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) } private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index deee6eb1d..58d868a2e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -61,7 +61,6 @@ case class JobOperator( jobType, "", // FlintJob doesn't have sessionId null, // FlintJob doesn't have SessionManager - null, // FlintJob doesn't have QueryResultWriter Duration.Inf, // FlintJob doesn't have queryExecutionTimeout -1, // FlintJob doesn't have inactivityLimitMillis -1, // FlintJob doesn't have queryWaitTimeMillis diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala index 23d7f42a1..61c6e0747 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -10,9 +10,14 @@ import org.opensearch.flint.common.model.FlintStatement import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.util.CleanerFactory -class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { +class QueryResultWriterImpl(commandContext: CommandContext) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + private val context = commandContext.sessionManager.getSessionContext private val resultIndex = context("resultIndex").asInstanceOf[String] // Initialize OSClient with Flint options because custom session manager implementation should not have it in the context private val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -20,4 +25,27 @@ class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) } + + override def processDataFrame( + dataFrame: DataFrame, + statement: FlintStatement, + queryStartTime: Long): DataFrame = { + import commandContext._ + + /** + * Reformat the given DataFrame to the desired format for OpenSearch storage. + */ + getFormattedData( + applicationId, + jobId, + dataFrame, + spark, + dataSource, + statement.queryId, + statement.query, + sessionId, + queryStartTime, + currentTimeProvider, + CleanerFactory.cleaner(false)) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 4e9435f7b..432d6df11 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) } override def executeStatement(statement: FlintStatement): DataFrame = { - import commandContext._ - executeQuery( - applicationId, - jobId, - spark, - statement.query, - dataSource, + import commandContext.spark + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup( statement.queryId, - sessionId, - false) + "Job group for " + statement.queryId, + interruptOnCancel = true) + spark.sql(statement.query) } private def createOpenSearchQueryReader() = { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 355bd9ede..5eeccce73 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -675,7 +675,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -748,7 +747,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -761,6 +759,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -809,7 +808,6 @@ class FlintREPLTest when(mockSparkSession.sparkContext).thenReturn(sparkContext) // Assume handleQueryException logs the error and returns an error message string - val mockErrorString = "Error due to syntax" when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) @@ -824,7 +822,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -837,6 +834,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -1076,7 +1074,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, @@ -1146,7 +1143,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, @@ -1212,7 +1208,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1283,7 +1278,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1367,7 +1361,6 @@ class FlintREPLTest override val osClient: OSClient = mockOSClient override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1377,7 +1370,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1453,7 +1445,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60,