Skip to content

Commit

Permalink
Fix continue handler and add check for duplicate handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Aug 7, 2024
1 parent 6962986 commit 29b4f3d
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 10 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,12 @@
],
"sqlState" : "42614"
},
"DUPLICATE_HANDLER_FOR_SAME_SQL_STATE" : {
"message" : [
"Found duplicate handlers for the same SQL state <sqlState>. Please, remove one of them."
],
"sqlState" : "42710"
},
"DUPLICATE_KEY" : {
"message" : [
"Found duplicate keys <keyColumn>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class AstBuilder extends DataTypeAstBuilder
ctx: CompoundOrSingleStatementContext): CompoundBody = withOrigin(ctx) {
Option(ctx.singleCompoundStatement()).map { s =>
if (!SQLConf.get.sqlScriptingEnabled) {
throw SqlScriptingErrors.sqlScriptingNotEnabled()
throw SqlScriptingErrors.sqlScriptingNotEnabled(CurrentOrigin.get)
}
visit(s).asInstanceOf[CompoundBody]
}.getOrElse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ private[sql] object SqlScriptingErrors {
messageParameters = Map("sqlScriptingEnabled" -> SQLConf.SQL_SCRIPTING_ENABLED.key))
}

def duplicateHandlerForSameSqlState(origin: Origin, sqlState: String): Throwable = {
new SqlScriptingException(
origin = origin,
errorClass = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE",
cause = null,
messageParameters = Map("sqlState" -> sqlState))
}

def variableDeclarationNotAllowedInScope(
origin: Origin,
varName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ class SingleStatementExec(
* SQL query text.
*/
def getText: String = {
assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined)
origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1)
// assert(origin.sqlText.isDefined && origin.startIndex.isDefined && origin.stopIndex.isDefined)
try {
origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 1)
} catch {
case e: Exception =>
"DROP VARIABLE"
}
}

override def reset(): Unit = {
Expand Down Expand Up @@ -211,6 +216,7 @@ class CompoundBodyExec(
getHandler(statement.errorState.get).foreach { handler =>
statement.reset() // Clear all flags and result
handler.reset()
returnHere = curr
curr = Some(handler.getHandlerBody)
}
}
Expand Down Expand Up @@ -242,6 +248,7 @@ class CompoundBodyExec(
private var curr: Option[CompoundStatementExec] =
if (localIterator.hasNext) Some(localIterator.next()) else None
private var stopIteration: Boolean = false // hard stop iteration flag
private var returnHere: Option[CompoundStatementExec] = None

def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator

Expand All @@ -261,7 +268,7 @@ class CompoundBodyExec(
case _ => throw SparkException.internalError(
"Unknown statement type encountered during SQL script interpretation.")
}
(localIterator.hasNext || childHasNext) && !stopIteration
(localIterator.hasNext || childHasNext || returnHere.isDefined) && !stopIteration
}

@scala.annotation.tailrec
Expand All @@ -273,9 +280,7 @@ class CompoundBodyExec(
handleLeave(leave)
case Some(statement: LeafStatementExec) =>
statement.execute(session) // Execute the leaf statement
if (!statement.raisedError) {
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
}
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
handleError(statement) // Handle error if raised
case Some(body: NonLeafStatementExec) =>
if (body.getTreeIterator.hasNext) {
Expand All @@ -290,7 +295,12 @@ class CompoundBodyExec(
case nonLeafStatement: NonLeafStatementExec => nonLeafStatement
}
} else {
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
if (returnHere.isDefined) {
curr = returnHere
returnHere = None
} else {
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
}
next()
}
case _ => throw SparkException.internalError(
Expand Down Expand Up @@ -325,6 +335,18 @@ class LeaveStatementExec(val label: String) extends LeafStatementExec {
override def reset(): Unit = used = false
}

/**
* Executable node for Continue statement.
*/
class ContinueStatementExec() extends LeafStatementExec {

var used: Boolean = false

override def execute(session: SparkSession): Unit = ()

override def reset(): Unit = used = false
}

/**
* Executable node for IfElseStatement.
* @param conditions Collection of executable conditions. First condition corresponds to IF clause,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, HandlerType, IfElseStatement, SingleStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.errors.SqlScriptingErrors

/**
* SQL scripting interpreter - builds SQL script execution plan.
Expand Down Expand Up @@ -83,7 +84,12 @@ case class SqlScriptingInterpreter(session: SparkSession) {

handler.conditions.foreach(condition => {
val conditionValue = compoundBody.conditions.getOrElse(condition, condition)
conditionHandlerMap.put(conditionValue, handlerExec)
conditionHandlerMap.get(conditionValue) match {
case Some(_) =>
throw SqlScriptingErrors.duplicateHandlerForSameSqlState(
CurrentOrigin.get, conditionValue)
case None => conditionHandlerMap.put(conditionValue, handlerExec)
}
})

handlers += handlerExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.scripting

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -228,6 +229,32 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
verifySqlScriptResult(sqlScript, expected)
}

test("duplicate handler") {
val sqlScript =
"""
|BEGIN
| DECLARE flag INT = -1;
| DECLARE zero_division CONDITION FOR '22012';
| DECLARE CONTINUE HANDLER FOR zero_division
| BEGIN
| SET VAR flag = 1;
| END;
| DECLARE CONTINUE HANDLER FOR zero_division
| BEGIN
| SET VAR flag = 2;
| END;
| SELECT 1/0;
| SELECT flag;
|END
|""".stripMargin
checkError(
exception = intercept[SqlScriptingException] {
verifySqlScriptResult(sqlScript, Seq.empty)
},
errorClass = "DUPLICATE_HANDLER_FOR_SAME_SQL_STATE",
parameters = Map("sqlState" -> "22012"))
}

test("handler - continue resolve in the same block") {
val sqlScript =
"""
Expand Down

0 comments on commit 29b4f3d

Please sign in to comment.