Skip to content

Commit

Permalink
Add catch all handler
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Aug 9, 2024
1 parent 736e481 commit 3b6dca0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class SingleStatementExec(
}
case throwable: Throwable =>
raisedError = true
errorState = Some("UNKNOWN")
errorState = Some("SQLEXCEPTION")
rethrow = Some(throwable)
}
}
Expand All @@ -200,7 +200,7 @@ class CompoundBodyExec(
case Some(handler) if condition.startsWith("02") => Some(handler)
case _ => None
})
.orElse(conditionHandlerMap.get("UNKNOWN"))
.orElse(conditionHandlerMap.get("SQLEXCEPTION"))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.errors.SqlScriptingErrors


/**
* SQL scripting interpreter - builds SQL script execution plan.
*/
Expand Down Expand Up @@ -97,12 +98,12 @@ case class SqlScriptingInterpreter(session: SparkSession) {

if (isExitHandler) {
val leave = new LeaveStatementExec(label)
val stmts = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++
val statements = compoundBody.collection.map(st => transformTreeIntoExecutable(st)) ++
dropVariables :+ leave

return new CompoundBodyExec(
compoundBody.label,
stmts,
statements,
conditionHandlerMap,
session)
}
Expand Down Expand Up @@ -150,7 +151,16 @@ case class SqlScriptingInterpreter(session: SparkSession) {
val executionPlan = buildExecutionPlan(compoundBody)
executionPlan.flatMap {
case statement: SingleStatementExec if statement.raisedError =>
throw statement.rethrow.get
val sqlState = statement.errorState.getOrElse(throw statement.rethrow.get)

// SQLWARNING and NOT FOUND are not considered as errors.
if (!sqlState.startsWith("01") || !sqlState.startsWith("02")) {
// Throw the error for SQLEXCEPTION.
throw statement.rethrow.get
}

// Return empty result set for SQLWARNING and NOT FOUND.
None
case statement: SingleStatementExec if statement.shouldCollectResult => statement.result
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.sql.scripting

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand Down Expand Up @@ -434,6 +434,33 @@ class SqlScriptingInterpreterSuite extends SparkFunSuite with SharedSparkSession
verifySqlScriptResult(sqlScript, expected)
}

test("handler - continue resolve by the CATCH ALL handler") {
val sqlScript =
"""
|BEGIN
| DECLARE flag INT = -1;
| DECLARE CONTINUE HANDLER FOR SQLEXCEPTION
| BEGIN
| SELECT flag;
| SET VAR flag = 1;
| END;
| SELECT 2;
| SELECT 1/0;
| SELECT 3;
| SELECT flag;
|END
|""".stripMargin
val expected = Seq(
Array.empty[Row], // declare var
Array(Row(2)), // select
Array(Row(-1)), // select flag
Array.empty[Row], // set flag
Array(Row(3)), // select
Array(Row(1)), // select
)
verifySqlScriptResult(sqlScript, expected)
}

test("chained begin end blocks") {
val sqlScript =
"""
Expand Down

0 comments on commit 3b6dca0

Please sign in to comment.