Skip to content

Commit

Permalink
Merge with the latest changes and update tests:
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Aug 6, 2024
1 parent b5516e4 commit 56ca9fd
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ class AstBuilder extends DataTypeAstBuilder
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
val stmt = Option(ctx.beginEndCompoundBlock()).
getOrElse(Option(ctx.declareHandler()).getOrElse(ctx.declareCondition()))
getOrElse(Option(ctx.declareHandler()).
getOrElse(Option(ctx.declareCondition()).
getOrElse(ctx.ifElseStatement())))
visit(stmt).asInstanceOf[CompoundPlanStatement]
}
}
Expand All @@ -242,6 +244,7 @@ class AstBuilder extends DataTypeAstBuilder
buff.toSeq
}
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
IfElseStatement(
Expand All @@ -255,7 +258,6 @@ class AstBuilder extends DataTypeAstBuilder
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
)
}
}

override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = {
val conditionName = ctx.multipartIdentifier().getText
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,20 @@

package org.apache.spark.sql.catalyst.plans

import java.util.IdentityHashMap

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.rules.RuleId
import org.apache.spark.sql.catalyst.rules.UnknownRuleId
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag}
import org.apache.spark.sql.catalyst.rules.{RuleId, UnknownRuleId}
import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION}
import org.apache.spark.sql.catalyst.trees.TreePatternBits
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag, TreePatternBits}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.collection.BitSet

import java.util.IdentityHashMap
import scala.collection.mutable

/**
* An abstraction of the Spark SQL query plan tree, which can be logical or physical. This class
* defines some basic properties of a query plan node, as well as some new transform APIs to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ trait NonLeafStatementExec extends CompoundStatementExec {
session: SparkSession,
statement: LeafStatementExec): Boolean = statement match {
case statement: SingleStatementExec =>
assert(!statement.isExecuted)
statement.isExecuted = true

// DataFrame evaluates to True if it is single row, single column
// of boolean type with value True.
val df = Dataset.ofRows(session, statement.parsedPlan)
Expand Down Expand Up @@ -190,7 +187,7 @@ class SingleStatementExec(
* Spark session.
*/
class CompoundBodyExec(
label: Option[String],
label: Option[String] = None,
statements: Seq[CompoundStatementExec],
conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = mutable.HashMap(),
session: SparkSession)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ case class SqlScriptingInterpreter(session: SparkSession) {
* @return
* Iterator through collection of statements to be executed.
*/
def buildExecutionPlan(
compound: CompoundBody,
session: SparkSession): Iterator[CompoundStatementExec] = {
transformTreeIntoExecutable(compound, session).asInstanceOf[CompoundBodyExec].getTreeIterator
def buildExecutionPlan(compound: CompoundBody): Iterator[CompoundStatementExec] = {
transformTreeIntoExecutable(compound).asInstanceOf[CompoundBodyExec].getTreeIterator
}

/**
Expand Down Expand Up @@ -115,13 +113,10 @@ case class SqlScriptingInterpreter(session: SparkSession) {
*
* @param node
* Root node of the parsed tree.
* @param session
* Spark session that SQL script is executed within.
* @return
* Executable statement.
*/
private def transformTreeIntoExecutable(
node: CompoundPlanStatement, session: SparkSession): CompoundStatementExec =
private def transformTreeIntoExecutable(node: CompoundPlanStatement): CompoundStatementExec =
node match {
case body: CompoundBody =>
// TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing.
Expand All @@ -130,9 +125,9 @@ case class SqlScriptingInterpreter(session: SparkSession) {
val conditionsExec = conditions.map(condition =>
new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false))
val conditionalBodiesExec = conditionalBodies.map(body =>
transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec])
transformTreeIntoExecutable(body).asInstanceOf[CompoundBodyExec])
val unconditionalBodiesExec = elseBody.map(body =>
transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec])
transformTreeIntoExecutable(body).asInstanceOf[CompoundBodyExec])
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session)
case sparkStatement: SingleStatement =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi

// Tests
test("test body - single statement") {
val iter = new CompoundBodyExec(Seq(TestLeafStatement("one"))).getTreeIterator
val iter = TestBody(Seq(TestLeafStatement("one"))).getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("one"))
}

test("test body - no nesting") {
val iter = new CompoundBodyExec(
val iter = TestBody(
Seq(
TestLeafStatement("one"),
TestLeafStatement("two"),
Expand All @@ -75,26 +75,26 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("test body - nesting") {
val iter = new CompoundBodyExec(
val iter = TestBody(
Seq(
new CompoundBodyExec(Seq(TestLeafStatement("one"), TestLeafStatement("two"))),
TestBody(Seq(TestLeafStatement("one"), TestLeafStatement("two"))),
TestLeafStatement("three"),
new CompoundBodyExec(Seq(TestLeafStatement("four"), TestLeafStatement("five")))))
TestBody(Seq(TestLeafStatement("four"), TestLeafStatement("five")))))
.getTreeIterator
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("one", "two", "three", "four", "five"))
}

test("if else - enter body of the IF clause") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = true, description = "con1")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1")))
TestBody(Seq(TestLeafStatement("body1")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body2")))),
session = spark
)
)).getTreeIterator
Expand All @@ -103,15 +103,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else - enter body of the ELSE clause") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1")))
TestBody(Seq(TestLeafStatement("body1")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body2")))),
session = spark
)
)).getTreeIterator
Expand All @@ -120,17 +120,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - enter body of the IF clause") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = true, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))),
session = spark
)
)).getTreeIterator
Expand All @@ -139,17 +139,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - enter body of the ELSE IF clause") {
val iter = new CompoundBodyExec(Seq(
val iter = new TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = true, description = "con2")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))),
session = spark
)
)).getTreeIterator
Expand All @@ -158,19 +158,19 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - enter body of the second ELSE IF clause") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2"),
TestIfElseCondition(condVal = true, description = "con3")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2"))),
new CompoundBodyExec(Seq(TestLeafStatement("body3")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2"))),
TestBody(Seq(TestLeafStatement("body3")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body4")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body4")))),
session = spark
)
)).getTreeIterator
Expand All @@ -179,17 +179,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - enter body of the ELSE clause") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2")))
),
elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
elseBody = Some(TestBody(Seq(TestLeafStatement("body3")))),
session = spark
)
)).getTreeIterator
Expand All @@ -198,15 +198,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - without else (successful check)") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = true, description = "con2")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2")))
),
elseBody = None,
session = spark
Expand All @@ -217,15 +217,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
}

test("if else if - without else (unsuccessful checks)") {
val iter = new CompoundBodyExec(Seq(
val iter = TestBody(Seq(
new IfElseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2")
),
conditionalBodies = Seq(
new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
new CompoundBodyExec(Seq(TestLeafStatement("body2")))
TestBody(Seq(TestLeafStatement("body1"))),
TestBody(Seq(TestLeafStatement("body2")))
),
elseBody = None,
session = spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
| END IF;
|END
|""".stripMargin
val expected = Seq(Seq(Row(42)))
val expected = Seq(Array(Row(42)))
verifySqlScriptResult(commands, expected)
}

Expand All @@ -462,7 +462,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
| END IF;
|END
|""".stripMargin
val expected = Seq(Seq(Row(42)))
val expected = Seq(Array(Row(42)))
verifySqlScriptResult(commands, expected)
}

Expand Down Expand Up @@ -536,7 +536,7 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
|END
|""".stripMargin

val expected = Seq(Seq(Row(44)))
val expected = Seq(Array(Row(44)))
verifySqlScriptResult(commands, expected)
}

Expand Down

0 comments on commit 56ca9fd

Please sign in to comment.