From b25ee02d7a3e607c1ab95b589213b5fdc55a1b53 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Mon, 25 Nov 2024 12:29:55 +0100 Subject: [PATCH] Add session to be constructor argument for SqlScriptingInterpreter --- .../sql/scripting/SqlScriptingExecution.scala | 2 +- .../scripting/SqlScriptingInterpreter.scala | 29 +++-- .../SqlScriptingExecutionSuite.scala | 106 +++++------------- .../SqlScriptingInterpreterSuite.scala | 8 +- 4 files changed, 47 insertions(+), 98 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index dfee573178eb0..c55bc91cf5cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -31,7 +31,7 @@ class SqlScriptingExecution( // Build the execution plan for the script private val executionPlan: Iterator[CompoundStatementExec] = - SqlScriptingInterpreter().buildExecutionPlan(sqlScript, session, args) + SqlScriptingInterpreter(session).buildExecutionPlan(sqlScript, args) private var current = getNextResult diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 3acf75b9fa110..2a395435f8af2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -25,8 +25,11 @@ import org.apache.spark.sql.catalyst.trees.Origin /** * SQL scripting interpreter - builds SQL script execution plan. + * + * @param session + * Spark session that SQL script is executed within. */ -case class SqlScriptingInterpreter() { +case class SqlScriptingInterpreter(session: SparkSession) { /** * Build execution plan and return statements that need to be executed, @@ -34,16 +37,13 @@ case class SqlScriptingInterpreter() { * * @param compound * CompoundBody for which to build the plan. - * @param session - * Spark session that SQL script is executed within. * @return * Iterator through collection of statements to be executed. */ def buildExecutionPlan( compound: CompoundBody, - session: SparkSession, args: Map[String, Expression]): Iterator[CompoundStatementExec] = { - transformTreeIntoExecutable(compound, session, args) + transformTreeIntoExecutable(compound, args) .asInstanceOf[CompoundBodyExec].getTreeIterator } @@ -65,14 +65,11 @@ case class SqlScriptingInterpreter() { * * @param node * Root node of the parsed tree. - * @param session - * Spark session that SQL script is executed within. * @return * Executable statement. */ private def transformTreeIntoExecutable( node: CompoundPlanStatement, - session: SparkSession, args: Map[String, Expression]): CompoundStatementExec = node match { case CompoundBody(collection, label) => @@ -86,7 +83,7 @@ case class SqlScriptingInterpreter() { .map(new SingleStatementExec(_, Origin(), args, isInternal = true)) .reverse new CompoundBodyExec( - collection.map(st => transformTreeIntoExecutable(st, session, args)) ++ dropVariables, + collection.map(st => transformTreeIntoExecutable(st, args)) ++ dropVariables, label) case IfElseStatement(conditions, conditionalBodies, elseBody) => @@ -97,9 +94,9 @@ case class SqlScriptingInterpreter() { args, isInternal = false)) val conditionalBodiesExec = conditionalBodies.map(body => - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec]) val unconditionalBodiesExec = elseBody.map(body => - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec]) new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) @@ -112,9 +109,9 @@ case class SqlScriptingInterpreter() { args, isInternal = false)) val conditionalBodiesExec = conditionalBodies.map(body => - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec]) val unconditionalBodiesExec = elseBody.map(body => - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec]) + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec]) new CaseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) @@ -126,7 +123,7 @@ case class SqlScriptingInterpreter() { args, isInternal = false) val bodyExec = - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec] + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec] new WhileStatementExec(conditionExec, bodyExec, label, session) case RepeatStatement(condition, body, label) => @@ -137,11 +134,11 @@ case class SqlScriptingInterpreter() { args, isInternal = false) val bodyExec = - transformTreeIntoExecutable(body, session, args).asInstanceOf[CompoundBodyExec] + transformTreeIntoExecutable(body, args).asInstanceOf[CompoundBodyExec] new RepeatStatementExec(conditionExec, bodyExec, label, session) case LoopStatement(body, label) => - val bodyExec = transformTreeIntoExecutable(body, session, args) + val bodyExec = transformTreeIntoExecutable(body, args) .asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 0d0b3acc0faf6..140fa495447e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -90,9 +90,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |FROM t; |END |""".stripMargin - val expected = Seq( - Seq(Row(false)) // select - ) + val expected = Seq(Seq(Row(false))) verifySqlScriptResult(sqlScript, expected) } } @@ -106,9 +104,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |SELECT var; |END |""".stripMargin - val expected = Seq( - Seq(Row(2)), // select - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScript, expected) } @@ -121,9 +117,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |SELECT var; |END |""".stripMargin - val expected = Seq( - Seq(Row(2)), // select - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScript, expected) } @@ -149,7 +143,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq(Row(1)), // select Seq(Row(2)), // select - Seq(Row(4)), // select + Seq(Row(4)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -164,9 +158,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |DROP TEMPORARY VARIABLE var; |END |""".stripMargin - val expected = Seq( - Seq(Row(2)), // select - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScript, expected) } @@ -179,9 +171,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq( - Seq(Row(42)) - ) + val expected = Seq(Seq(Row(42))) verifySqlScriptResult(commands, expected) } @@ -214,7 +204,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(42))) verifySqlScriptResult(commands, expected) } @@ -234,7 +223,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -251,7 +239,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -271,7 +258,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(44))) verifySqlScriptResult(commands, expected) } @@ -291,7 +277,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -314,7 +299,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -405,7 +389,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -429,7 +412,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -447,7 +429,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq() + val expected = Seq.empty verifySqlScriptResult(commands, expected) } @@ -538,7 +520,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq(Seq(Row(42))) verifySqlScriptResult(commands, expected) } @@ -562,7 +543,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq(Seq(Row(44))) verifySqlScriptResult(commands, expected) } @@ -580,7 +560,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq() + val expected = Seq.empty verifySqlScriptResult(commands, expected) } @@ -598,7 +578,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - val expected = Seq(Seq(Row(43))) verifySqlScriptResult(commands, expected) } @@ -615,11 +594,10 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END WHILE; |END |""".stripMargin - val expected = Seq( Seq(Row(0)), // select i Seq(Row(1)), // select i - Seq(Row(2)), // select i + Seq(Row(2)) // select i ) verifySqlScriptResult(commands, expected) } @@ -635,7 +613,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END WHILE; |END |""".stripMargin - val expected = Seq.empty verifySqlScriptResult(commands, expected) } @@ -656,12 +633,11 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END WHILE; |END |""".stripMargin - val expected = Seq( Seq(Row(0, 0)), // select i, j Seq(Row(0, 1)), // select i, j Seq(Row(1, 0)), // select i, j - Seq(Row(1, 1)), // select i, j + Seq(Row(1, 1)) // select i, j ) verifySqlScriptResult(commands, expected) } @@ -678,10 +654,9 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |END WHILE; |END |""".stripMargin - val expected = Seq( Seq(Row(42)), // select - Seq(Row(42)), // select + Seq(Row(42)) // select ) verifySqlScriptResult(commands, expected) } @@ -700,11 +675,10 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END REPEAT; |END |""".stripMargin - val expected = Seq( Seq(Row(0)), // select i Seq(Row(1)), // select i - Seq(Row(2)), // select i + Seq(Row(2)) // select i ) verifySqlScriptResult(commands, expected) } @@ -723,9 +697,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |END |""".stripMargin - val expected = Seq( - Seq(Row(3)), // select i - ) + val expected = Seq(Seq(Row(3))) verifySqlScriptResult(commands, expected) } @@ -752,7 +724,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { Seq(Row(0, 0)), // select i, j Seq(Row(0, 1)), // select i, j Seq(Row(1, 0)), // select i, j - Seq(Row(1, 1)), // select i, j + Seq(Row(1, 1)) // select i, j ) verifySqlScriptResult(commands, expected) } @@ -773,7 +745,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq(Row(42)), // select - Seq(Row(42)), // select + Seq(Row(42)) // select ) verifySqlScriptResult(commands, expected) } @@ -789,9 +761,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | SELECT 2; | END; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -804,9 +774,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | LEAVE lbl; | END WHILE; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -820,9 +788,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | UNTIL 1 = 2 | END REPEAT; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select 1 - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -839,9 +805,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END WHILE; | SELECT x; |END""".stripMargin - val expected = Seq( - Seq(Row(2)), // select - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScriptText, expected) } @@ -859,9 +823,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END REPEAT; | SELECT x; |END""".stripMargin - val expected = Seq( - Seq(Row(2)), // select x - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScriptText, expected) } @@ -878,9 +840,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | UNTIL 1 = 2 | END REPEAT; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select 1 - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -895,9 +855,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END WHILE; | END WHILE; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -919,7 +877,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq(Row(1)), // select 1 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -948,7 +906,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { Seq(Row(2)), // select 2 Seq(Row(1)), // select 1 Seq(Row(2)), // select 2 - Seq(Row(2)), // select x + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -973,7 +931,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq(Row(1)), // select 1 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -998,7 +956,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select x Seq(Row(2)), // select x Seq(Row(3)), // select x - Seq(Row(3)), // select x + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1030,7 +988,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { Seq(Row(0, 0)), // select x, y Seq(Row(0, 1)), // select x, y Seq(Row(1, 0)), // select x, y - Seq(Row(1, 1)), // select x, y + Seq(Row(1, 1)) // select x, y ) verifySqlScriptResult(commands, expected) } @@ -1051,9 +1009,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END LOOP; | SELECT x; |END""".stripMargin - val expected = Seq( - Seq(Row(2)), // select x - ) + val expected = Seq(Seq(Row(2))) verifySqlScriptResult(sqlScriptText, expected) } @@ -1068,9 +1024,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | END LOOP; | END LOOP; |END""".stripMargin - val expected = Seq( - Seq(Row(1)) // select 1 - ) + val expected = Seq(Seq(Row(1))) verifySqlScriptResult(sqlScriptText, expected) } @@ -1096,7 +1050,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq(Row(1)), // select 1 Seq(Row(1)), // select 1 - Seq(Row(3)), // select x + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8d7366db0d6ab..177ffc24d180a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -43,9 +43,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { private def runSqlScript( sqlText: String, args: Map[String, Expression] = Map.empty): Array[DataFrame] = { - val interpreter = SqlScriptingInterpreter() + val interpreter = SqlScriptingInterpreter(spark) val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] - val executionPlan = interpreter.buildExecutionPlan(compoundBody, spark, args) + val executionPlan = interpreter.buildExecutionPlan(compoundBody, args) executionPlan.flatMap { case statement: SingleStatementExec => if (statement.isExecuted) { @@ -236,9 +236,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin - val expected = Seq( - Seq(Row(42)) - ) + val expected = Seq(Seq(Row(42))) verifySqlScriptResult(commands, expected) }