From 09c132ca61189e02814dc874348e7a6e278783a6 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 3 Dec 2024 12:58:50 +0100 Subject: [PATCH] Fix iterator --- .../resources/error/error-conditions.json | 18 ++ .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 6 + .../sql/catalyst/parser/SqlBaseParser.g4 | 32 ++- .../sql/catalyst/parser/AstBuilder.scala | 75 +++++- .../logical/SqlScriptingLogicalPlans.scala | 57 +++- .../spark/sql/errors/SqlScriptingErrors.scala | 26 ++ .../parser/SqlScriptingParserSuite.scala | 61 ++++- .../sql/scripting/SqlScriptingExecution.scala | 44 ++-- .../SqlScriptingExecutionContext.scala | 51 +++- .../scripting/SqlScriptingExecutionNode.scala | 26 +- .../scripting/SqlScriptingInterpreter.scala | 118 +++++++-- .../SqlScriptingExecutionSuite.scala | 248 +++++++++++++++++- .../SqlScriptingInterpreterSuite.scala | 1 - 13 files changed, 707 insertions(+), 56 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 77437f6c56179..3148bbcbcd0b2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1194,6 +1194,18 @@ ], "sqlState" : "42614" }, + "DUPLICATE_CONDITION_NAME_FOR_DIFFERENT_SQL_STATE" : { + "message" : [ + "Found duplicate condition name for different SQL states. Please, remove one of them." + ], + "sqlState" : "42710" + }, + "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE" : { + "message" : [ + "Found duplicate handlers for the same SQL state . Please, remove one of them." + ], + "sqlState" : "42710" + }, "DUPLICATE_KEY" : { "message" : [ "Found duplicate keys ." @@ -1218,6 +1230,12 @@ }, "sqlState" : "4274K" }, + "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER" : { + "message" : [ + "Found duplicate SQL state for the same handler. Please, remove one of them." + ], + "sqlState" : "42710" + }, "EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED" : { "message" : [ "Previous node emitted a row with eventTime= which is older than current_watermark_value=", diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index eeebe89de8ff1..27cbf29f9f586 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -174,8 +174,10 @@ COMPACTIONS: 'COMPACTIONS'; COMPENSATION: 'COMPENSATION'; COMPUTE: 'COMPUTE'; CONCATENATE: 'CONCATENATE'; +CONDITION: 'CONDITION'; CONSTRAINT: 'CONSTRAINT'; CONTAINS: 'CONTAINS'; +CONTINUE: 'CONTINUE'; COST: 'COST'; CREATE: 'CREATE'; CROSS: 'CROSS'; @@ -226,6 +228,7 @@ EXCEPT: 'EXCEPT'; EXCHANGE: 'EXCHANGE'; EXCLUDE: 'EXCLUDE'; EXISTS: 'EXISTS'; +EXIT: 'EXIT'; EXPLAIN: 'EXPLAIN'; EXPORT: 'EXPORT'; EXTEND: 'EXTEND'; @@ -244,6 +247,7 @@ FOR: 'FOR'; FOREIGN: 'FOREIGN'; FORMAT: 'FORMAT'; FORMATTED: 'FORMATTED'; +FOUND: 'FOUND'; FROM: 'FROM'; FULL: 'FULL'; FUNCTION: 'FUNCTION'; @@ -253,6 +257,7 @@ GLOBAL: 'GLOBAL'; GRANT: 'GRANT'; GROUP: 'GROUP'; GROUPING: 'GROUPING'; +HANDLER: 'HANDLER'; HAVING: 'HAVING'; BINARY_HEX: 'X'; HOUR: 'HOUR'; @@ -412,6 +417,7 @@ SORTED: 'SORTED'; SOURCE: 'SOURCE'; SPECIFIC: 'SPECIFIC'; SQL: 'SQL'; +SQLEXCEPTION: 'SQLEXCEPTION'; START: 'START'; STATISTICS: 'STATISTICS'; STORED: 'STORED'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 4b7b4634b74b2..3ef0f00111f67 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -60,7 +60,9 @@ compoundBody ; compoundStatement - : statement + : declareCondition + | declareHandler + | statement | setStatementWithOptionalVarKeyword | beginEndCompoundBlock | ifElseStatement @@ -73,6 +75,23 @@ compoundStatement | forStatement ; +conditionValue + : stringLit + | multipartIdentifier + ; + +conditionValueList + : ((conditionValues+=conditionValue (COMMA conditionValues+=conditionValue)*) | SQLEXCEPTION | NOT FOUND) + ; + +declareCondition + : DECLARE multipartIdentifier CONDITION (FOR stringLit)? + ; + +declareHandler + : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN compoundBody END | statement | setStatementWithOptionalVarKeyword) + ; + setStatementWithOptionalVarKeyword : SET variable? assignmentList #setVariableWithOptionalKeyword | SET variable? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ @@ -1592,6 +1611,7 @@ ansiNonReserved | COMPUTE | CONCATENATE | CONTAINS + | CONTINUE | COST | CUBE | CURRENT @@ -1631,6 +1651,7 @@ ansiNonReserved | EXCHANGE | EXCLUDE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTEND @@ -1644,11 +1665,13 @@ ansiNonReserved | FOLLOWING | FORMAT | FORMATTED + | FOUND | FUNCTION | FUNCTIONS | GENERATED | GLOBAL | GROUPING + | HANDLER | HOUR | HOURS | IDENTIFIER_KW @@ -1780,6 +1803,7 @@ ansiNonReserved | SORTED | SOURCE | SPECIFIC + | SQLEXCEPTION | START | STATISTICS | STORED @@ -1927,8 +1951,10 @@ nonReserved | COMPENSATION | COMPUTE | CONCATENATE + | CONDITION | CONSTRAINT | CONTAINS + | CONTINUE | COST | CREATE | CUBE @@ -1978,6 +2004,7 @@ nonReserved | EXCLUDE | EXECUTE | EXISTS + | EXIT | EXPLAIN | EXPORT | EXTEND @@ -1996,6 +2023,7 @@ nonReserved | FOREIGN | FORMAT | FORMATTED + | FOUND | FROM | FUNCTION | FUNCTIONS @@ -2004,6 +2032,7 @@ nonReserved | GRANT | GROUP | GROUPING + | HANDLER | HAVING | HOUR | HOURS @@ -2153,6 +2182,7 @@ nonReserved | SOURCE | SPECIFIC | SQL + | SQLEXCEPTION | START | STATISTICS | STORED diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 04dff0b6e6680..6a2c31a4dac3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set} +import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Set} import scala.jdk.CollectionConverters._ import scala.util.{Left, Right} @@ -142,6 +142,52 @@ class AstBuilder extends DataTypeAstBuilder } } + override def visitConditionValue(ctx: ConditionValueContext): String = { + Option(ctx.multipartIdentifier()).map(_.getText) + .getOrElse(ctx.stringLit().getText).replace("'", "") + } + + override def visitConditionValueList(ctx: ConditionValueListContext): Seq[String] = { + Option(ctx.SQLEXCEPTION()).map(_ => Seq("SQLEXCEPTION")).getOrElse { + Option(ctx.NOT()).map(_ => Seq("NOT FOUND")).getOrElse { + val buff = scala.collection.mutable.Set[String]() + ctx.conditionValues.forEach { conditionValue => + val elem = visit(conditionValue).asInstanceOf[String] + if (buff(elem)) { + throw SqlScriptingErrors.duplicateSqlStateForSameHandler(CurrentOrigin.get, elem) + } + buff += elem + } + buff.toSeq + } + } + } + + override def visitDeclareCondition(ctx: DeclareConditionContext): ErrorCondition = { + val conditionName = ctx.multipartIdentifier().getText + val conditionValue = Option(ctx.stringLit()).map(_.getText.replace("'", "")).getOrElse("45000") + + val sqlStateRegex = "^[A-Za-z0-9]{5}$".r + assert(sqlStateRegex.findFirstIn(conditionValue).isDefined) + + ErrorCondition(conditionName, conditionValue) + } + + def visitDeclareHandlerImpl( + ctx: DeclareHandlerContext, labelCtx: SqlScriptingLabelContext): ErrorHandler = { + val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]] + val handlerType = Option(ctx.EXIT()).map(_ => HandlerType.EXIT).getOrElse(HandlerType.CONTINUE) + + val body = if (!ctx.compoundBody().isEmpty) { + visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = true, labelCtx) + } else { + val logicalPlan = visitChildren(ctx).asInstanceOf[LogicalPlan] + CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)), None) + }.asInstanceOf[CompoundBody] + + ErrorHandler(conditions, body, handlerType) + } + override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = { val labelCtx = new SqlScriptingLabelContext() visitCompoundBodyImpl(ctx.compoundBody(), Some("root"), allowVarDeclare = true, labelCtx) @@ -153,8 +199,27 @@ class AstBuilder extends DataTypeAstBuilder allowVarDeclare: Boolean, labelCtx: SqlScriptingLabelContext): CompoundBody = { val buff = ListBuffer[CompoundPlanStatement]() - ctx.compoundStatements.forEach( - compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx)) + val handlers = ListBuffer[ErrorHandler]() + val conditions = HashMap[String, String]() + val sqlStates = Set[String]() + + ctx.compoundStatements.forEach(compoundStatement => { + val stmt = visitCompoundStatementImpl(compoundStatement, labelCtx) + stmt match { + case handler: ErrorHandler => handlers += handler + case condition: ErrorCondition => + if (conditions.contains(condition.conditionName)) { + throw SqlScriptingErrors.duplicateConditionNameForDifferentSqlState( + CurrentOrigin.get, condition.conditionName) + } + conditions += condition.conditionName -> condition.value + sqlStates += condition.value + case s => buff += s + } + }) + +// ctx.compoundStatements.forEach( +// compoundStatement => buff += visitCompoundStatementImpl(compoundStatement, labelCtx)) val compoundStatements = buff.toList @@ -183,7 +248,7 @@ class AstBuilder extends DataTypeAstBuilder case _ => } - CompoundBody(buff.toSeq, label) + CompoundBody(buff.toSeq, label, handlers.toSeq, conditions) } private def visitBeginEndCompoundBlockImpl( @@ -228,6 +293,8 @@ class AstBuilder extends DataTypeAstBuilder visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx) case forStatementContext: ForStatementContext => visitForStatementImpl(forStatementContext, labelCtx) + case declareHandlerContext: DeclareHandlerContext => + visitDeclareHandlerImpl(declareHandlerContext, labelCtx) case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement] } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index 4faf1f5d26672..49fb997f21810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable.HashMap + import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.HandlerType.HandlerType import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -65,7 +68,9 @@ case class SingleStatement(parsedPlan: LogicalPlan) */ case class CompoundBody( collection: Seq[CompoundPlanStatement], - label: Option[String]) extends Command with CompoundPlanStatement { + label: Option[String], + handlers: Seq[ErrorHandler] = Seq.empty, + conditions: HashMap[String, String] = HashMap()) extends Command with CompoundPlanStatement { override def children: Seq[LogicalPlan] = collection @@ -295,3 +300,53 @@ case class ForStatement( ForStatement(query, variableName, body, label) } } + +/** + * Logical operator for an error condition. + * @param conditionName Name of the error condition. + * @param value SQLSTATE or Error Code. + */ +case class ErrorCondition( + conditionName: String, + value: String) extends CompoundPlanStatement { + override def output: Seq[Attribute] = Seq.empty + + /** + * Returns a Seq of the children of this node. + * Children should not change. Immutability required for containsChild optimization + */ + override def children: Seq[LogicalPlan] = Seq.empty + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this.copy() +} + +object HandlerType extends Enumeration { + type HandlerType = Value + val EXIT, CONTINUE = Value +} + +/** + * Logical operator for an error condition. + * @param conditions Name of the error condition variable for which the handler is built. + * @param body CompoundBody of the handler. + * @param handlerType Type of the handler (CONTINUE or EXIT). + */ +case class ErrorHandler( + conditions: Seq[String], + body: CompoundBody, + handlerType: HandlerType) extends CompoundPlanStatement { + override def output: Seq[Attribute] = Seq.empty + + /** + * Returns a Seq of the children of this node. + * Children should not change. Immutability required for containsChild optimization + */ + override def children: Seq[LogicalPlan] = Seq(body) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = { + assert(newChildren.length == 1) + ErrorHandler(conditions, newChildren(0).asInstanceOf[CompoundBody], handlerType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 2a4b8fde6989c..b05ade60a2226 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -133,4 +133,30 @@ private[sql] object SqlScriptingErrors { cause = null, messageParameters = Map("labelName" -> toSQLStmt(labelName))) } + + def duplicateHandlerForSameSqlState(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + + def duplicateSqlStateForSameHandler(origin: Origin, sqlState: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER", + cause = null, + messageParameters = Map("sqlState" -> sqlState)) + } + + def duplicateConditionNameForDifferentSqlState( + origin: Origin, + conditionName: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "DUPLICATE_CONDITION_NAME_FOR_DIFFERENT_SQL_STATE", + cause = null, + messageParameters = Map("conditionName" -> conditionName)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index ab647f83b42a4..e5930c8225103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ErrorHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -2167,6 +2167,65 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 3") } +// test("declare condition: default sqlstate") { +// val sqlScriptText = +// """ +// |BEGIN +// | DECLARE test CONDITION; +// |END""".stripMargin +// val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] +// assert(tree.conditions.size == 1) +// assert(tree.conditions("test").equals("45000")) // Default SQLSTATE +// } + + test("declare condition: custom sqlstate") { + val sqlScriptText = + """ + |BEGIN + | SELECT 1; + | DECLARE test CONDITION FOR '12000'; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.conditions.size == 1) + assert(tree.conditions("test").equals("12000")) + } + + test("declare handler") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test BEGIN SELECT 1; END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ErrorHandler]) + } + + test("declare handler single statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test SELECT 1; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.handlers.length == 1) + assert(tree.handlers.head.isInstanceOf[ErrorHandler]) + } + + test("declare handler duplicate sqlState") { + val sqlScriptText = + """ + |BEGIN + | DECLARE CONTINUE HANDLER FOR test, test BEGIN SELECT 1; END; + |END""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "DUPLICATE_SQL_STATE_FOR_SAME_HANDLER", + parameters = Map("sqlState" -> "test")) + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 0dd04a12f0ab6..d133ce4bba922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.scripting -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} @@ -40,24 +40,25 @@ class SqlScriptingExecution( // Frames to keep what is being executed. private val context: SqlScriptingExecutionContext = { val ctx = new SqlScriptingExecutionContext() - val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx) - val frame = new SqlScriptingExecutionFrame(executionPlan) - frame.enterScope(sqlScript.label.get) - ctx.frames - .addOne(new SqlScriptingExecutionFrame( - interpreter.buildExecutionPlan(sqlScript, args, ctx))) + interpreter.buildExecutionPlan(sqlScript, args, ctx) ctx } - private var current = getNextResult + private var current: Option[DataFrame] = None - override def hasNext: Boolean = current.isDefined + override def hasNext: Boolean = { + val stmt = getNextResult + if (stmt.isDefined) { + current = stmt + true + } else { + false + } + } override def next(): DataFrame = { if (!hasNext) throw SparkException.internalError("No more elements to iterate through.") - val nextDataFrame = current.get - current = getNextResult - nextDataFrame + current.get } /** Helper method to iterate get next statements from the first available frame. */ @@ -80,6 +81,7 @@ class SqlScriptingExecution( case Some(stmt: SingleStatementExec) if !stmt.isExecuted => withErrorHandling { val df = stmt.buildDataFrame(session) + print("GOT DATAFRAME\n") df.logicalPlan match { case _: CommandResult => // pass case _ => return Some(df) // If the statement is a result, return it to the caller. @@ -92,18 +94,26 @@ class SqlScriptingExecution( None } - private def handleException(e: Throwable): Unit = { - // Rethrow the exception. - // TODO: SPARK-48353 Add error handling for SQL scripts - throw e + private def handleException(e: SparkThrowable): Unit = { + print("ERROR HAPPENED\n") + context.findHandler(e.getSqlState) match { + case Some(handler) => + context.frames.addOne(new SqlScriptingExecutionFrame(handler.getTreeIterator)) + case None => + throw e.asInstanceOf[Throwable] + } } def withErrorHandling(f: => Unit): Unit = { try { f } catch { - case e: Throwable => + case e: SparkThrowable => + // Try to find a handler for the exception. handleException(e) + case exception: Exception => + // Throw the exception as is. + throw exception } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index b3011f5b3a65f..6a2adc75cd9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.scripting -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.spark.SparkException @@ -28,11 +28,11 @@ class SqlScriptingExecutionContext { // List of frames that are currently active. val frames: ListBuffer[SqlScriptingExecutionFrame] = ListBuffer() - def enterScope(label: String): Unit = { + def enterScope(label: String, conditionHandlerMap: HashMap[String, ErrorHandlerExec]): Unit = { if (frames.isEmpty) { throw SparkException.internalError(s"Cannot enter scope: no frames.") } - frames.last.enterScope(label) + frames.last.enterScope(label, conditionHandlerMap) } def exitScope(label: String): Unit = { @@ -41,6 +41,21 @@ class SqlScriptingExecutionContext { } frames.last.exitScope(label) } + + def findHandler(condition: String): Option[ErrorHandlerExec] = { + if (frames.isEmpty) { + throw SparkException.internalError(s"Cannot find handler: no frames.") + } + + frames.reverseIterator.foreach { frame => + val handler = frame.findHandler(condition) + if (handler.isDefined) { + return handler + } + } + None + } + } /** @@ -53,8 +68,7 @@ class SqlScriptingExecutionFrame( executionPlan: Iterator[CompoundStatementExec]) extends Iterator[CompoundStatementExec] { // List of scopes that are currently active. - private val scopes: ListBuffer[SqlScriptingExecutionScope] = - ListBuffer(new SqlScriptingExecutionScope("root")) + private val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty override def hasNext: Boolean = executionPlan.hasNext @@ -63,8 +77,22 @@ class SqlScriptingExecutionFrame( executionPlan.next() } - def enterScope(label: String): Unit = { - scopes.addOne(new SqlScriptingExecutionScope(label)) + def findHandler(condition: String): Option[ErrorHandlerExec] = { + if (scopes.isEmpty) { + throw SparkException.internalError(s"Cannot find handler: no scopes.") + } + + scopes.reverseIterator.foreach { scope => + val handler = scope.findHandler(condition) + if (handler.isDefined) { + return handler + } + } + None + } + + def enterScope(label: String, conditionHandlerMap: HashMap[String, ErrorHandlerExec]): Unit = { + scopes.addOne(new SqlScriptingExecutionScope(label, conditionHandlerMap)) } def exitScope(label: String): Unit = { @@ -89,4 +117,11 @@ class SqlScriptingExecutionFrame( * @param label * Label of the scope. */ -class SqlScriptingExecutionScope(val label: String) +class SqlScriptingExecutionScope( + val label: String, + val conditionHandlerMap: HashMap[String, ErrorHandlerExec]) { + + def findHandler(condition: String): Option[ErrorHandlerExec] = { + conditionHandlerMap.get(condition) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 9a3714ef6fff9..b1fae2ee64a6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.scripting import java.util - +import scala.collection.mutable.HashMap import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.HandlerType.HandlerType import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -174,7 +175,8 @@ class SingleStatementExec( class CompoundBodyExec( statements: Seq[CompoundStatementExec], label: Option[String] = None, - context: SqlScriptingExecutionContext) + context: SqlScriptingExecutionContext, + conditionHandlerMap: HashMap[String, ErrorHandlerExec] = HashMap()) extends NonLeafStatementExec { private var localIterator = statements.iterator @@ -182,14 +184,14 @@ class CompoundBodyExec( private var scopeEntered = false private var scopeExited = false - private def enterScope(): Unit = { + def enterScope(): Unit = { if (label.isDefined && !scopeEntered) { scopeEntered = true - context.enterScope(label.get) + context.enterScope(label.get, conditionHandlerMap) } } - private def exitScope(): Unit = { + def exitScope(): Unit = { if (label.isDefined && !scopeExited) { scopeExited = true context.exitScope(label.get) @@ -920,3 +922,17 @@ class ForStatementExec( body.reset() } } + +/** + * Executable node for ErrorHandlerStatement. + * @param body Executable CompoundBody of the error handler. + */ +class ErrorHandlerExec( + val body: CompoundBodyExec, + val handlerType: HandlerType, + val scopeToExit: Option[String]) extends NonLeafStatementExec { + + override def getTreeIterator: Iterator[CompoundStatementExec] = body.getTreeIterator + + override def reset(): Unit = body.reset() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 4aac8f5fe1364..7e277fa421e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable.HashMap import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} -import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, HandlerType, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.errors.SqlScriptingErrors /** * SQL scripting interpreter - builds SQL script execution plan. @@ -46,8 +48,11 @@ case class SqlScriptingInterpreter(session: SparkSession) { compound: CompoundBody, args: Map[String, Expression], context: SqlScriptingExecutionContext): Iterator[CompoundStatementExec] = { - transformTreeIntoExecutable(compound, args, context) - .asInstanceOf[CompoundBodyExec].getTreeIterator + val compoundBodyExec = transformTreeIntoExecutable(compound, args, context) + .asInstanceOf[CompoundBodyExec] + context.frames.addOne(new SqlScriptingExecutionFrame(compoundBodyExec.getTreeIterator)) + compoundBodyExec.enterScope() + compoundBodyExec.getTreeIterator } /** @@ -63,6 +68,92 @@ case class SqlScriptingInterpreter(session: SparkSession) { case _ => None } + /** + * Transform [[CompoundBody]] into [[CompoundBodyExec]]. + * @param compoundBody + * CompoundBody to be transformed into CompoundBodyExec. + * @param isExitHandler + * Flag to indicate if the body is an exit handler body to add leave statement at the end. + * @param exitHandlerLabel + * If body is an exit handler body, this is the label of surrounding CompoundBody + * that should be exited. + * @return + * Executable version of the CompoundBody . + */ + private def transformBodyIntoExec( + compoundBody: CompoundBody, + args: Map[String, Expression], + context: SqlScriptingExecutionContext, + isExitHandler: Boolean = false, + exitHandlerLabel: String = ""): CompoundBodyExec = { + val variables = compoundBody.collection.flatMap { + case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) + case _ => None + } + val dropVariables = variables + .map(varName => DropVariable(varName, ifExists = true)) + .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) + .reverse + + // Create a map of conditions (SqlStates) to their respective handlers. + val conditionHandlerMap = HashMap[String, ErrorHandlerExec]() + compoundBody.handlers.foreach(handler => { + val handlerBodyExec = + transformBodyIntoExec( + handler.body, + args, + context, + handler.handlerType == HandlerType.EXIT, + compoundBody.label.get) + + // Execution node of handler. + val scopeToExit = if (handler.handlerType == HandlerType.EXIT) { + Some(compoundBody.label.get) + } else { + None + } + + val handlerExec = new ErrorHandlerExec( + handlerBodyExec, + handler.handlerType, + scopeToExit) + + // For each condition handler is defined for, add corresponding key value pair + // to the conditionHandlerMap. + handler.conditions.foreach(condition => { + // Condition can either be the key in conditions map or SqlState. + val conditionValue = compoundBody.conditions.getOrElse(condition, condition) + if (conditionHandlerMap.contains(conditionValue)) { + throw SqlScriptingErrors.duplicateHandlerForSameSqlState( + CurrentOrigin.get, conditionValue) + } else { + conditionHandlerMap.put(conditionValue, handlerExec) + } + }) + }) + + if (isExitHandler) { + // Create leave statement to exit the surrounding CompoundBody after handler execution. + val leave = new LeaveStatementExec(exitHandlerLabel) + val statements = compoundBody.collection.map(st => + transformTreeIntoExecutable(st, args, context)) ++ + dropVariables :+ leave + + return new CompoundBodyExec( + statements, + compoundBody.label, + context, + conditionHandlerMap) + } + + new CompoundBodyExec( + compoundBody.collection + .map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables, + compoundBody.label, + context, + conditionHandlerMap) + } + /** * Transform the parsed tree to the executable node. * @@ -78,20 +169,9 @@ case class SqlScriptingInterpreter(session: SparkSession) { args: Map[String, Expression], context: SqlScriptingExecutionContext): CompoundStatementExec = node match { - case CompoundBody(collection, label) => + case body: CompoundBody => // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = collection.flatMap { - case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) - case _ => None - } - val dropVariables = variables - .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) - .reverse - new CompoundBodyExec( - collection.map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables, - label, - context) + transformBodyIntoExec(body, args, context) case IfElseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => @@ -177,5 +257,9 @@ case class SqlScriptingInterpreter(session: SparkSession) { args, isInternal = false, context) + + case _ => + throw new UnsupportedOperationException( + s"Unsupported operation in the execution plan.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index bbeae942f9fe7..0fc9b3789ffa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -55,6 +56,251 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } } + private def runSqlScriptWithHandler( + sqlText: String, + args: Map[String, Expression] = Map.empty): Seq[Row] = { + val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] + val sse = new SqlScriptingExecution(compoundBody, spark, args) + + var result: Option[Seq[Row]] = None + var df: Option[DataFrame] = None + + if (sse.hasNext) { + df = Some(sse.next()) + } + + while (true) { + sse.withErrorHandling { + while (sse.hasNext) { + print("RESULT\n") + print(df.get.collect()) + df = Some(sse.next()) + } + } + result = Some(df.get.collect().toSeq) + } + + + if (result.isDefined) { + print(result.get) + print("\n") + result.get + } else { + Seq.empty[Row] + } + } + + // Handler tests + test("duplicate handler") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SET VAR flag = 1; + | END; + | DECLARE CONTINUE HANDLER FOR '22012' + | BEGIN + | SET VAR flag = 2; + | END; + | SELECT 1/0; + | SELECT flag; + |END + |""".stripMargin + checkError( + exception = intercept[SqlScriptingException] { + verifySqlScriptResult(sqlScript, Seq.empty) + }, + condition = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE", + parameters = Map("sqlState" -> "22012")) + } + + test("handler - continue resolve in the same block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 5; + | END; + | SELECT 2; + | SELECT 3; + | SELECT 1/0; + | SELECT 4; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(4)), // select + Array(Row(1)) // select + ) + runSqlScriptWithHandler(sqlScript) + } + + test("handler - exit resolve in the same block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | BEGIN + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE EXIT HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 3; + | END; + | SELECT 2; + | SELECT 3; + | SELECT 1/0; + | SELECT 4; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(2)), // select + Seq(Row(3)), // select + Seq(Row(-1)), // select flag + Seq.empty[Row], // set flag + Seq(Row(1)) // select flag from the outer body + ) + runSqlScriptWithHandler(sqlScript) +// verifySqlScriptResult(sqlScript, expected) + } + + test("handler - continue resolve in outer block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 10; + | END; + | SELECT 2; + | BEGIN + | SELECT 3; + | BEGIN + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT 7; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(5)), // select + Array(Row(6)), // select + Array(Row(7)), // select + Array(Row(1)) // select + ) + runSqlScriptWithHandler(sqlScript) +// verifySqlScriptResult(sqlScript, expected) + } + + test("handler - continue resolve in the same block nested") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | SELECT 2; + | BEGIN + | SELECT 3; + | BEGIN + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE CONTINUE HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 15; + | END; + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT 7; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + Array(Row(5)), // select + Array(Row(6)), // select + Array(Row(7)), // select + Array(Row(1)) // select + ) + runSqlScriptWithHandler(sqlScript) + // verifySqlScriptResult(sqlScript, expected) + } + + + test("handler - exit resolve in outer block") { + val sqlScript = + """ + |BEGIN + | DECLARE flag INT = -1; + | BEGIN + | DECLARE zero_division CONDITION FOR '22012'; + | DECLARE EXIT HANDLER FOR zero_division + | BEGIN + | SELECT flag; + | SET VAR flag = 25; + | END; + | SELECT 2; + | SELECT 3; + | BEGIN + | SELECT 4; + | SELECT 1/0; + | SELECT 5; + | END; + | SELECT 6; + | END; + | SELECT flag; + |END + |""".stripMargin + val expected = Seq( + Array.empty[Row], // declare var + Array(Row(2)), // select + Array(Row(3)), // select + Array(Row(4)), // select + Array(Row(-1)), // select flag + Array.empty[Row], // set flag + // skip select 5 + // skip select 6 + Array(Row(1)) // select flag from the outer body + ) + runSqlScriptWithHandler(sqlScript) + // verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2435cd0c5bff7..849875178b3a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -47,7 +47,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] val context = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) - context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan)) executionPlan.flatMap { case statement: SingleStatementExec => if (statement.isExecuted) {