diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0ebeea9aed8d2..229da4fa17de7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3118,12 +3118,12 @@ "subClass" : { "NOT_ALLOWED_IN_SCOPE" : { "message" : [ - "Variable was declared on line , which is not allowed in this scope." + "Declaration of the variable is not allowed in this scope." ] }, "ONLY_AT_BEGINNING" : { "message" : [ - "Variable can only be declared at the beginning of the compound, but it was declared on line ." + "Variable can only be declared at the beginning of the compound." ] } }, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9620ce13d92eb..7ad7d60e70c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -173,14 +173,10 @@ class AstBuilder extends DataTypeAstBuilder case Some(c: CreateVariable) => if (allowVarDeclare) { throw SqlScriptingErrors.variableDeclarationOnlyAtBeginning( - c.origin, - toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), - c.origin.line.get.toString) + c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) } else { throw SqlScriptingErrors.variableDeclarationNotAllowedInScope( - c.origin, - toSQLId(c.name.asInstanceOf[UnresolvedIdentifier].nameParts), - c.origin.line.get.toString) + c.origin, c.name.asInstanceOf[UnresolvedIdentifier].nameParts) } case _ => } @@ -200,7 +196,9 @@ class AstBuilder extends DataTypeAstBuilder el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) => withOrigin(bl) { throw SqlScriptingErrors.labelsMismatch( - CurrentOrigin.get, bl.multipartIdentifier().getText, el.multipartIdentifier().getText) + CurrentOrigin.get, + bl.multipartIdentifier().getText, + el.multipartIdentifier().getText) } case (None, Some(el: EndLabelContext)) => withOrigin(el) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 591d2e3e53d47..7f13dc334e06e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.errors import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLStmt import org.apache.spark.sql.exceptions.SqlScriptingException @@ -32,7 +33,7 @@ private[sql] object SqlScriptingErrors { origin = origin, errorClass = "LABELS_MISMATCH", cause = null, - messageParameters = Map("beginLabel" -> beginLabel, "endLabel" -> endLabel)) + messageParameters = Map("beginLabel" -> toSQLId(beginLabel), "endLabel" -> toSQLId(endLabel))) } def endLabelWithoutBeginLabel(origin: Origin, endLabel: String): Throwable = { @@ -40,29 +41,27 @@ private[sql] object SqlScriptingErrors { origin = origin, errorClass = "END_LABEL_WITHOUT_BEGIN_LABEL", cause = null, - messageParameters = Map("endLabel" -> endLabel)) + messageParameters = Map("endLabel" -> toSQLId(endLabel))) } def variableDeclarationNotAllowedInScope( origin: Origin, - varName: String, - lineNumber: String): Throwable = { + varName: Seq[String]): Throwable = { new SqlScriptingException( origin = origin, errorClass = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", cause = null, - messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + messageParameters = Map("varName" -> toSQLId(varName))) } def variableDeclarationOnlyAtBeginning( origin: Origin, - varName: String, - lineNumber: String): Throwable = { + varName: Seq[String]): Throwable = { new SqlScriptingException( origin = origin, errorClass = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", cause = null, - messageParameters = Map("varName" -> varName, "lineNumber" -> lineNumber)) + messageParameters = Map("varName" -> toSQLId(varName))) } def invalidBooleanStatement( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala index 4354e7e3635e4..f0c28c95046eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.exceptions.SqlScriptingException.errorMessageWithLin class SqlScriptingException ( errorClass: String, cause: Throwable, - origin: Origin, + val origin: Origin, messageParameters: Map[String, String] = Map.empty) extends Exception( errorMessageWithLineNumber(Option(origin), errorClass, messageParameters), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 24ad32c5300bc..ba634333e06fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, Project} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { @@ -206,13 +207,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl_end""".stripMargin - + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, + exception = exception, condition = "LABELS_MISMATCH", - parameters = Map("beginLabel" -> "lbl_begin", "endLabel" -> "lbl_end")) + parameters = Map("beginLabel" -> toSQLId("lbl_begin"), "endLabel" -> toSQLId("lbl_end"))) + assert(exception.origin.line.contains(2)) } test("compound: endLabel") { @@ -225,13 +227,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT a, b, c FROM T; | SELECT * FROM T; |END lbl""".stripMargin - + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, + exception = exception, condition = "END_LABEL_WITHOUT_BEGIN_LABEL", - parameters = Map("endLabel" -> "lbl")) + parameters = Map("endLabel" -> toSQLId("lbl"))) + assert(exception.origin.line.contains(8)) } test("compound: beginLabel + endLabel with different casing") { @@ -287,12 +290,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | DECLARE testVariable INTEGER; |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, + exception = exception, condition = "INVALID_VARIABLE_DECLARATION.ONLY_AT_BEGINNING", - parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) + parameters = Map("varName" -> "`testVariable`")) + assert(exception.origin.line.contains(4)) } test("declare in wrong scope") { @@ -303,12 +308,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | DECLARE testVariable INTEGER; | END IF; |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parseScript(sqlScriptText) + } checkError( - exception = intercept[SqlScriptingException] { - parseScript(sqlScriptText) - }, + exception = exception, condition = "INVALID_VARIABLE_DECLARATION.NOT_ALLOWED_IN_SCOPE", - parameters = Map("varName" -> "`testVariable`", "lineNumber" -> "4")) + parameters = Map("varName" -> "`testVariable`")) + assert(exception.origin.line.contains(4)) } test("SET VAR statement test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 4851faf897a02..3fad99eba509a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -755,13 +755,16 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin + val exception = intercept[SqlScriptingException] { + runSqlScript(commands) + } checkError( - exception = intercept[SqlScriptingException] ( - runSqlScript(commands) - ), + exception = exception, condition = "INVALID_BOOLEAN_STATEMENT", parameters = Map("invalidStatement" -> "1") ) + assert(exception.origin.line.isDefined) + assert(exception.origin.line.get == 3) } } @@ -777,13 +780,16 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END IF; |END |""".stripMargin + val exception = intercept[SqlScriptingException] { + runSqlScript(commands1) + } checkError( - exception = intercept[SqlScriptingException] ( - runSqlScript(commands1) - ), + exception = exception, condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", parameters = Map("invalidStatement" -> "(SELECT * FROM T1)") ) + assert(exception.origin.line.isDefined) + assert(exception.origin.line.get == 4) // too many rows ( > 1 ) val commands2 =